| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use crate::core::{Id, Point}; |
| use std::io::{self, Read, Write, Cursor}; |
|
|
| |
| const MAGIC: &[u8; 4] = b"HAT\0"; |
|
|
| |
| const VERSION: u32 = 1; |
|
|
| |
| #[derive(Debug)] |
| pub enum PersistError { |
| |
| InvalidMagic, |
| |
| UnsupportedVersion(u32), |
| |
| Io(io::Error), |
| |
| Corrupted(String), |
| |
| DimensionMismatch { expected: usize, found: usize }, |
| } |
|
|
| impl std::fmt::Display for PersistError { |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| match self { |
| PersistError::InvalidMagic => write!(f, "Invalid HAT file magic bytes"), |
| PersistError::UnsupportedVersion(v) => write!(f, "Unsupported HAT version: {}", v), |
| PersistError::Io(e) => write!(f, "IO error: {}", e), |
| PersistError::Corrupted(msg) => write!(f, "Data corruption: {}", msg), |
| PersistError::DimensionMismatch { expected, found } => { |
| write!(f, "Dimension mismatch: expected {}, found {}", expected, found) |
| } |
| } |
| } |
| } |
|
|
| impl std::error::Error for PersistError {} |
|
|
| impl From<io::Error> for PersistError { |
| fn from(e: io::Error) -> Self { |
| PersistError::Io(e) |
| } |
| } |
|
|
| |
| #[repr(u8)] |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| pub enum LevelByte { |
| Root = 0, |
| Session = 1, |
| Document = 2, |
| Chunk = 3, |
| } |
|
|
| impl LevelByte { |
| pub fn from_u8(v: u8) -> Option<Self> { |
| match v { |
| 0 => Some(LevelByte::Root), |
| 1 => Some(LevelByte::Session), |
| 2 => Some(LevelByte::Document), |
| 3 => Some(LevelByte::Chunk), |
| _ => None, |
| } |
| } |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct SerializedContainer { |
| pub id: Id, |
| pub level: LevelByte, |
| pub timestamp: u64, |
| pub children: Vec<Id>, |
| pub descendant_count: u64, |
| pub centroid: Vec<f32>, |
| pub accumulated_sum: Option<Vec<f32>>, |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct SerializedHat { |
| pub version: u32, |
| pub dimensionality: u32, |
| pub root_id: Option<Id>, |
| pub containers: Vec<SerializedContainer>, |
| pub active_session: Option<Id>, |
| pub active_document: Option<Id>, |
| pub router_weights: Option<Vec<f32>>, |
| } |
|
|
| impl SerializedHat { |
| |
| pub fn to_bytes(&self) -> Result<Vec<u8>, PersistError> { |
| let mut buf = Vec::new(); |
|
|
| |
| buf.write_all(MAGIC)?; |
| buf.write_all(&self.version.to_le_bytes())?; |
| buf.write_all(&self.dimensionality.to_le_bytes())?; |
| buf.write_all(&(self.containers.len() as u64).to_le_bytes())?; |
|
|
| |
| if let Some(id) = &self.root_id { |
| buf.write_all(id.as_bytes())?; |
| } else { |
| buf.write_all(&[0u8; 16])?; |
| } |
|
|
| |
| for container in &self.containers { |
| |
| buf.write_all(container.id.as_bytes())?; |
|
|
| |
| buf.write_all(&[container.level as u8])?; |
|
|
| |
| buf.write_all(&container.timestamp.to_le_bytes())?; |
|
|
| |
| buf.write_all(&(container.children.len() as u32).to_le_bytes())?; |
| for child_id in &container.children { |
| buf.write_all(child_id.as_bytes())?; |
| } |
|
|
| |
| buf.write_all(&container.descendant_count.to_le_bytes())?; |
|
|
| |
| for &v in &container.centroid { |
| buf.write_all(&v.to_le_bytes())?; |
| } |
|
|
| |
| if let Some(sum) = &container.accumulated_sum { |
| buf.write_all(&[1u8])?; |
| for &v in sum { |
| buf.write_all(&v.to_le_bytes())?; |
| } |
| } else { |
| buf.write_all(&[0u8])?; |
| } |
| } |
|
|
| |
| if let Some(id) = &self.active_session { |
| buf.write_all(id.as_bytes())?; |
| } else { |
| buf.write_all(&[0u8; 16])?; |
| } |
|
|
| if let Some(id) = &self.active_document { |
| buf.write_all(id.as_bytes())?; |
| } else { |
| buf.write_all(&[0u8; 16])?; |
| } |
|
|
| |
| if let Some(weights) = &self.router_weights { |
| buf.write_all(&[1u8])?; |
| for &w in weights { |
| buf.write_all(&w.to_le_bytes())?; |
| } |
| } else { |
| buf.write_all(&[0u8])?; |
| } |
|
|
| Ok(buf) |
| } |
|
|
| |
| pub fn from_bytes(data: &[u8]) -> Result<Self, PersistError> { |
| let mut cursor = Cursor::new(data); |
|
|
| |
| let mut magic = [0u8; 4]; |
| cursor.read_exact(&mut magic)?; |
| if &magic != MAGIC { |
| return Err(PersistError::InvalidMagic); |
| } |
|
|
| let mut version_bytes = [0u8; 4]; |
| cursor.read_exact(&mut version_bytes)?; |
| let version = u32::from_le_bytes(version_bytes); |
| if version != VERSION { |
| return Err(PersistError::UnsupportedVersion(version)); |
| } |
|
|
| let mut dims_bytes = [0u8; 4]; |
| cursor.read_exact(&mut dims_bytes)?; |
| let dimensionality = u32::from_le_bytes(dims_bytes); |
|
|
| let mut count_bytes = [0u8; 8]; |
| cursor.read_exact(&mut count_bytes)?; |
| let container_count = u64::from_le_bytes(count_bytes); |
|
|
| let mut root_bytes = [0u8; 16]; |
| cursor.read_exact(&mut root_bytes)?; |
| let root_id = if root_bytes == [0u8; 16] { |
| None |
| } else { |
| Some(Id::from_bytes(root_bytes)) |
| }; |
|
|
| |
| let mut containers = Vec::with_capacity(container_count as usize); |
| for _ in 0..container_count { |
| |
| let mut id_bytes = [0u8; 16]; |
| cursor.read_exact(&mut id_bytes)?; |
| let id = Id::from_bytes(id_bytes); |
|
|
| |
| let mut level_byte = [0u8; 1]; |
| cursor.read_exact(&mut level_byte)?; |
| let level = LevelByte::from_u8(level_byte[0]) |
| .ok_or_else(|| PersistError::Corrupted(format!("Invalid level: {}", level_byte[0])))?; |
|
|
| |
| let mut ts_bytes = [0u8; 8]; |
| cursor.read_exact(&mut ts_bytes)?; |
| let timestamp = u64::from_le_bytes(ts_bytes); |
|
|
| |
| let mut child_count_bytes = [0u8; 4]; |
| cursor.read_exact(&mut child_count_bytes)?; |
| let child_count = u32::from_le_bytes(child_count_bytes) as usize; |
|
|
| let mut children = Vec::with_capacity(child_count); |
| for _ in 0..child_count { |
| let mut child_bytes = [0u8; 16]; |
| cursor.read_exact(&mut child_bytes)?; |
| children.push(Id::from_bytes(child_bytes)); |
| } |
|
|
| |
| let mut desc_bytes = [0u8; 8]; |
| cursor.read_exact(&mut desc_bytes)?; |
| let descendant_count = u64::from_le_bytes(desc_bytes); |
|
|
| |
| let mut centroid = Vec::with_capacity(dimensionality as usize); |
| for _ in 0..dimensionality { |
| let mut v_bytes = [0u8; 4]; |
| cursor.read_exact(&mut v_bytes)?; |
| centroid.push(f32::from_le_bytes(v_bytes)); |
| } |
|
|
| |
| let mut has_sum = [0u8; 1]; |
| cursor.read_exact(&mut has_sum)?; |
| let accumulated_sum = if has_sum[0] == 1 { |
| let mut sum = Vec::with_capacity(dimensionality as usize); |
| for _ in 0..dimensionality { |
| let mut v_bytes = [0u8; 4]; |
| cursor.read_exact(&mut v_bytes)?; |
| sum.push(f32::from_le_bytes(v_bytes)); |
| } |
| Some(sum) |
| } else { |
| None |
| }; |
|
|
| containers.push(SerializedContainer { |
| id, |
| level, |
| timestamp, |
| children, |
| descendant_count, |
| centroid, |
| accumulated_sum, |
| }); |
| } |
|
|
| |
| let mut active_session_bytes = [0u8; 16]; |
| cursor.read_exact(&mut active_session_bytes)?; |
| let active_session = if active_session_bytes == [0u8; 16] { |
| None |
| } else { |
| Some(Id::from_bytes(active_session_bytes)) |
| }; |
|
|
| let mut active_document_bytes = [0u8; 16]; |
| cursor.read_exact(&mut active_document_bytes)?; |
| let active_document = if active_document_bytes == [0u8; 16] { |
| None |
| } else { |
| Some(Id::from_bytes(active_document_bytes)) |
| }; |
|
|
| |
| let router_weights = if cursor.position() < data.len() as u64 { |
| let mut has_weights = [0u8; 1]; |
| cursor.read_exact(&mut has_weights)?; |
| if has_weights[0] == 1 { |
| let mut weights = Vec::with_capacity(dimensionality as usize); |
| for _ in 0..dimensionality { |
| let mut w_bytes = [0u8; 4]; |
| cursor.read_exact(&mut w_bytes)?; |
| weights.push(f32::from_le_bytes(w_bytes)); |
| } |
| Some(weights) |
| } else { |
| None |
| } |
| } else { |
| None |
| }; |
|
|
| Ok(SerializedHat { |
| version, |
| dimensionality, |
| root_id, |
| containers, |
| active_session, |
| active_document, |
| router_weights, |
| }) |
| } |
| } |
|
|
| |
| fn id_to_bytes(id: &Option<Id>) -> [u8; 16] { |
| match id { |
| Some(id) => *id.as_bytes(), |
| None => [0u8; 16], |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_serialized_hat_roundtrip() { |
| let original = SerializedHat { |
| version: VERSION, |
| dimensionality: 128, |
| root_id: Some(Id::now()), |
| containers: vec![ |
| SerializedContainer { |
| id: Id::now(), |
| level: LevelByte::Root, |
| timestamp: 1234567890, |
| children: vec![Id::now(), Id::now()], |
| descendant_count: 10, |
| centroid: vec![0.1; 128], |
| accumulated_sum: None, |
| }, |
| SerializedContainer { |
| id: Id::now(), |
| level: LevelByte::Chunk, |
| timestamp: 1234567891, |
| children: vec![], |
| descendant_count: 1, |
| centroid: vec![0.5; 128], |
| accumulated_sum: Some(vec![0.5; 128]), |
| }, |
| ], |
| active_session: Some(Id::now()), |
| active_document: None, |
| router_weights: Some(vec![1.0; 128]), |
| }; |
|
|
| let bytes = original.to_bytes().unwrap(); |
| let restored = SerializedHat::from_bytes(&bytes).unwrap(); |
|
|
| assert_eq!(restored.version, original.version); |
| assert_eq!(restored.dimensionality, original.dimensionality); |
| assert_eq!(restored.containers.len(), original.containers.len()); |
| assert!(restored.router_weights.is_some()); |
| } |
|
|
| #[test] |
| fn test_invalid_magic() { |
| let bad_data = b"BAD\0rest of data..."; |
| let result = SerializedHat::from_bytes(bad_data); |
| assert!(matches!(result, Err(PersistError::InvalidMagic))); |
| } |
|
|
| #[test] |
| fn test_level_byte_conversion() { |
| assert_eq!(LevelByte::from_u8(0), Some(LevelByte::Root)); |
| assert_eq!(LevelByte::from_u8(1), Some(LevelByte::Session)); |
| assert_eq!(LevelByte::from_u8(2), Some(LevelByte::Document)); |
| assert_eq!(LevelByte::from_u8(3), Some(LevelByte::Chunk)); |
| assert_eq!(LevelByte::from_u8(4), None); |
| } |
| } |
|
|