|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use std::collections::{HashMap, VecDeque}; |
|
|
use std::sync::Arc; |
|
|
use std::time::{SystemTime, UNIX_EPOCH}; |
|
|
|
|
|
use crate::core::{Id, Point}; |
|
|
use crate::core::proximity::Proximity; |
|
|
use crate::core::merge::Merge; |
|
|
use crate::ports::{Near, NearError, NearResult, SearchResult}; |
|
|
|
|
|
use super::consolidation::{ |
|
|
Consolidate, ConsolidationConfig, ConsolidationPhase, ConsolidationState, |
|
|
ConsolidationMetrics, ConsolidationProgress, ConsolidationTickResult, |
|
|
compute_exact_centroid, centroid_drift, |
|
|
}; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
|
|
pub enum CentroidMethod { |
|
|
|
|
|
Euclidean, |
|
|
|
|
|
Frechet, |
|
|
} |
|
|
|
|
|
impl Default for CentroidMethod { |
|
|
fn default() -> Self { |
|
|
CentroidMethod::Euclidean |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct HatConfig { |
|
|
|
|
|
pub max_children: usize, |
|
|
|
|
|
|
|
|
pub min_children: usize, |
|
|
|
|
|
|
|
|
pub beam_width: usize, |
|
|
|
|
|
|
|
|
pub temporal_weight: f32, |
|
|
|
|
|
|
|
|
pub time_decay: f32, |
|
|
|
|
|
|
|
|
|
|
|
pub propagation_threshold: f32, |
|
|
|
|
|
|
|
|
pub centroid_method: CentroidMethod, |
|
|
|
|
|
|
|
|
pub frechet_iterations: usize, |
|
|
|
|
|
|
|
|
pub subspace_enabled: bool, |
|
|
|
|
|
|
|
|
pub subspace_config: super::subspace::SubspaceConfig, |
|
|
|
|
|
|
|
|
pub learnable_routing_enabled: bool, |
|
|
|
|
|
|
|
|
pub learnable_routing_config: super::learnable_routing::LearnableRoutingConfig, |
|
|
} |
|
|
|
|
|
impl Default for HatConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
max_children: 50, |
|
|
min_children: 5, |
|
|
beam_width: 3, |
|
|
temporal_weight: 0.0, |
|
|
time_decay: 0.001, |
|
|
propagation_threshold: 0.0, |
|
|
centroid_method: CentroidMethod::Euclidean, |
|
|
frechet_iterations: 5, |
|
|
subspace_enabled: false, |
|
|
subspace_config: super::subspace::SubspaceConfig::default(), |
|
|
learnable_routing_enabled: false, |
|
|
learnable_routing_config: super::learnable_routing::LearnableRoutingConfig::default(), |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl HatConfig { |
|
|
pub fn new() -> Self { |
|
|
Self::default() |
|
|
} |
|
|
|
|
|
pub fn with_beam_width(mut self, width: usize) -> Self { |
|
|
self.beam_width = width; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_temporal_weight(mut self, weight: f32) -> Self { |
|
|
self.temporal_weight = weight; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_propagation_threshold(mut self, threshold: f32) -> Self { |
|
|
self.propagation_threshold = threshold; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_centroid_method(mut self, method: CentroidMethod) -> Self { |
|
|
self.centroid_method = method; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_frechet_iterations(mut self, iterations: usize) -> Self { |
|
|
self.frechet_iterations = iterations; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_subspace_enabled(mut self, enabled: bool) -> Self { |
|
|
self.subspace_enabled = enabled; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_subspace_config(mut self, config: super::subspace::SubspaceConfig) -> Self { |
|
|
self.subspace_config = config; |
|
|
self.subspace_enabled = true; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_learnable_routing_enabled(mut self, enabled: bool) -> Self { |
|
|
self.learnable_routing_enabled = enabled; |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_learnable_routing_config(mut self, config: super::learnable_routing::LearnableRoutingConfig) -> Self { |
|
|
self.learnable_routing_config = config; |
|
|
self.learnable_routing_enabled = true; |
|
|
self |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
|
|
pub enum ContainerLevel { |
|
|
|
|
|
Global, |
|
|
|
|
|
Session, |
|
|
|
|
|
Document, |
|
|
|
|
|
Chunk, |
|
|
} |
|
|
|
|
|
impl ContainerLevel { |
|
|
fn child_level(&self) -> Option<ContainerLevel> { |
|
|
match self { |
|
|
ContainerLevel::Global => Some(ContainerLevel::Session), |
|
|
ContainerLevel::Session => Some(ContainerLevel::Document), |
|
|
ContainerLevel::Document => Some(ContainerLevel::Chunk), |
|
|
ContainerLevel::Chunk => None, |
|
|
} |
|
|
} |
|
|
|
|
|
fn depth(&self) -> usize { |
|
|
match self { |
|
|
ContainerLevel::Global => 0, |
|
|
ContainerLevel::Session => 1, |
|
|
ContainerLevel::Document => 2, |
|
|
ContainerLevel::Chunk => 3, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct SessionSummary { |
|
|
|
|
|
pub id: Id, |
|
|
|
|
|
|
|
|
pub score: f32, |
|
|
|
|
|
|
|
|
pub chunk_count: usize, |
|
|
|
|
|
|
|
|
pub timestamp: u64, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct DocumentSummary { |
|
|
|
|
|
pub id: Id, |
|
|
|
|
|
|
|
|
pub score: f32, |
|
|
|
|
|
|
|
|
pub chunk_count: usize, |
|
|
|
|
|
|
|
|
pub timestamp: u64, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
struct Container { |
|
|
|
|
|
id: Id, |
|
|
|
|
|
|
|
|
level: ContainerLevel, |
|
|
|
|
|
|
|
|
centroid: Point, |
|
|
|
|
|
|
|
|
timestamp: u64, |
|
|
|
|
|
|
|
|
children: Vec<Id>, |
|
|
|
|
|
|
|
|
descendant_count: usize, |
|
|
|
|
|
|
|
|
|
|
|
accumulated_sum: Option<Point>, |
|
|
|
|
|
|
|
|
|
|
|
subspace: Option<super::subspace::Subspace>, |
|
|
} |
|
|
|
|
|
impl Container { |
|
|
fn new(id: Id, level: ContainerLevel, centroid: Point) -> Self { |
|
|
let timestamp = SystemTime::now() |
|
|
.duration_since(UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64; |
|
|
|
|
|
|
|
|
let accumulated_sum = if level == ContainerLevel::Chunk { |
|
|
Some(centroid.clone()) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
|
|
|
let subspace = if level != ContainerLevel::Chunk { |
|
|
Some(super::subspace::Subspace::new(centroid.dimensionality())) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
Self { |
|
|
id, |
|
|
level, |
|
|
centroid, |
|
|
timestamp, |
|
|
children: Vec::new(), |
|
|
descendant_count: if level == ContainerLevel::Chunk { 1 } else { 0 }, |
|
|
accumulated_sum, |
|
|
subspace, |
|
|
} |
|
|
} |
|
|
|
|
|
fn is_leaf(&self) -> bool { |
|
|
self.level == ContainerLevel::Chunk |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct HatIndex { |
|
|
|
|
|
containers: HashMap<Id, Container>, |
|
|
|
|
|
|
|
|
root_id: Option<Id>, |
|
|
|
|
|
|
|
|
active_session: Option<Id>, |
|
|
|
|
|
|
|
|
active_document: Option<Id>, |
|
|
|
|
|
|
|
|
dimensionality: usize, |
|
|
|
|
|
|
|
|
proximity: Arc<dyn Proximity>, |
|
|
|
|
|
|
|
|
merge: Arc<dyn Merge>, |
|
|
|
|
|
|
|
|
higher_is_better: bool, |
|
|
|
|
|
|
|
|
config: HatConfig, |
|
|
|
|
|
|
|
|
consolidation_state: Option<ConsolidationState>, |
|
|
|
|
|
|
|
|
consolidation_points_cache: HashMap<Id, Vec<Point>>, |
|
|
|
|
|
|
|
|
learnable_router: Option<super::learnable_routing::LearnableRouter>, |
|
|
} |
|
|
|
|
|
impl HatIndex { |
|
|
|
|
|
pub fn cosine(dimensionality: usize) -> Self { |
|
|
use crate::core::proximity::Cosine; |
|
|
use crate::core::merge::Mean; |
|
|
Self::new( |
|
|
dimensionality, |
|
|
Arc::new(Cosine), |
|
|
Arc::new(Mean), |
|
|
true, |
|
|
HatConfig::default(), |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn with_config(mut self, config: HatConfig) -> Self { |
|
|
|
|
|
if config.learnable_routing_enabled { |
|
|
self.learnable_router = Some(super::learnable_routing::LearnableRouter::new( |
|
|
self.dimensionality, |
|
|
config.learnable_routing_config.clone(), |
|
|
)); |
|
|
} |
|
|
self.config = config; |
|
|
self |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new( |
|
|
dimensionality: usize, |
|
|
proximity: Arc<dyn Proximity>, |
|
|
merge: Arc<dyn Merge>, |
|
|
higher_is_better: bool, |
|
|
config: HatConfig, |
|
|
) -> Self { |
|
|
|
|
|
let learnable_router = if config.learnable_routing_enabled { |
|
|
Some(super::learnable_routing::LearnableRouter::new( |
|
|
dimensionality, |
|
|
config.learnable_routing_config.clone(), |
|
|
)) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
Self { |
|
|
containers: HashMap::new(), |
|
|
root_id: None, |
|
|
active_session: None, |
|
|
active_document: None, |
|
|
dimensionality, |
|
|
proximity, |
|
|
merge, |
|
|
higher_is_better, |
|
|
config, |
|
|
consolidation_state: None, |
|
|
consolidation_points_cache: HashMap::new(), |
|
|
learnable_router, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn distance(&self, a: &Point, b: &Point) -> f32 { |
|
|
let prox = self.proximity.proximity(a, b); |
|
|
if self.higher_is_better { |
|
|
1.0 - prox |
|
|
} else { |
|
|
prox |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn temporal_distance(&self, t1: u64, t2: u64) -> f32 { |
|
|
let diff = (t1 as i64 - t2 as i64).unsigned_abs() as f64; |
|
|
|
|
|
|
|
|
let hours = diff / (1000.0 * 60.0 * 60.0); |
|
|
(1.0 - (-self.config.time_decay as f64 * hours).exp()) as f32 |
|
|
} |
|
|
|
|
|
|
|
|
fn combined_distance(&self, query: &Point, query_time: u64, container: &Container) -> f32 { |
|
|
|
|
|
let semantic = if self.config.learnable_routing_enabled { |
|
|
|
|
|
if let Some(ref router) = self.learnable_router { |
|
|
|
|
|
|
|
|
let sim = router.weighted_similarity(query, &container.centroid); |
|
|
1.0 - sim |
|
|
} else { |
|
|
self.distance(query, &container.centroid) |
|
|
} |
|
|
} else if self.config.subspace_enabled && !container.is_leaf() { |
|
|
|
|
|
if let Some(ref subspace) = container.subspace { |
|
|
|
|
|
|
|
|
let sim = super::subspace::combined_subspace_similarity( |
|
|
query, subspace, &self.config.subspace_config |
|
|
); |
|
|
1.0 - sim |
|
|
} else { |
|
|
self.distance(query, &container.centroid) |
|
|
} |
|
|
} else { |
|
|
self.distance(query, &container.centroid) |
|
|
}; |
|
|
|
|
|
let temporal = self.temporal_distance(query_time, container.timestamp); |
|
|
|
|
|
|
|
|
let w = self.config.temporal_weight; |
|
|
semantic * (1.0 - w) + temporal * w |
|
|
} |
|
|
|
|
|
|
|
|
fn ensure_root(&mut self) { |
|
|
if self.root_id.is_none() { |
|
|
let root = Container::new( |
|
|
Id::now(), |
|
|
ContainerLevel::Global, |
|
|
Point::origin(self.dimensionality), |
|
|
); |
|
|
let root_id = root.id; |
|
|
self.containers.insert(root_id, root); |
|
|
self.root_id = Some(root_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn ensure_session(&mut self) { |
|
|
self.ensure_root(); |
|
|
|
|
|
if self.active_session.is_none() { |
|
|
let session = Container::new( |
|
|
Id::now(), |
|
|
ContainerLevel::Session, |
|
|
Point::origin(self.dimensionality), |
|
|
); |
|
|
let session_id = session.id; |
|
|
self.containers.insert(session_id, session); |
|
|
|
|
|
|
|
|
if let Some(root_id) = self.root_id { |
|
|
if let Some(root) = self.containers.get_mut(&root_id) { |
|
|
root.children.push(session_id); |
|
|
} |
|
|
} |
|
|
|
|
|
self.active_session = Some(session_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn ensure_document(&mut self) { |
|
|
self.ensure_session(); |
|
|
|
|
|
if self.active_document.is_none() { |
|
|
let document = Container::new( |
|
|
Id::now(), |
|
|
ContainerLevel::Document, |
|
|
Point::origin(self.dimensionality), |
|
|
); |
|
|
let doc_id = document.id; |
|
|
self.containers.insert(doc_id, document); |
|
|
|
|
|
|
|
|
if let Some(session_id) = self.active_session { |
|
|
if let Some(session) = self.containers.get_mut(&session_id) { |
|
|
session.children.push(doc_id); |
|
|
} |
|
|
} |
|
|
|
|
|
self.active_document = Some(doc_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new_session(&mut self) { |
|
|
self.active_session = None; |
|
|
self.active_document = None; |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new_document(&mut self) { |
|
|
self.active_document = None; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn compute_frechet_mean(&self, points: &[Point], initial: &Point) -> Point { |
|
|
let mut mean = initial.clone(); |
|
|
let iterations = self.config.frechet_iterations; |
|
|
|
|
|
for _ in 0..iterations { |
|
|
|
|
|
let mut tangent_sum = vec![0.0f32; mean.dimensionality()]; |
|
|
|
|
|
for point in points { |
|
|
|
|
|
|
|
|
|
|
|
let dot: f32 = mean.dims().iter() |
|
|
.zip(point.dims().iter()) |
|
|
.map(|(a, b)| a * b) |
|
|
.sum(); |
|
|
|
|
|
|
|
|
let dot_clamped = dot.clamp(-1.0, 1.0); |
|
|
let theta = dot_clamped.acos(); |
|
|
|
|
|
if theta.abs() < 1e-8 { |
|
|
|
|
|
continue; |
|
|
} |
|
|
|
|
|
|
|
|
let mut direction: Vec<f32> = point.dims().iter() |
|
|
.zip(mean.dims().iter()) |
|
|
.map(|(q, p)| q - dot * p) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let dir_norm: f32 = direction.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
if dir_norm < 1e-8 { |
|
|
continue; |
|
|
} |
|
|
|
|
|
for (i, d) in direction.iter_mut().enumerate() { |
|
|
tangent_sum[i] += theta * (*d / dir_norm); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let n = points.len() as f32; |
|
|
for t in tangent_sum.iter_mut() { |
|
|
*t /= n; |
|
|
} |
|
|
|
|
|
|
|
|
let tangent_norm: f32 = tangent_sum.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
|
|
|
if tangent_norm < 1e-8 { |
|
|
|
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let cos_t = tangent_norm.cos(); |
|
|
let sin_t = tangent_norm.sin(); |
|
|
|
|
|
let new_dims: Vec<f32> = mean.dims().iter() |
|
|
.zip(tangent_sum.iter()) |
|
|
.map(|(p, v)| cos_t * p + sin_t * (v / tangent_norm)) |
|
|
.collect(); |
|
|
|
|
|
mean = Point::new(new_dims); |
|
|
} |
|
|
|
|
|
|
|
|
mean.normalize() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn update_centroid(&mut self, container_id: Id, new_point: &Point) -> f32 { |
|
|
let method = self.config.centroid_method; |
|
|
|
|
|
|
|
|
let (old_centroid, n, accumulated_sum) = { |
|
|
if let Some(container) = self.containers.get(&container_id) { |
|
|
( |
|
|
container.centroid.clone(), |
|
|
container.descendant_count as f32, |
|
|
container.accumulated_sum.clone(), |
|
|
) |
|
|
} else { |
|
|
return 0.0; |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
if n == 0.0 { |
|
|
if let Some(container) = self.containers.get_mut(&container_id) { |
|
|
container.centroid = new_point.clone(); |
|
|
container.accumulated_sum = Some(new_point.clone()); |
|
|
container.descendant_count += 1; |
|
|
} |
|
|
return f32::MAX; |
|
|
} |
|
|
|
|
|
|
|
|
let (new_centroid, new_sum) = match method { |
|
|
CentroidMethod::Euclidean => { |
|
|
|
|
|
let new_sum = if let Some(ref sum) = accumulated_sum { |
|
|
sum.dims().iter() |
|
|
.zip(new_point.dims().iter()) |
|
|
.map(|(s, p)| s + p) |
|
|
.collect::<Vec<f32>>() |
|
|
} else { |
|
|
new_point.dims().to_vec() |
|
|
}; |
|
|
|
|
|
|
|
|
let count = n + 1.0; |
|
|
let mean_dims: Vec<f32> = new_sum.iter().map(|s| s / count).collect(); |
|
|
let centroid = Point::new(mean_dims).normalize(); |
|
|
(centroid, Point::new(new_sum)) |
|
|
} |
|
|
CentroidMethod::Frechet => { |
|
|
|
|
|
let new_sum = if let Some(ref sum) = accumulated_sum { |
|
|
sum.dims().iter() |
|
|
.zip(new_point.dims().iter()) |
|
|
.map(|(s, p)| s + p) |
|
|
.collect::<Vec<f32>>() |
|
|
} else { |
|
|
new_point.dims().to_vec() |
|
|
}; |
|
|
|
|
|
|
|
|
let new_count = n + 1.0; |
|
|
let weight = 1.0 / new_count; |
|
|
let centroid = Self::geodesic_interpolate_static(&old_centroid, new_point, weight); |
|
|
(centroid, Point::new(new_sum)) |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
let subspace_enabled = self.config.subspace_enabled; |
|
|
if let Some(container) = self.containers.get_mut(&container_id) { |
|
|
container.centroid = new_centroid.clone(); |
|
|
container.accumulated_sum = Some(new_sum); |
|
|
container.descendant_count += 1; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if subspace_enabled |
|
|
&& self.config.subspace_config.incremental_covariance |
|
|
&& container.level != ContainerLevel::Chunk |
|
|
{ |
|
|
if let Some(ref mut subspace) = container.subspace { |
|
|
subspace.add_point(new_point); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let delta: f32 = old_centroid.dims() |
|
|
.iter() |
|
|
.zip(new_centroid.dims().iter()) |
|
|
.map(|(old, new)| (new - old).powi(2)) |
|
|
.sum::<f32>() |
|
|
.sqrt(); |
|
|
|
|
|
delta |
|
|
} |
|
|
|
|
|
|
|
|
fn geodesic_interpolate_static(a: &Point, b: &Point, t: f32) -> Point { |
|
|
|
|
|
let dot: f32 = a.dims().iter() |
|
|
.zip(b.dims().iter()) |
|
|
.map(|(x, y)| x * y) |
|
|
.sum(); |
|
|
|
|
|
|
|
|
let dot_clamped = dot.clamp(-0.9999, 0.9999); |
|
|
let theta = dot_clamped.acos(); |
|
|
|
|
|
if theta.abs() < 1e-8 { |
|
|
|
|
|
return a.clone(); |
|
|
} |
|
|
|
|
|
|
|
|
let sin_theta = theta.sin(); |
|
|
let weight_a = ((1.0 - t) * theta).sin() / sin_theta; |
|
|
let weight_b = (t * theta).sin() / sin_theta; |
|
|
|
|
|
let result_dims: Vec<f32> = a.dims().iter() |
|
|
.zip(b.dims().iter()) |
|
|
.map(|(x, y)| weight_a * x + weight_b * y) |
|
|
.collect(); |
|
|
|
|
|
Point::new(result_dims).normalize() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
fn geodesic_interpolate(&self, a: &Point, b: &Point, t: f32) -> Point { |
|
|
|
|
|
let dot: f32 = a.dims().iter() |
|
|
.zip(b.dims().iter()) |
|
|
.map(|(x, y)| x * y) |
|
|
.sum(); |
|
|
|
|
|
|
|
|
let dot_clamped = dot.clamp(-0.9999, 0.9999); |
|
|
let theta = dot_clamped.acos(); |
|
|
|
|
|
if theta.abs() < 1e-8 { |
|
|
|
|
|
return a.clone(); |
|
|
} |
|
|
|
|
|
|
|
|
let sin_theta = theta.sin(); |
|
|
let weight_a = ((1.0 - t) * theta).sin() / sin_theta; |
|
|
let weight_b = (t * theta).sin() / sin_theta; |
|
|
|
|
|
let result_dims: Vec<f32> = a.dims().iter() |
|
|
.zip(b.dims().iter()) |
|
|
.map(|(x, y)| weight_a * x + weight_b * y) |
|
|
.collect(); |
|
|
|
|
|
Point::new(result_dims).normalize() |
|
|
} |
|
|
|
|
|
|
|
|
fn propagate_centroid_update( |
|
|
&mut self, |
|
|
container_id: Id, |
|
|
new_point: &Point, |
|
|
ancestors: &[Id], |
|
|
) { |
|
|
let threshold = self.config.propagation_threshold; |
|
|
let mut delta = self.update_centroid(container_id, new_point); |
|
|
|
|
|
|
|
|
for ancestor_id in ancestors { |
|
|
if delta < threshold { |
|
|
break; |
|
|
} |
|
|
delta = self.update_centroid(*ancestor_id, new_point); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn search_tree( |
|
|
&self, |
|
|
query: &Point, |
|
|
query_time: u64, |
|
|
start_id: Id, |
|
|
k: usize, |
|
|
) -> Vec<(Id, f32)> { |
|
|
let mut results: Vec<(Id, f32)> = Vec::new(); |
|
|
|
|
|
|
|
|
let beam_width = self.config.beam_width.max(k); |
|
|
|
|
|
|
|
|
let mut current_level = vec![start_id]; |
|
|
|
|
|
while !current_level.is_empty() { |
|
|
let mut next_level: Vec<(Id, f32)> = Vec::new(); |
|
|
|
|
|
for container_id in ¤t_level { |
|
|
if let Some(container) = self.containers.get(container_id) { |
|
|
if container.is_leaf() { |
|
|
|
|
|
let dist = self.combined_distance(query, query_time, container); |
|
|
results.push((*container_id, dist)); |
|
|
} else { |
|
|
|
|
|
for child_id in &container.children { |
|
|
if let Some(child) = self.containers.get(child_id) { |
|
|
let dist = self.combined_distance(query, query_time, child); |
|
|
next_level.push((*child_id, dist)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if next_level.is_empty() { |
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
next_level.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); |
|
|
current_level = next_level |
|
|
.into_iter() |
|
|
.take(beam_width) |
|
|
.map(|(id, _)| id) |
|
|
.collect(); |
|
|
} |
|
|
|
|
|
|
|
|
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); |
|
|
results.truncate(k); |
|
|
results |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn near_sessions(&self, query: &Point, k: usize) -> NearResult<Vec<SessionSummary>> { |
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
let root_id = match self.root_id { |
|
|
Some(id) => id, |
|
|
None => return Ok(vec![]), |
|
|
}; |
|
|
|
|
|
let query_time = SystemTime::now() |
|
|
.duration_since(UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64; |
|
|
|
|
|
|
|
|
let root = match self.containers.get(&root_id) { |
|
|
Some(r) => r, |
|
|
None => return Ok(vec![]), |
|
|
}; |
|
|
|
|
|
let mut sessions: Vec<SessionSummary> = root.children |
|
|
.iter() |
|
|
.filter_map(|session_id| { |
|
|
let session = self.containers.get(session_id)?; |
|
|
if session.level != ContainerLevel::Session { |
|
|
return None; |
|
|
} |
|
|
let dist = self.combined_distance(query, query_time, session); |
|
|
let score = if self.higher_is_better { 1.0 - dist } else { dist }; |
|
|
|
|
|
Some(SessionSummary { |
|
|
id: *session_id, |
|
|
score, |
|
|
chunk_count: session.descendant_count, |
|
|
timestamp: session.timestamp, |
|
|
}) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
sessions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
|
|
sessions.truncate(k); |
|
|
|
|
|
Ok(sessions) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn near_documents(&self, session_id: Id, query: &Point, k: usize) -> NearResult<Vec<DocumentSummary>> { |
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
let query_time = SystemTime::now() |
|
|
.duration_since(UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64; |
|
|
|
|
|
let session = match self.containers.get(&session_id) { |
|
|
Some(s) => s, |
|
|
None => return Ok(vec![]), |
|
|
}; |
|
|
|
|
|
let mut documents: Vec<DocumentSummary> = session.children |
|
|
.iter() |
|
|
.filter_map(|doc_id| { |
|
|
let doc = self.containers.get(doc_id)?; |
|
|
if doc.level != ContainerLevel::Document { |
|
|
return None; |
|
|
} |
|
|
let dist = self.combined_distance(query, query_time, doc); |
|
|
let score = if self.higher_is_better { 1.0 - dist } else { dist }; |
|
|
|
|
|
Some(DocumentSummary { |
|
|
id: *doc_id, |
|
|
score, |
|
|
chunk_count: doc.descendant_count, |
|
|
timestamp: doc.timestamp, |
|
|
}) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
documents.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
|
|
documents.truncate(k); |
|
|
|
|
|
Ok(documents) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn near_in_document(&self, doc_id: Id, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
let query_time = SystemTime::now() |
|
|
.duration_since(UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64; |
|
|
|
|
|
let doc = match self.containers.get(&doc_id) { |
|
|
Some(d) => d, |
|
|
None => return Ok(vec![]), |
|
|
}; |
|
|
|
|
|
let mut chunks: Vec<SearchResult> = doc.children |
|
|
.iter() |
|
|
.filter_map(|chunk_id| { |
|
|
let chunk = self.containers.get(chunk_id)?; |
|
|
if chunk.level != ContainerLevel::Chunk { |
|
|
return None; |
|
|
} |
|
|
let dist = self.combined_distance(query, query_time, chunk); |
|
|
let score = if self.higher_is_better { 1.0 - dist } else { dist }; |
|
|
|
|
|
Some(SearchResult::new(*chunk_id, score)) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
|
|
chunks.truncate(k); |
|
|
|
|
|
Ok(chunks) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn stats(&self) -> HatStats { |
|
|
let mut stats = HatStats::default(); |
|
|
|
|
|
for container in self.containers.values() { |
|
|
match container.level { |
|
|
ContainerLevel::Global => stats.global_count += 1, |
|
|
ContainerLevel::Session => stats.session_count += 1, |
|
|
ContainerLevel::Document => stats.document_count += 1, |
|
|
ContainerLevel::Chunk => stats.chunk_count += 1, |
|
|
} |
|
|
} |
|
|
|
|
|
stats |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn record_retrieval_success(&mut self, query: &Point, result_id: Id) { |
|
|
if let Some(ref mut router) = self.learnable_router { |
|
|
|
|
|
if let Some(container) = self.containers.get(&result_id) { |
|
|
router.record_success(query, &container.centroid, container.level.depth()); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn record_retrieval_failure(&mut self, query: &Point, result_id: Id) { |
|
|
if let Some(ref mut router) = self.learnable_router { |
|
|
if let Some(container) = self.containers.get(&result_id) { |
|
|
router.record_failure(query, &container.centroid, container.level.depth()); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn record_implicit_feedback(&mut self, query: &Point, result_id: Id, relevance: f32) { |
|
|
if let Some(ref mut router) = self.learnable_router { |
|
|
if let Some(container) = self.containers.get(&result_id) { |
|
|
router.record_implicit(query, &container.centroid, container.level.depth(), relevance); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn router_stats(&self) -> Option<super::learnable_routing::RouterStats> { |
|
|
self.learnable_router.as_ref().map(|r| r.stats()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn routing_weights(&self) -> Option<&[f32]> { |
|
|
self.learnable_router.as_ref().map(|r| r.weights()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn reset_routing_weights(&mut self) { |
|
|
if let Some(ref mut router) = self.learnable_router { |
|
|
router.reset_weights(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn is_learnable_routing_enabled(&self) -> bool { |
|
|
self.learnable_router.is_some() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Default)] |
|
|
pub struct HatStats { |
|
|
pub global_count: usize, |
|
|
pub session_count: usize, |
|
|
pub document_count: usize, |
|
|
pub chunk_count: usize, |
|
|
} |
|
|
|
|
|
impl Near for HatIndex { |
|
|
fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
|
|
|
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
let root_id = match self.root_id { |
|
|
Some(id) => id, |
|
|
None => return Ok(vec![]), |
|
|
}; |
|
|
|
|
|
|
|
|
let query_time = SystemTime::now() |
|
|
.duration_since(UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64; |
|
|
|
|
|
|
|
|
let results = self.search_tree(query, query_time, root_id, k); |
|
|
|
|
|
|
|
|
let search_results: Vec<SearchResult> = results |
|
|
.into_iter() |
|
|
.map(|(id, dist)| { |
|
|
let score = if self.higher_is_better { |
|
|
1.0 - dist |
|
|
} else { |
|
|
dist |
|
|
}; |
|
|
SearchResult::new(id, score) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
Ok(search_results) |
|
|
} |
|
|
|
|
|
fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> { |
|
|
|
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
let all_results = self.near(query, self.containers.len())?; |
|
|
|
|
|
let filtered: Vec<SearchResult> = all_results |
|
|
.into_iter() |
|
|
.filter(|r| { |
|
|
if self.higher_is_better { |
|
|
r.score >= threshold |
|
|
} else { |
|
|
r.score <= threshold |
|
|
} |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
Ok(filtered) |
|
|
} |
|
|
|
|
|
fn add(&mut self, id: Id, point: &Point) -> NearResult<()> { |
|
|
|
|
|
if point.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: point.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
self.ensure_document(); |
|
|
|
|
|
|
|
|
let chunk = Container::new(id, ContainerLevel::Chunk, point.clone()); |
|
|
self.containers.insert(id, chunk); |
|
|
|
|
|
|
|
|
if let Some(doc_id) = self.active_document { |
|
|
if let Some(doc) = self.containers.get_mut(&doc_id) { |
|
|
doc.children.push(id); |
|
|
} |
|
|
|
|
|
|
|
|
let mut ancestors = Vec::new(); |
|
|
if let Some(session_id) = self.active_session { |
|
|
ancestors.push(session_id); |
|
|
if let Some(root_id) = self.root_id { |
|
|
ancestors.push(root_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
self.propagate_centroid_update(doc_id, point, &ancestors); |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(doc_id) = self.active_document { |
|
|
if let Some(doc) = self.containers.get(&doc_id) { |
|
|
if doc.children.len() >= self.config.max_children { |
|
|
|
|
|
self.new_document(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(session_id) = self.active_session { |
|
|
if let Some(session) = self.containers.get(&session_id) { |
|
|
if session.children.len() >= self.config.max_children { |
|
|
|
|
|
self.new_session(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn remove(&mut self, id: Id) -> NearResult<()> { |
|
|
|
|
|
self.containers.remove(&id); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn rebuild(&mut self) -> NearResult<()> { |
|
|
|
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn is_ready(&self) -> bool { |
|
|
true |
|
|
} |
|
|
|
|
|
fn len(&self) -> usize { |
|
|
|
|
|
self.containers.values() |
|
|
.filter(|c| c.level == ContainerLevel::Chunk) |
|
|
.count() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl HatIndex { |
|
|
|
|
|
fn collect_leaf_points(&self, container_id: Id) -> Vec<Point> { |
|
|
let container = match self.containers.get(&container_id) { |
|
|
Some(c) => c, |
|
|
None => return vec![], |
|
|
}; |
|
|
|
|
|
if container.is_leaf() { |
|
|
return vec![container.centroid.clone()]; |
|
|
} |
|
|
|
|
|
let mut points = Vec::new(); |
|
|
for child_id in &container.children { |
|
|
points.extend(self.collect_leaf_points(*child_id)); |
|
|
} |
|
|
points |
|
|
} |
|
|
|
|
|
|
|
|
fn containers_at_level(&self, level: ContainerLevel) -> Vec<Id> { |
|
|
self.containers |
|
|
.iter() |
|
|
.filter(|(_, c)| c.level == level) |
|
|
.map(|(id, _)| *id) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
fn recompute_centroid(&mut self, container_id: Id) -> Option<f32> { |
|
|
|
|
|
let points = self.collect_leaf_points(container_id); |
|
|
|
|
|
if points.is_empty() { |
|
|
return None; |
|
|
} |
|
|
|
|
|
let new_centroid = match compute_exact_centroid(&points) { |
|
|
Some(c) => c, |
|
|
None => return None, |
|
|
}; |
|
|
|
|
|
|
|
|
let subspace_enabled = self.config.subspace_enabled; |
|
|
let subspace_rank = self.config.subspace_config.rank; |
|
|
|
|
|
|
|
|
let drift = if let Some(container) = self.containers.get_mut(&container_id) { |
|
|
let old_centroid = container.centroid.clone(); |
|
|
let drift = centroid_drift(&old_centroid, &new_centroid); |
|
|
container.centroid = new_centroid; |
|
|
container.descendant_count = points.len(); |
|
|
|
|
|
|
|
|
let sum: Vec<f32> = points.iter() |
|
|
.fold(vec![0.0f32; self.dimensionality], |mut acc, p| { |
|
|
for (i, &v) in p.dims().iter().enumerate() { |
|
|
acc[i] += v; |
|
|
} |
|
|
acc |
|
|
}); |
|
|
container.accumulated_sum = Some(Point::new(sum)); |
|
|
|
|
|
|
|
|
if subspace_enabled && container.level != ContainerLevel::Chunk { |
|
|
let mut subspace = super::subspace::Subspace::new(self.dimensionality); |
|
|
for point in &points { |
|
|
subspace.add_point(point); |
|
|
} |
|
|
subspace.recompute_subspace(subspace_rank); |
|
|
container.subspace = Some(subspace); |
|
|
} |
|
|
|
|
|
Some(drift) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
drift |
|
|
} |
|
|
|
|
|
|
|
|
fn should_merge(&self, container_id: Id, threshold: usize) -> bool { |
|
|
if let Some(container) = self.containers.get(&container_id) { |
|
|
|
|
|
if container.level == ContainerLevel::Chunk || |
|
|
container.level == ContainerLevel::Global || |
|
|
container.level == ContainerLevel::Session { |
|
|
return false; |
|
|
} |
|
|
container.children.len() < threshold |
|
|
} else { |
|
|
false |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn should_split(&self, container_id: Id, threshold: usize) -> bool { |
|
|
if let Some(container) = self.containers.get(&container_id) { |
|
|
|
|
|
if container.level == ContainerLevel::Chunk { |
|
|
return false; |
|
|
} |
|
|
container.children.len() > threshold |
|
|
} else { |
|
|
false |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn find_merge_sibling(&self, container_id: Id) -> Option<Id> { |
|
|
|
|
|
let parent_id = self.containers.iter() |
|
|
.find(|(_, c)| c.children.contains(&container_id)) |
|
|
.map(|(id, _)| *id)?; |
|
|
|
|
|
let parent = self.containers.get(&parent_id)?; |
|
|
|
|
|
|
|
|
let mut smallest: Option<(Id, usize)> = None; |
|
|
for child_id in &parent.children { |
|
|
if *child_id == container_id { |
|
|
continue; |
|
|
} |
|
|
if let Some(child) = self.containers.get(child_id) { |
|
|
let size = child.children.len(); |
|
|
if smallest.is_none() || size < smallest.unwrap().1 { |
|
|
smallest = Some((*child_id, size)); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
smallest.map(|(id, _)| id) |
|
|
} |
|
|
|
|
|
|
|
|
fn merge_containers(&mut self, a_id: Id, b_id: Id) { |
|
|
|
|
|
let b_children: Vec<Id> = if let Some(b) = self.containers.get(&b_id) { |
|
|
b.children.clone() |
|
|
} else { |
|
|
return; |
|
|
}; |
|
|
|
|
|
|
|
|
if let Some(a) = self.containers.get_mut(&a_id) { |
|
|
a.children.extend(b_children); |
|
|
} |
|
|
|
|
|
|
|
|
let parent_id = self.containers.iter() |
|
|
.find(|(_, c)| c.children.contains(&b_id)) |
|
|
.map(|(id, _)| *id); |
|
|
|
|
|
if let Some(pid) = parent_id { |
|
|
if let Some(parent) = self.containers.get_mut(&pid) { |
|
|
parent.children.retain(|id| *id != b_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
self.containers.remove(&b_id); |
|
|
|
|
|
|
|
|
self.recompute_centroid(a_id); |
|
|
} |
|
|
|
|
|
|
|
|
fn split_container(&mut self, container_id: Id) -> Option<Id> { |
|
|
|
|
|
let (level, children, parent_id) = { |
|
|
let container = self.containers.get(&container_id)?; |
|
|
let parent_id = self.containers.iter() |
|
|
.find(|(_, c)| c.children.contains(&container_id)) |
|
|
.map(|(id, _)| *id); |
|
|
(container.level, container.children.clone(), parent_id) |
|
|
}; |
|
|
|
|
|
if children.len() < 2 { |
|
|
return None; |
|
|
} |
|
|
|
|
|
|
|
|
let mid = children.len() / 2; |
|
|
let (keep, move_to_new) = children.split_at(mid); |
|
|
|
|
|
|
|
|
let new_id = Id::now(); |
|
|
let new_container = Container::new( |
|
|
new_id, |
|
|
level, |
|
|
Point::origin(self.dimensionality), |
|
|
); |
|
|
self.containers.insert(new_id, new_container); |
|
|
|
|
|
|
|
|
if let Some(container) = self.containers.get_mut(&container_id) { |
|
|
container.children = keep.to_vec(); |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(new_container) = self.containers.get_mut(&new_id) { |
|
|
new_container.children = move_to_new.to_vec(); |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(pid) = parent_id { |
|
|
if let Some(parent) = self.containers.get_mut(&pid) { |
|
|
parent.children.push(new_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
self.recompute_centroid(container_id); |
|
|
self.recompute_centroid(new_id); |
|
|
|
|
|
Some(new_id) |
|
|
} |
|
|
|
|
|
|
|
|
fn prune_empty(&mut self) -> usize { |
|
|
let mut pruned = 0; |
|
|
|
|
|
loop { |
|
|
let empty_ids: Vec<Id> = self.containers |
|
|
.iter() |
|
|
.filter(|(_, c)| { |
|
|
c.level != ContainerLevel::Chunk && |
|
|
c.level != ContainerLevel::Global && |
|
|
c.children.is_empty() |
|
|
}) |
|
|
.map(|(id, _)| *id) |
|
|
.collect(); |
|
|
|
|
|
if empty_ids.is_empty() { |
|
|
break; |
|
|
} |
|
|
|
|
|
for id in empty_ids { |
|
|
|
|
|
let parent_id = self.containers.iter() |
|
|
.find(|(_, c)| c.children.contains(&id)) |
|
|
.map(|(pid, _)| *pid); |
|
|
|
|
|
if let Some(pid) = parent_id { |
|
|
if let Some(parent) = self.containers.get_mut(&pid) { |
|
|
parent.children.retain(|cid| *cid != id); |
|
|
} |
|
|
} |
|
|
|
|
|
self.containers.remove(&id); |
|
|
pruned += 1; |
|
|
} |
|
|
} |
|
|
|
|
|
pruned |
|
|
} |
|
|
} |
|
|
|
|
|
impl Consolidate for HatIndex { |
|
|
fn begin_consolidation(&mut self, config: ConsolidationConfig) { |
|
|
let mut state = ConsolidationState::new(config); |
|
|
state.start(); |
|
|
|
|
|
|
|
|
let all_ids: VecDeque<Id> = self.containers.keys().copied().collect(); |
|
|
state.work_queue = all_ids; |
|
|
|
|
|
self.consolidation_state = Some(state); |
|
|
self.consolidation_points_cache.clear(); |
|
|
} |
|
|
|
|
|
fn consolidation_tick(&mut self) -> ConsolidationTickResult { |
|
|
|
|
|
let mut state = match self.consolidation_state.take() { |
|
|
Some(s) => s, |
|
|
None => { |
|
|
return ConsolidationTickResult::Complete(ConsolidationMetrics::default()); |
|
|
} |
|
|
}; |
|
|
|
|
|
let batch_size = state.config.batch_size; |
|
|
|
|
|
match state.phase { |
|
|
ConsolidationPhase::Idle => { |
|
|
state.start(); |
|
|
} |
|
|
|
|
|
ConsolidationPhase::CollectingLeaves => { |
|
|
state.next_phase(); |
|
|
|
|
|
|
|
|
let docs = self.containers_at_level(ContainerLevel::Document); |
|
|
let sessions = self.containers_at_level(ContainerLevel::Session); |
|
|
let globals = self.containers_at_level(ContainerLevel::Global); |
|
|
|
|
|
state.work_queue.clear(); |
|
|
state.work_queue.extend(docs); |
|
|
state.work_queue.extend(sessions); |
|
|
state.work_queue.extend(globals); |
|
|
} |
|
|
|
|
|
ConsolidationPhase::RecomputingCentroids => { |
|
|
let mut processed = 0; |
|
|
let mut to_recompute = Vec::new(); |
|
|
|
|
|
while processed < batch_size { |
|
|
match state.work_queue.pop_front() { |
|
|
Some(id) => { |
|
|
to_recompute.push(id); |
|
|
state.processed.insert(id); |
|
|
processed += 1; |
|
|
} |
|
|
None => break, |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
for container_id in to_recompute { |
|
|
if let Some(drift) = self.recompute_centroid(container_id) { |
|
|
state.record_drift(drift); |
|
|
state.metrics.centroids_recomputed += 1; |
|
|
} |
|
|
state.metrics.containers_processed += 1; |
|
|
} |
|
|
|
|
|
if state.work_queue.is_empty() { |
|
|
state.next_phase(); |
|
|
|
|
|
if state.phase == ConsolidationPhase::AnalyzingStructure { |
|
|
let docs = self.containers_at_level(ContainerLevel::Document); |
|
|
state.work_queue.extend(docs); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
ConsolidationPhase::AnalyzingStructure => { |
|
|
let merge_threshold = state.config.merge_threshold; |
|
|
let split_threshold = state.config.split_threshold; |
|
|
let mut processed = 0; |
|
|
let mut to_analyze = Vec::new(); |
|
|
|
|
|
while processed < batch_size { |
|
|
match state.work_queue.pop_front() { |
|
|
Some(id) => { |
|
|
to_analyze.push(id); |
|
|
state.processed.insert(id); |
|
|
processed += 1; |
|
|
} |
|
|
None => break, |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
for container_id in to_analyze { |
|
|
if self.should_merge(container_id, merge_threshold) { |
|
|
if let Some(sibling) = self.find_merge_sibling(container_id) { |
|
|
state.add_merge_candidate(container_id, sibling); |
|
|
} |
|
|
} else if self.should_split(container_id, split_threshold) { |
|
|
state.add_split_candidate(container_id); |
|
|
} |
|
|
} |
|
|
|
|
|
if state.work_queue.is_empty() { |
|
|
state.next_phase(); |
|
|
} |
|
|
} |
|
|
|
|
|
ConsolidationPhase::Merging => { |
|
|
let mut processed = 0; |
|
|
let mut to_merge = Vec::new(); |
|
|
|
|
|
while processed < batch_size { |
|
|
match state.next_merge() { |
|
|
Some(pair) => { |
|
|
to_merge.push(pair); |
|
|
processed += 1; |
|
|
} |
|
|
None => break, |
|
|
}; |
|
|
} |
|
|
|
|
|
for (a, b) in to_merge { |
|
|
self.merge_containers(a, b); |
|
|
state.metrics.containers_merged += 1; |
|
|
} |
|
|
|
|
|
if !state.has_merges() { |
|
|
state.next_phase(); |
|
|
} |
|
|
} |
|
|
|
|
|
ConsolidationPhase::Splitting => { |
|
|
let mut processed = 0; |
|
|
let mut to_split = Vec::new(); |
|
|
|
|
|
while processed < batch_size { |
|
|
match state.next_split() { |
|
|
Some(id) => { |
|
|
to_split.push(id); |
|
|
processed += 1; |
|
|
} |
|
|
None => break, |
|
|
}; |
|
|
} |
|
|
|
|
|
for container_id in to_split { |
|
|
if self.split_container(container_id).is_some() { |
|
|
state.metrics.containers_split += 1; |
|
|
} |
|
|
} |
|
|
|
|
|
if !state.has_splits() { |
|
|
state.next_phase(); |
|
|
} |
|
|
} |
|
|
|
|
|
ConsolidationPhase::Pruning => { |
|
|
let pruned = self.prune_empty(); |
|
|
state.metrics.containers_pruned = pruned; |
|
|
state.next_phase(); |
|
|
} |
|
|
|
|
|
ConsolidationPhase::OptimizingLayout => { |
|
|
for container in self.containers.values_mut() { |
|
|
if container.children.len() > 1 { |
|
|
|
|
|
} |
|
|
} |
|
|
state.next_phase(); |
|
|
} |
|
|
|
|
|
ConsolidationPhase::Complete => { |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
state.metrics.ticks += 1; |
|
|
|
|
|
if state.is_complete() { |
|
|
let metrics = state.metrics.clone(); |
|
|
self.consolidation_points_cache.clear(); |
|
|
ConsolidationTickResult::Complete(metrics) |
|
|
} else { |
|
|
let progress = state.progress(); |
|
|
self.consolidation_state = Some(state); |
|
|
ConsolidationTickResult::Continue(progress) |
|
|
} |
|
|
} |
|
|
|
|
|
fn is_consolidating(&self) -> bool { |
|
|
self.consolidation_state.is_some() |
|
|
} |
|
|
|
|
|
fn consolidation_progress(&self) -> Option<ConsolidationProgress> { |
|
|
self.consolidation_state.as_ref().map(|s| s.progress()) |
|
|
} |
|
|
|
|
|
fn cancel_consolidation(&mut self) { |
|
|
self.consolidation_state = None; |
|
|
self.consolidation_points_cache.clear(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl HatIndex { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn to_bytes(&self) -> Result<Vec<u8>, super::persistence::PersistError> { |
|
|
use super::persistence::{SerializedHat, SerializedContainer, LevelByte}; |
|
|
|
|
|
let containers: Vec<SerializedContainer> = self.containers.iter() |
|
|
.map(|(_, c)| { |
|
|
let level = match c.level { |
|
|
ContainerLevel::Global => LevelByte::Root, |
|
|
ContainerLevel::Session => LevelByte::Session, |
|
|
ContainerLevel::Document => LevelByte::Document, |
|
|
ContainerLevel::Chunk => LevelByte::Chunk, |
|
|
}; |
|
|
|
|
|
SerializedContainer { |
|
|
id: c.id, |
|
|
level, |
|
|
timestamp: c.timestamp, |
|
|
children: c.children.clone(), |
|
|
descendant_count: c.descendant_count as u64, |
|
|
centroid: c.centroid.dims().to_vec(), |
|
|
accumulated_sum: c.accumulated_sum.as_ref().map(|p| p.dims().to_vec()), |
|
|
} |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
let router_weights = self.learnable_router.as_ref() |
|
|
.map(|r| r.weights().to_vec()); |
|
|
|
|
|
let serialized = SerializedHat { |
|
|
version: 1, |
|
|
dimensionality: self.dimensionality as u32, |
|
|
root_id: self.root_id, |
|
|
containers, |
|
|
active_session: self.active_session, |
|
|
active_document: self.active_document, |
|
|
router_weights, |
|
|
}; |
|
|
|
|
|
serialized.to_bytes() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Result<Self, super::persistence::PersistError> { |
|
|
use super::persistence::{SerializedHat, LevelByte, PersistError}; |
|
|
use crate::core::proximity::Cosine; |
|
|
use crate::core::merge::Mean; |
|
|
|
|
|
let serialized = SerializedHat::from_bytes(data)?; |
|
|
let dimensionality = serialized.dimensionality as usize; |
|
|
|
|
|
|
|
|
let mut index = Self::new( |
|
|
dimensionality, |
|
|
Arc::new(Cosine), |
|
|
Arc::new(Mean), |
|
|
true, |
|
|
HatConfig::default(), |
|
|
); |
|
|
|
|
|
|
|
|
for sc in serialized.containers { |
|
|
let level = match sc.level { |
|
|
LevelByte::Root => ContainerLevel::Global, |
|
|
LevelByte::Session => ContainerLevel::Session, |
|
|
LevelByte::Document => ContainerLevel::Document, |
|
|
LevelByte::Chunk => ContainerLevel::Chunk, |
|
|
}; |
|
|
|
|
|
|
|
|
if sc.centroid.len() != dimensionality { |
|
|
return Err(PersistError::DimensionMismatch { |
|
|
expected: dimensionality, |
|
|
found: sc.centroid.len(), |
|
|
}); |
|
|
} |
|
|
|
|
|
let centroid = Point::new(sc.centroid); |
|
|
let accumulated_sum = sc.accumulated_sum.map(Point::new); |
|
|
|
|
|
let container = Container { |
|
|
id: sc.id, |
|
|
level, |
|
|
centroid, |
|
|
timestamp: sc.timestamp, |
|
|
children: sc.children, |
|
|
descendant_count: sc.descendant_count as usize, |
|
|
accumulated_sum, |
|
|
subspace: if level != ContainerLevel::Chunk { |
|
|
Some(super::subspace::Subspace::new(dimensionality)) |
|
|
} else { |
|
|
None |
|
|
}, |
|
|
}; |
|
|
|
|
|
index.containers.insert(sc.id, container); |
|
|
} |
|
|
|
|
|
|
|
|
index.root_id = serialized.root_id; |
|
|
index.active_session = serialized.active_session; |
|
|
index.active_document = serialized.active_document; |
|
|
|
|
|
|
|
|
if let Some(weights) = serialized.router_weights { |
|
|
let mut router = super::learnable_routing::LearnableRouter::default_for_dims(dimensionality); |
|
|
let weight_bytes: Vec<u8> = weights.iter() |
|
|
.flat_map(|w| w.to_le_bytes()) |
|
|
.collect(); |
|
|
router.deserialize_weights(&weight_bytes) |
|
|
.map_err(|e| PersistError::Corrupted(e.to_string()))?; |
|
|
index.learnable_router = Some(router); |
|
|
} |
|
|
|
|
|
Ok(index) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), super::persistence::PersistError> { |
|
|
let bytes = self.to_bytes()?; |
|
|
std::fs::write(path, bytes)?; |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn load_from_file(path: &std::path::Path) -> Result<Self, super::persistence::PersistError> { |
|
|
let bytes = std::fs::read(path)?; |
|
|
Self::from_bytes(&bytes) |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_hat_add() { |
|
|
let mut index = HatIndex::cosine(3); |
|
|
|
|
|
let id = Id::now(); |
|
|
let point = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
|
|
|
index.add(id, &point).unwrap(); |
|
|
|
|
|
assert_eq!(index.len(), 1); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_near() { |
|
|
let mut index = HatIndex::cosine(3); |
|
|
|
|
|
|
|
|
let points = vec![ |
|
|
Point::new(vec![1.0, 0.0, 0.0]), |
|
|
Point::new(vec![0.0, 1.0, 0.0]), |
|
|
Point::new(vec![0.0, 0.0, 1.0]), |
|
|
Point::new(vec![0.7, 0.7, 0.0]).normalize(), |
|
|
]; |
|
|
|
|
|
for point in &points { |
|
|
index.add(Id::now(), point).unwrap(); |
|
|
} |
|
|
|
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = index.near(&query, 2).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 2); |
|
|
|
|
|
assert!(results[0].score > 0.5); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_sessions() { |
|
|
let mut index = HatIndex::cosine(3); |
|
|
|
|
|
|
|
|
for i in 0..5 { |
|
|
let point = Point::new(vec![1.0, i as f32 * 0.1, 0.0]).normalize(); |
|
|
index.add(Id::now(), &point).unwrap(); |
|
|
} |
|
|
|
|
|
|
|
|
index.new_session(); |
|
|
|
|
|
|
|
|
for i in 0..5 { |
|
|
let point = Point::new(vec![0.0, 1.0, i as f32 * 0.1]).normalize(); |
|
|
index.add(Id::now(), &point).unwrap(); |
|
|
} |
|
|
|
|
|
assert_eq!(index.len(), 10); |
|
|
|
|
|
|
|
|
let query = Point::new(vec![0.5, 0.5, 0.0]).normalize(); |
|
|
let results = index.near(&query, 5).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 5); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_hierarchy_structure() { |
|
|
let mut index = HatIndex::cosine(3); |
|
|
|
|
|
|
|
|
for _ in 0..10 { |
|
|
let point = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
index.add(Id::now(), &point).unwrap(); |
|
|
} |
|
|
|
|
|
|
|
|
assert!(index.containers.len() >= 13); |
|
|
|
|
|
|
|
|
assert!(index.root_id.is_some()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_empty() { |
|
|
let index = HatIndex::cosine(3); |
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = index.near(&query, 5).unwrap(); |
|
|
|
|
|
assert!(results.is_empty()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_dimensionality_check() { |
|
|
let mut index = HatIndex::cosine(3); |
|
|
|
|
|
let wrong_dims = Point::new(vec![1.0, 0.0]); |
|
|
let result = index.add(Id::now(), &wrong_dims); |
|
|
|
|
|
match result { |
|
|
Err(NearError::DimensionalityMismatch { expected, got }) => { |
|
|
assert_eq!(expected, 3); |
|
|
assert_eq!(got, 2); |
|
|
} |
|
|
_ => panic!("Expected DimensionalityMismatch error"), |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_hat_scale() { |
|
|
let mut index = HatIndex::cosine(128); |
|
|
|
|
|
|
|
|
for i in 0..1000 { |
|
|
let mut dims = vec![0.0f32; 128]; |
|
|
dims[i % 128] = 1.0; |
|
|
let point = Point::new(dims).normalize(); |
|
|
index.add(Id::now(), &point).unwrap(); |
|
|
} |
|
|
|
|
|
assert_eq!(index.len(), 1000); |
|
|
|
|
|
|
|
|
let query = Point::new(vec![1.0; 128]).normalize(); |
|
|
let results = index.near(&query, 10).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 10); |
|
|
} |
|
|
} |
|
|
|