|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::core::{Id, Point}; |
|
|
use std::io::{self, Read, Write, Cursor}; |
|
|
|
|
|
|
|
|
const MAGIC: &[u8; 4] = b"HAT\0"; |
|
|
|
|
|
|
|
|
const VERSION: u32 = 1; |
|
|
|
|
|
|
|
|
#[derive(Debug)] |
|
|
pub enum PersistError { |
|
|
|
|
|
InvalidMagic, |
|
|
|
|
|
UnsupportedVersion(u32), |
|
|
|
|
|
Io(io::Error), |
|
|
|
|
|
Corrupted(String), |
|
|
|
|
|
DimensionMismatch { expected: usize, found: usize }, |
|
|
} |
|
|
|
|
|
impl std::fmt::Display for PersistError { |
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
|
match self { |
|
|
PersistError::InvalidMagic => write!(f, "Invalid HAT file magic bytes"), |
|
|
PersistError::UnsupportedVersion(v) => write!(f, "Unsupported HAT version: {}", v), |
|
|
PersistError::Io(e) => write!(f, "IO error: {}", e), |
|
|
PersistError::Corrupted(msg) => write!(f, "Data corruption: {}", msg), |
|
|
PersistError::DimensionMismatch { expected, found } => { |
|
|
write!(f, "Dimension mismatch: expected {}, found {}", expected, found) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl std::error::Error for PersistError {} |
|
|
|
|
|
impl From<io::Error> for PersistError { |
|
|
fn from(e: io::Error) -> Self { |
|
|
PersistError::Io(e) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[repr(u8)] |
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
|
|
pub enum LevelByte { |
|
|
Root = 0, |
|
|
Session = 1, |
|
|
Document = 2, |
|
|
Chunk = 3, |
|
|
} |
|
|
|
|
|
impl LevelByte { |
|
|
pub fn from_u8(v: u8) -> Option<Self> { |
|
|
match v { |
|
|
0 => Some(LevelByte::Root), |
|
|
1 => Some(LevelByte::Session), |
|
|
2 => Some(LevelByte::Document), |
|
|
3 => Some(LevelByte::Chunk), |
|
|
_ => None, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct SerializedContainer { |
|
|
pub id: Id, |
|
|
pub level: LevelByte, |
|
|
pub timestamp: u64, |
|
|
pub children: Vec<Id>, |
|
|
pub descendant_count: u64, |
|
|
pub centroid: Vec<f32>, |
|
|
pub accumulated_sum: Option<Vec<f32>>, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct SerializedHat { |
|
|
pub version: u32, |
|
|
pub dimensionality: u32, |
|
|
pub root_id: Option<Id>, |
|
|
pub containers: Vec<SerializedContainer>, |
|
|
pub active_session: Option<Id>, |
|
|
pub active_document: Option<Id>, |
|
|
pub router_weights: Option<Vec<f32>>, |
|
|
} |
|
|
|
|
|
impl SerializedHat { |
|
|
|
|
|
pub fn to_bytes(&self) -> Result<Vec<u8>, PersistError> { |
|
|
let mut buf = Vec::new(); |
|
|
|
|
|
|
|
|
buf.write_all(MAGIC)?; |
|
|
buf.write_all(&self.version.to_le_bytes())?; |
|
|
buf.write_all(&self.dimensionality.to_le_bytes())?; |
|
|
buf.write_all(&(self.containers.len() as u64).to_le_bytes())?; |
|
|
|
|
|
|
|
|
if let Some(id) = &self.root_id { |
|
|
buf.write_all(id.as_bytes())?; |
|
|
} else { |
|
|
buf.write_all(&[0u8; 16])?; |
|
|
} |
|
|
|
|
|
|
|
|
for container in &self.containers { |
|
|
|
|
|
buf.write_all(container.id.as_bytes())?; |
|
|
|
|
|
|
|
|
buf.write_all(&[container.level as u8])?; |
|
|
|
|
|
|
|
|
buf.write_all(&container.timestamp.to_le_bytes())?; |
|
|
|
|
|
|
|
|
buf.write_all(&(container.children.len() as u32).to_le_bytes())?; |
|
|
for child_id in &container.children { |
|
|
buf.write_all(child_id.as_bytes())?; |
|
|
} |
|
|
|
|
|
|
|
|
buf.write_all(&container.descendant_count.to_le_bytes())?; |
|
|
|
|
|
|
|
|
for &v in &container.centroid { |
|
|
buf.write_all(&v.to_le_bytes())?; |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(sum) = &container.accumulated_sum { |
|
|
buf.write_all(&[1u8])?; |
|
|
for &v in sum { |
|
|
buf.write_all(&v.to_le_bytes())?; |
|
|
} |
|
|
} else { |
|
|
buf.write_all(&[0u8])?; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(id) = &self.active_session { |
|
|
buf.write_all(id.as_bytes())?; |
|
|
} else { |
|
|
buf.write_all(&[0u8; 16])?; |
|
|
} |
|
|
|
|
|
if let Some(id) = &self.active_document { |
|
|
buf.write_all(id.as_bytes())?; |
|
|
} else { |
|
|
buf.write_all(&[0u8; 16])?; |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(weights) = &self.router_weights { |
|
|
buf.write_all(&[1u8])?; |
|
|
for &w in weights { |
|
|
buf.write_all(&w.to_le_bytes())?; |
|
|
} |
|
|
} else { |
|
|
buf.write_all(&[0u8])?; |
|
|
} |
|
|
|
|
|
Ok(buf) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Result<Self, PersistError> { |
|
|
let mut cursor = Cursor::new(data); |
|
|
|
|
|
|
|
|
let mut magic = [0u8; 4]; |
|
|
cursor.read_exact(&mut magic)?; |
|
|
if &magic != MAGIC { |
|
|
return Err(PersistError::InvalidMagic); |
|
|
} |
|
|
|
|
|
let mut version_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut version_bytes)?; |
|
|
let version = u32::from_le_bytes(version_bytes); |
|
|
if version != VERSION { |
|
|
return Err(PersistError::UnsupportedVersion(version)); |
|
|
} |
|
|
|
|
|
let mut dims_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut dims_bytes)?; |
|
|
let dimensionality = u32::from_le_bytes(dims_bytes); |
|
|
|
|
|
let mut count_bytes = [0u8; 8]; |
|
|
cursor.read_exact(&mut count_bytes)?; |
|
|
let container_count = u64::from_le_bytes(count_bytes); |
|
|
|
|
|
let mut root_bytes = [0u8; 16]; |
|
|
cursor.read_exact(&mut root_bytes)?; |
|
|
let root_id = if root_bytes == [0u8; 16] { |
|
|
None |
|
|
} else { |
|
|
Some(Id::from_bytes(root_bytes)) |
|
|
}; |
|
|
|
|
|
|
|
|
let mut containers = Vec::with_capacity(container_count as usize); |
|
|
for _ in 0..container_count { |
|
|
|
|
|
let mut id_bytes = [0u8; 16]; |
|
|
cursor.read_exact(&mut id_bytes)?; |
|
|
let id = Id::from_bytes(id_bytes); |
|
|
|
|
|
|
|
|
let mut level_byte = [0u8; 1]; |
|
|
cursor.read_exact(&mut level_byte)?; |
|
|
let level = LevelByte::from_u8(level_byte[0]) |
|
|
.ok_or_else(|| PersistError::Corrupted(format!("Invalid level: {}", level_byte[0])))?; |
|
|
|
|
|
|
|
|
let mut ts_bytes = [0u8; 8]; |
|
|
cursor.read_exact(&mut ts_bytes)?; |
|
|
let timestamp = u64::from_le_bytes(ts_bytes); |
|
|
|
|
|
|
|
|
let mut child_count_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut child_count_bytes)?; |
|
|
let child_count = u32::from_le_bytes(child_count_bytes) as usize; |
|
|
|
|
|
let mut children = Vec::with_capacity(child_count); |
|
|
for _ in 0..child_count { |
|
|
let mut child_bytes = [0u8; 16]; |
|
|
cursor.read_exact(&mut child_bytes)?; |
|
|
children.push(Id::from_bytes(child_bytes)); |
|
|
} |
|
|
|
|
|
|
|
|
let mut desc_bytes = [0u8; 8]; |
|
|
cursor.read_exact(&mut desc_bytes)?; |
|
|
let descendant_count = u64::from_le_bytes(desc_bytes); |
|
|
|
|
|
|
|
|
let mut centroid = Vec::with_capacity(dimensionality as usize); |
|
|
for _ in 0..dimensionality { |
|
|
let mut v_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut v_bytes)?; |
|
|
centroid.push(f32::from_le_bytes(v_bytes)); |
|
|
} |
|
|
|
|
|
|
|
|
let mut has_sum = [0u8; 1]; |
|
|
cursor.read_exact(&mut has_sum)?; |
|
|
let accumulated_sum = if has_sum[0] == 1 { |
|
|
let mut sum = Vec::with_capacity(dimensionality as usize); |
|
|
for _ in 0..dimensionality { |
|
|
let mut v_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut v_bytes)?; |
|
|
sum.push(f32::from_le_bytes(v_bytes)); |
|
|
} |
|
|
Some(sum) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
containers.push(SerializedContainer { |
|
|
id, |
|
|
level, |
|
|
timestamp, |
|
|
children, |
|
|
descendant_count, |
|
|
centroid, |
|
|
accumulated_sum, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
let mut active_session_bytes = [0u8; 16]; |
|
|
cursor.read_exact(&mut active_session_bytes)?; |
|
|
let active_session = if active_session_bytes == [0u8; 16] { |
|
|
None |
|
|
} else { |
|
|
Some(Id::from_bytes(active_session_bytes)) |
|
|
}; |
|
|
|
|
|
let mut active_document_bytes = [0u8; 16]; |
|
|
cursor.read_exact(&mut active_document_bytes)?; |
|
|
let active_document = if active_document_bytes == [0u8; 16] { |
|
|
None |
|
|
} else { |
|
|
Some(Id::from_bytes(active_document_bytes)) |
|
|
}; |
|
|
|
|
|
|
|
|
let router_weights = if cursor.position() < data.len() as u64 { |
|
|
let mut has_weights = [0u8; 1]; |
|
|
cursor.read_exact(&mut has_weights)?; |
|
|
if has_weights[0] == 1 { |
|
|
let mut weights = Vec::with_capacity(dimensionality as usize); |
|
|
for _ in 0..dimensionality { |
|
|
let mut w_bytes = [0u8; 4]; |
|
|
cursor.read_exact(&mut w_bytes)?; |
|
|
weights.push(f32::from_le_bytes(w_bytes)); |
|
|
} |
|
|
Some(weights) |
|
|
} else { |
|
|
None |
|
|
} |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
Ok(SerializedHat { |
|
|
version, |
|
|
dimensionality, |
|
|
root_id, |
|
|
containers, |
|
|
active_session, |
|
|
active_document, |
|
|
router_weights, |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn id_to_bytes(id: &Option<Id>) -> [u8; 16] { |
|
|
match id { |
|
|
Some(id) => *id.as_bytes(), |
|
|
None => [0u8; 16], |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_serialized_hat_roundtrip() { |
|
|
let original = SerializedHat { |
|
|
version: VERSION, |
|
|
dimensionality: 128, |
|
|
root_id: Some(Id::now()), |
|
|
containers: vec![ |
|
|
SerializedContainer { |
|
|
id: Id::now(), |
|
|
level: LevelByte::Root, |
|
|
timestamp: 1234567890, |
|
|
children: vec![Id::now(), Id::now()], |
|
|
descendant_count: 10, |
|
|
centroid: vec![0.1; 128], |
|
|
accumulated_sum: None, |
|
|
}, |
|
|
SerializedContainer { |
|
|
id: Id::now(), |
|
|
level: LevelByte::Chunk, |
|
|
timestamp: 1234567891, |
|
|
children: vec![], |
|
|
descendant_count: 1, |
|
|
centroid: vec![0.5; 128], |
|
|
accumulated_sum: Some(vec![0.5; 128]), |
|
|
}, |
|
|
], |
|
|
active_session: Some(Id::now()), |
|
|
active_document: None, |
|
|
router_weights: Some(vec![1.0; 128]), |
|
|
}; |
|
|
|
|
|
let bytes = original.to_bytes().unwrap(); |
|
|
let restored = SerializedHat::from_bytes(&bytes).unwrap(); |
|
|
|
|
|
assert_eq!(restored.version, original.version); |
|
|
assert_eq!(restored.dimensionality, original.dimensionality); |
|
|
assert_eq!(restored.containers.len(), original.containers.len()); |
|
|
assert!(restored.router_weights.is_some()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_invalid_magic() { |
|
|
let bad_data = b"BAD\0rest of data..."; |
|
|
let result = SerializedHat::from_bytes(bad_data); |
|
|
assert!(matches!(result, Err(PersistError::InvalidMagic))); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_level_byte_conversion() { |
|
|
assert_eq!(LevelByte::from_u8(0), Some(LevelByte::Root)); |
|
|
assert_eq!(LevelByte::from_u8(1), Some(LevelByte::Session)); |
|
|
assert_eq!(LevelByte::from_u8(2), Some(LevelByte::Document)); |
|
|
assert_eq!(LevelByte::from_u8(3), Some(LevelByte::Chunk)); |
|
|
assert_eq!(LevelByte::from_u8(4), None); |
|
|
} |
|
|
} |
|
|
|