|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use std::collections::HashMap; |
|
|
use std::sync::Arc; |
|
|
|
|
|
use crate::core::{Id, Point}; |
|
|
use crate::core::proximity::Proximity; |
|
|
use crate::ports::{Near, NearError, NearResult, SearchResult}; |
|
|
|
|
|
|
|
|
pub struct FlatIndex { |
|
|
|
|
|
points: HashMap<Id, Point>, |
|
|
|
|
|
|
|
|
dimensionality: usize, |
|
|
|
|
|
|
|
|
proximity: Arc<dyn Proximity>, |
|
|
|
|
|
|
|
|
|
|
|
higher_is_better: bool, |
|
|
} |
|
|
|
|
|
impl FlatIndex { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn new( |
|
|
dimensionality: usize, |
|
|
proximity: Arc<dyn Proximity>, |
|
|
higher_is_better: bool, |
|
|
) -> Self { |
|
|
Self { |
|
|
points: HashMap::new(), |
|
|
dimensionality, |
|
|
proximity, |
|
|
higher_is_better, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn cosine(dimensionality: usize) -> Self { |
|
|
use crate::core::proximity::Cosine; |
|
|
Self::new(dimensionality, Arc::new(Cosine), true) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn euclidean(dimensionality: usize) -> Self { |
|
|
use crate::core::proximity::Euclidean; |
|
|
Self::new(dimensionality, Arc::new(Euclidean), false) |
|
|
} |
|
|
|
|
|
|
|
|
fn sort_results(&self, results: &mut Vec<SearchResult>) { |
|
|
if self.higher_is_better { |
|
|
|
|
|
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
|
|
} else { |
|
|
|
|
|
results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap()); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Near for FlatIndex { |
|
|
fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
|
|
|
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
let mut results: Vec<SearchResult> = self |
|
|
.points |
|
|
.iter() |
|
|
.map(|(id, point)| { |
|
|
let score = self.proximity.proximity(query, point); |
|
|
SearchResult::new(*id, score) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
self.sort_results(&mut results); |
|
|
|
|
|
|
|
|
results.truncate(k); |
|
|
|
|
|
Ok(results) |
|
|
} |
|
|
|
|
|
fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> { |
|
|
|
|
|
if query.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: query.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
let mut results: Vec<SearchResult> = self |
|
|
.points |
|
|
.iter() |
|
|
.filter_map(|(id, point)| { |
|
|
let score = self.proximity.proximity(query, point); |
|
|
let within = if self.higher_is_better { |
|
|
score >= threshold |
|
|
} else { |
|
|
score <= threshold |
|
|
}; |
|
|
if within { |
|
|
Some(SearchResult::new(*id, score)) |
|
|
} else { |
|
|
None |
|
|
} |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
self.sort_results(&mut results); |
|
|
|
|
|
Ok(results) |
|
|
} |
|
|
|
|
|
fn add(&mut self, id: Id, point: &Point) -> NearResult<()> { |
|
|
if point.dimensionality() != self.dimensionality { |
|
|
return Err(NearError::DimensionalityMismatch { |
|
|
expected: self.dimensionality, |
|
|
got: point.dimensionality(), |
|
|
}); |
|
|
} |
|
|
|
|
|
self.points.insert(id, point.clone()); |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn remove(&mut self, id: Id) -> NearResult<()> { |
|
|
self.points.remove(&id); |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn rebuild(&mut self) -> NearResult<()> { |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
fn is_ready(&self) -> bool { |
|
|
true |
|
|
} |
|
|
|
|
|
fn len(&self) -> usize { |
|
|
self.points.len() |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
fn setup_index() -> FlatIndex { |
|
|
let mut index = FlatIndex::cosine(3); |
|
|
|
|
|
|
|
|
let points = vec![ |
|
|
(Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])), |
|
|
(Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])), |
|
|
(Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])), |
|
|
(Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()), |
|
|
]; |
|
|
|
|
|
for (id, point) in points { |
|
|
index.add(id, &point).unwrap(); |
|
|
} |
|
|
|
|
|
index |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_near() { |
|
|
let index = setup_index(); |
|
|
|
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = index.near(&query, 2).unwrap(); |
|
|
|
|
|
assert_eq!(results.len(), 2); |
|
|
|
|
|
|
|
|
assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
|
|
assert!((results[0].score - 1.0).abs() < 0.0001); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_within_cosine() { |
|
|
let index = setup_index(); |
|
|
|
|
|
|
|
|
let query = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
let results = index.within(&query, 0.5).unwrap(); |
|
|
|
|
|
|
|
|
assert_eq!(results.len(), 2); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_euclidean() { |
|
|
let mut index = FlatIndex::euclidean(2); |
|
|
|
|
|
index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap(); |
|
|
index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap(); |
|
|
index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap(); |
|
|
|
|
|
let query = Point::new(vec![0.0, 0.0]); |
|
|
let results = index.near(&query, 2).unwrap(); |
|
|
|
|
|
|
|
|
assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
|
|
assert!((results[0].score - 0.0).abs() < 0.0001); |
|
|
|
|
|
|
|
|
assert_eq!(results[1].id, Id::from_bytes([2; 16])); |
|
|
assert!((results[1].score - 1.0).abs() < 0.0001); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_add_remove() { |
|
|
let mut index = FlatIndex::cosine(3); |
|
|
|
|
|
let id = Id::from_bytes([1; 16]); |
|
|
let point = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
|
|
|
index.add(id, &point).unwrap(); |
|
|
assert_eq!(index.len(), 1); |
|
|
|
|
|
index.remove(id).unwrap(); |
|
|
assert_eq!(index.len(), 0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_dimensionality_check() { |
|
|
let mut index = FlatIndex::cosine(3); |
|
|
|
|
|
let wrong_dims = Point::new(vec![1.0, 0.0]); |
|
|
let result = index.add(Id::now(), &wrong_dims); |
|
|
|
|
|
match result { |
|
|
Err(NearError::DimensionalityMismatch { expected, got }) => { |
|
|
assert_eq!(expected, 3); |
|
|
assert_eq!(got, 2); |
|
|
} |
|
|
_ => panic!("Expected DimensionalityMismatch error"), |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_flat_index_ready() { |
|
|
let index = FlatIndex::cosine(3); |
|
|
assert!(index.is_ready()); |
|
|
} |
|
|
} |
|
|
|