Spaces:
Runtime error
Runtime error
| use std::ffi::CString; | |
| use std::ffi::{c_char, c_int}; | |
| use crate::errors::{ChromaError, ErrorCodes}; | |
| use super::{Index, IndexConfig, PersistentIndex}; | |
| use crate::types::{Metadata, MetadataValue, MetadataValueConversionError, Segment}; | |
| use thiserror::Error; | |
| // https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs | |
| struct IndexPtrFFI { | |
| _data: [u8; 0], | |
| _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, | |
| } | |
| // TODO: Make this config: | |
| // - Watchable - for dynamic updates | |
| // - Have a notion of static vs dynamic config | |
| // - Have a notion of default config | |
| // - HNSWIndex should store a ref to the config so it can look up the config values. | |
| // deferring this for a config pass | |
| pub(crate) struct HnswIndexConfig { | |
| pub(crate) max_elements: usize, | |
| pub(crate) m: usize, | |
| pub(crate) ef_construction: usize, | |
| pub(crate) ef_search: usize, | |
| pub(crate) random_seed: usize, | |
| pub(crate) persist_path: String, | |
| } | |
| pub(crate) enum HnswIndexFromSegmentError { | |
| MissingConfig(String), | |
| } | |
| impl ChromaError for HnswIndexFromSegmentError { | |
| fn code(&self) -> ErrorCodes { | |
| crate::errors::ErrorCodes::InvalidArgument | |
| } | |
| } | |
| impl HnswIndexConfig { | |
| pub(crate) fn from_segment( | |
| segment: &Segment, | |
| persist_path: &std::path::Path, | |
| ) -> Result<HnswIndexConfig, Box<dyn ChromaError>> { | |
| let persist_path = match persist_path.to_str() { | |
| Some(persist_path) => persist_path, | |
| None => { | |
| return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( | |
| "persist_path".to_string(), | |
| ))) | |
| } | |
| }; | |
| let metadata = match &segment.metadata { | |
| Some(metadata) => metadata, | |
| None => { | |
| // TODO: This should error, but the configuration is not stored correctly | |
| // after the configuration is refactored to be always stored and doesn't rely on defaults we can fix this | |
| return Ok(HnswIndexConfig { | |
| max_elements: 1000, | |
| m: 16, | |
| ef_construction: 100, | |
| ef_search: 10, | |
| random_seed: 0, | |
| persist_path: persist_path.to_string(), | |
| }); | |
| // return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( | |
| // "metadata".to_string(), | |
| // ))) | |
| } | |
| }; | |
| fn get_metadata_value_as<'a, T>( | |
| metadata: &'a Metadata, | |
| key: &str, | |
| ) -> Result<T, Box<dyn ChromaError>> | |
| where | |
| T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>, | |
| { | |
| let res = match metadata.get(key) { | |
| Some(value) => T::try_from(value), | |
| None => { | |
| return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( | |
| key.to_string(), | |
| ))) | |
| } | |
| }; | |
| match res { | |
| Ok(value) => Ok(value), | |
| Err(e) => Err(Box::new(e)), | |
| } | |
| } | |
| let max_elements = get_metadata_value_as::<i32>(metadata, "hsnw:max_elements")?; | |
| let m = get_metadata_value_as::<i32>(metadata, "hnsw:m")?; | |
| let ef_construction = get_metadata_value_as::<i32>(metadata, "hnsw:ef_construction")?; | |
| let ef_search = get_metadata_value_as::<i32>(metadata, "hnsw:ef_search")?; | |
| return Ok(HnswIndexConfig { | |
| max_elements: max_elements as usize, | |
| m: m as usize, | |
| ef_construction: ef_construction as usize, | |
| ef_search: ef_search as usize, | |
| random_seed: 0, | |
| persist_path: persist_path.to_string(), | |
| }); | |
| } | |
| } | |
| /// The HnswIndex struct. | |
| /// # Description | |
| /// This struct wraps a pointer to the C++ HnswIndex class and presents a safe Rust interface. | |
| /// # Notes | |
| /// This struct is not thread safe for concurrent reads and writes. Callers should | |
| /// synchronize access to the index between reads and writes. | |
| pub(crate) struct HnswIndex { | |
| ffi_ptr: *const IndexPtrFFI, | |
| dimensionality: i32, | |
| } | |
| // Make index sync, we should wrap index so that it is sync in the way we expect but for now this implements the trait | |
| unsafe impl Sync for HnswIndex {} | |
| unsafe impl Send for HnswIndex {} | |
| pub(crate) enum HnswIndexInitError { | |
| NoConfigProvided, | |
| InvalidDistanceFunction(String), | |
| InvalidPath(String), | |
| } | |
| impl ChromaError for HnswIndexInitError { | |
| fn code(&self) -> ErrorCodes { | |
| crate::errors::ErrorCodes::InvalidArgument | |
| } | |
| } | |
| impl Index<HnswIndexConfig> for HnswIndex { | |
| fn init( | |
| index_config: &IndexConfig, | |
| hnsw_config: Option<&HnswIndexConfig>, | |
| ) -> Result<Self, Box<dyn ChromaError>> { | |
| match hnsw_config { | |
| None => return Err(Box::new(HnswIndexInitError::NoConfigProvided)), | |
| Some(config) => { | |
| let distance_function_string: String = | |
| index_config.distance_function.clone().into(); | |
| let space_name = match CString::new(distance_function_string) { | |
| Ok(space_name) => space_name, | |
| Err(e) => { | |
| return Err(Box::new(HnswIndexInitError::InvalidDistanceFunction( | |
| e.to_string(), | |
| ))) | |
| } | |
| }; | |
| let ffi_ptr = | |
| unsafe { create_index(space_name.as_ptr(), index_config.dimensionality) }; | |
| let path = match CString::new(config.persist_path.clone()) { | |
| Ok(path) => path, | |
| Err(e) => return Err(Box::new(HnswIndexInitError::InvalidPath(e.to_string()))), | |
| }; | |
| unsafe { | |
| init_index( | |
| ffi_ptr, | |
| config.max_elements, | |
| config.m, | |
| config.ef_construction, | |
| config.random_seed, | |
| true, | |
| true, | |
| path.as_ptr(), | |
| ); | |
| } | |
| let hnsw_index = HnswIndex { | |
| ffi_ptr: ffi_ptr, | |
| dimensionality: index_config.dimensionality, | |
| }; | |
| hnsw_index.set_ef(config.ef_search); | |
| Ok(hnsw_index) | |
| } | |
| } | |
| } | |
| fn add(&self, id: usize, vector: &[f32]) { | |
| unsafe { add_item(self.ffi_ptr, vector.as_ptr(), id, false) } | |
| } | |
| fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) { | |
| let mut ids = vec![0usize; k]; | |
| let mut distance = vec![0.0f32; k]; | |
| unsafe { | |
| knn_query( | |
| self.ffi_ptr, | |
| vector.as_ptr(), | |
| k, | |
| ids.as_mut_ptr(), | |
| distance.as_mut_ptr(), | |
| ); | |
| } | |
| return (ids, distance); | |
| } | |
| fn get(&self, id: usize) -> Option<Vec<f32>> { | |
| unsafe { | |
| let mut data: Vec<f32> = vec![0.0f32; self.dimensionality as usize]; | |
| get_item(self.ffi_ptr, id, data.as_mut_ptr()); | |
| return Some(data); | |
| } | |
| } | |
| } | |
| impl PersistentIndex<HnswIndexConfig> for HnswIndex { | |
| fn save(&self) -> Result<(), Box<dyn ChromaError>> { | |
| unsafe { persist_dirty(self.ffi_ptr) }; | |
| Ok(()) | |
| } | |
| fn load(path: &str, index_config: &IndexConfig) -> Result<Self, Box<dyn ChromaError>> { | |
| let distance_function_string: String = index_config.distance_function.clone().into(); | |
| let space_name = match CString::new(distance_function_string) { | |
| Ok(space_name) => space_name, | |
| Err(e) => { | |
| return Err(Box::new(HnswIndexInitError::InvalidDistanceFunction( | |
| e.to_string(), | |
| ))) | |
| } | |
| }; | |
| let ffi_ptr = unsafe { create_index(space_name.as_ptr(), index_config.dimensionality) }; | |
| let path = match CString::new(path.to_string()) { | |
| Ok(path) => path, | |
| Err(e) => return Err(Box::new(HnswIndexInitError::InvalidPath(e.to_string()))), | |
| }; | |
| unsafe { | |
| load_index(ffi_ptr, path.as_ptr(), true, true); | |
| } | |
| let hnsw_index = HnswIndex { | |
| ffi_ptr: ffi_ptr, | |
| dimensionality: index_config.dimensionality, | |
| }; | |
| Ok(hnsw_index) | |
| } | |
| } | |
| impl HnswIndex { | |
| pub fn set_ef(&self, ef: usize) { | |
| unsafe { set_ef(self.ffi_ptr, ef as c_int) } | |
| } | |
| pub fn get_ef(&self) -> usize { | |
| unsafe { get_ef(self.ffi_ptr) as usize } | |
| } | |
| } | |
| extern "C" { | |
| fn create_index(space_name: *const c_char, dim: c_int) -> *const IndexPtrFFI; | |
| fn init_index( | |
| index: *const IndexPtrFFI, | |
| max_elements: usize, | |
| M: usize, | |
| ef_construction: usize, | |
| random_seed: usize, | |
| allow_replace_deleted: bool, | |
| is_persistent: bool, | |
| path: *const c_char, | |
| ); | |
| fn load_index( | |
| index: *const IndexPtrFFI, | |
| path: *const c_char, | |
| allow_replace_deleted: bool, | |
| is_persistent_index: bool, | |
| ); | |
| fn persist_dirty(index: *const IndexPtrFFI); | |
| fn add_item(index: *const IndexPtrFFI, data: *const f32, id: usize, replace_deleted: bool); | |
| fn get_item(index: *const IndexPtrFFI, id: usize, data: *mut f32); | |
| fn knn_query( | |
| index: *const IndexPtrFFI, | |
| query_vector: *const f32, | |
| k: usize, | |
| ids: *mut usize, | |
| distance: *mut f32, | |
| ); | |
| fn get_ef(index: *const IndexPtrFFI) -> c_int; | |
| fn set_ef(index: *const IndexPtrFFI, ef: c_int); | |
| } | |
| pub mod test { | |
| use super::*; | |
| use crate::index::types::DistanceFunction; | |
| use crate::index::utils; | |
| use rand::Rng; | |
| use rayon::prelude::*; | |
| use rayon::ThreadPoolBuilder; | |
| use tempfile::tempdir; | |
| fn it_initializes_and_can_set_get_ef() { | |
| let n = 1000; | |
| let d: usize = 960; | |
| let tmp_dir = tempdir().unwrap(); | |
| let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | |
| let distance_function = DistanceFunction::Euclidean; | |
| let mut index = HnswIndex::init( | |
| &IndexConfig { | |
| dimensionality: d as i32, | |
| distance_function: distance_function, | |
| }, | |
| Some(&HnswIndexConfig { | |
| max_elements: n, | |
| m: 16, | |
| ef_construction: 100, | |
| ef_search: 10, | |
| random_seed: 0, | |
| persist_path: persist_path, | |
| }), | |
| ); | |
| match index { | |
| Err(e) => panic!("Error initializing index: {}", e), | |
| Ok(index) => { | |
| assert_eq!(index.get_ef(), 10); | |
| index.set_ef(100); | |
| assert_eq!(index.get_ef(), 100); | |
| } | |
| } | |
| } | |
| fn it_can_add_parallel() { | |
| let n = 10; | |
| let d: usize = 960; | |
| let distance_function = DistanceFunction::InnerProduct; | |
| let tmp_dir = tempdir().unwrap(); | |
| let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | |
| let index = HnswIndex::init( | |
| &IndexConfig { | |
| dimensionality: d as i32, | |
| distance_function: distance_function, | |
| }, | |
| Some(&HnswIndexConfig { | |
| max_elements: n, | |
| m: 16, | |
| ef_construction: 100, | |
| ef_search: 100, | |
| random_seed: 0, | |
| persist_path: persist_path, | |
| }), | |
| ); | |
| let index = match index { | |
| Err(e) => panic!("Error initializing index: {}", e), | |
| Ok(index) => index, | |
| }; | |
| let ids: Vec<usize> = (0..n).collect(); | |
| // Add data in parallel, using global pool for testing | |
| ThreadPoolBuilder::new() | |
| .num_threads(12) | |
| .build_global() | |
| .unwrap(); | |
| let mut rng: rand::prelude::ThreadRng = rand::thread_rng(); | |
| let mut datas = Vec::new(); | |
| for i in 0..n { | |
| let mut data: Vec<f32> = Vec::new(); | |
| for i in 0..960 { | |
| data.push(rng.gen()); | |
| } | |
| datas.push(data); | |
| } | |
| (0..n).into_par_iter().for_each(|i| { | |
| let data = &datas[i]; | |
| index.add(ids[i], data); | |
| }); | |
| // Get the data and check it | |
| let mut i = 0; | |
| for id in ids { | |
| let actual_data = index.get(id); | |
| match actual_data { | |
| None => panic!("No data found for id: {}", id), | |
| Some(actual_data) => { | |
| assert_eq!(actual_data.len(), d); | |
| for j in 0..d { | |
| // Floating point epsilon comparison | |
| assert!((actual_data[j] - datas[i][j]).abs() < 0.00001); | |
| } | |
| } | |
| } | |
| i += 1; | |
| } | |
| } | |
| fn it_can_add_and_basic_query() { | |
| let n = 1; | |
| let d: usize = 960; | |
| let distance_function = DistanceFunction::Euclidean; | |
| let tmp_dir = tempdir().unwrap(); | |
| let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | |
| let index = HnswIndex::init( | |
| &IndexConfig { | |
| dimensionality: d as i32, | |
| distance_function: distance_function, | |
| }, | |
| Some(&HnswIndexConfig { | |
| max_elements: n, | |
| m: 16, | |
| ef_construction: 100, | |
| ef_search: 100, | |
| random_seed: 0, | |
| persist_path: persist_path, | |
| }), | |
| ); | |
| let index = match index { | |
| Err(e) => panic!("Error initializing index: {}", e), | |
| Ok(index) => index, | |
| }; | |
| assert_eq!(index.get_ef(), 100); | |
| let data: Vec<f32> = utils::generate_random_data(n, d); | |
| let ids: Vec<usize> = (0..n).collect(); | |
| (0..n).into_iter().for_each(|i| { | |
| let data = &data[i * d..(i + 1) * d]; | |
| index.add(ids[i], data); | |
| }); | |
| // Get the data and check it | |
| let mut i = 0; | |
| for id in ids { | |
| let actual_data = index.get(id); | |
| match actual_data { | |
| None => panic!("No data found for id: {}", id), | |
| Some(actual_data) => { | |
| assert_eq!(actual_data.len(), d); | |
| for j in 0..d { | |
| // Floating point epsilon comparison | |
| assert!((actual_data[j] - data[i * d + j]).abs() < 0.00001); | |
| } | |
| } | |
| } | |
| i += 1; | |
| } | |
| // Query the data | |
| let query = &data[0..d]; | |
| let (ids, distances) = index.query(query, 1); | |
| assert_eq!(ids.len(), 1); | |
| assert_eq!(distances.len(), 1); | |
| assert_eq!(ids[0], 0); | |
| assert_eq!(distances[0], 0.0); | |
| } | |
| fn it_can_persist_and_load() { | |
| let n = 1000; | |
| let d: usize = 960; | |
| let distance_function = DistanceFunction::Euclidean; | |
| let tmp_dir = tempdir().unwrap(); | |
| let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | |
| let index = HnswIndex::init( | |
| &IndexConfig { | |
| dimensionality: d as i32, | |
| distance_function: distance_function.clone(), | |
| }, | |
| Some(&HnswIndexConfig { | |
| max_elements: n, | |
| m: 32, | |
| ef_construction: 100, | |
| ef_search: 100, | |
| random_seed: 0, | |
| persist_path: persist_path.clone(), | |
| }), | |
| ); | |
| let index = match index { | |
| Err(e) => panic!("Error initializing index: {}", e), | |
| Ok(index) => index, | |
| }; | |
| let data: Vec<f32> = utils::generate_random_data(n, d); | |
| let ids: Vec<usize> = (0..n).collect(); | |
| (0..n).into_iter().for_each(|i| { | |
| let data = &data[i * d..(i + 1) * d]; | |
| index.add(ids[i], data); | |
| }); | |
| // Persist the index | |
| let res = index.save(); | |
| match res { | |
| Err(e) => panic!("Error saving index: {}", e), | |
| Ok(_) => {} | |
| } | |
| // Load the index | |
| let index = HnswIndex::load( | |
| &persist_path, | |
| &IndexConfig { | |
| dimensionality: d as i32, | |
| distance_function: distance_function, | |
| }, | |
| ); | |
| let index = match index { | |
| Err(e) => panic!("Error loading index: {}", e), | |
| Ok(index) => index, | |
| }; | |
| // TODO: This should be set by the load | |
| index.set_ef(100); | |
| // Query the data | |
| let query = &data[0..d]; | |
| let (ids, distances) = index.query(query, 1); | |
| assert_eq!(ids.len(), 1); | |
| assert_eq!(distances.len(), 1); | |
| assert_eq!(ids[0], 0); | |
| assert_eq!(distances[0], 0.0); | |
| // Get the data and check it | |
| let mut i = 0; | |
| for id in ids { | |
| let actual_data = index.get(id); | |
| match actual_data { | |
| None => panic!("No data found for id: {}", id), | |
| Some(actual_data) => { | |
| assert_eq!(actual_data.len(), d); | |
| for j in 0..d { | |
| assert_eq!(actual_data[j], data[i * d + j]); | |
| } | |
| } | |
| } | |
| i += 1; | |
| } | |
| } | |
| } | |