File size: 8,139 Bytes
8ef2d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
//! # Flat Index Adapter
//!
//! Brute force nearest neighbor search.
//! Compares query against ALL points - O(n) per query.
//!
//! Good for:
//! - Testing
//! - Small datasets (< 10,000 points)
//! - When exact results are required
//!
//! Not good for:
//! - Large datasets (use HNSW instead)

use std::collections::HashMap;
use std::sync::Arc;

use crate::core::{Id, Point};
use crate::core::proximity::Proximity;
use crate::ports::{Near, NearError, NearResult, SearchResult};

/// Brute force index - searches all points
pub struct FlatIndex {
    /// Stored points (ID -> Point)
    points: HashMap<Id, Point>,

    /// Expected dimensionality
    dimensionality: usize,

    /// Proximity function to use
    proximity: Arc<dyn Proximity>,

    /// Whether higher proximity = more similar
    /// true for cosine/dot product, false for euclidean
    higher_is_better: bool,
}

impl FlatIndex {
    /// Create a new flat index
    ///
    /// `higher_is_better` indicates whether higher proximity scores mean more similar.
    /// - `true` for Cosine, DotProduct
    /// - `false` for Euclidean, Manhattan
    pub fn new(
        dimensionality: usize,
        proximity: Arc<dyn Proximity>,
        higher_is_better: bool,
    ) -> Self {
        Self {
            points: HashMap::new(),
            dimensionality,
            proximity,
            higher_is_better,
        }
    }

    /// Create with cosine similarity (higher = better)
    pub fn cosine(dimensionality: usize) -> Self {
        use crate::core::proximity::Cosine;
        Self::new(dimensionality, Arc::new(Cosine), true)
    }

    /// Create with euclidean distance (lower = better)
    pub fn euclidean(dimensionality: usize) -> Self {
        use crate::core::proximity::Euclidean;
        Self::new(dimensionality, Arc::new(Euclidean), false)
    }

    /// Sort results by relevance
    fn sort_results(&self, results: &mut Vec<SearchResult>) {
        if self.higher_is_better {
            // Higher score = more relevant, sort descending
            results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
        } else {
            // Lower score = more relevant, sort ascending
            results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
        }
    }
}

impl Near for FlatIndex {
    fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
        // Check dimensionality
        if query.dimensionality() != self.dimensionality {
            return Err(NearError::DimensionalityMismatch {
                expected: self.dimensionality,
                got: query.dimensionality(),
            });
        }

        // Compute proximity to all points
        let mut results: Vec<SearchResult> = self
            .points
            .iter()
            .map(|(id, point)| {
                let score = self.proximity.proximity(query, point);
                SearchResult::new(*id, score)
            })
            .collect();

        // Sort by relevance
        self.sort_results(&mut results);

        // Take top k
        results.truncate(k);

        Ok(results)
    }

    fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
        // Check dimensionality
        if query.dimensionality() != self.dimensionality {
            return Err(NearError::DimensionalityMismatch {
                expected: self.dimensionality,
                got: query.dimensionality(),
            });
        }

        // Find all points within threshold
        let mut results: Vec<SearchResult> = self
            .points
            .iter()
            .filter_map(|(id, point)| {
                let score = self.proximity.proximity(query, point);
                let within = if self.higher_is_better {
                    score >= threshold
                } else {
                    score <= threshold
                };
                if within {
                    Some(SearchResult::new(*id, score))
                } else {
                    None
                }
            })
            .collect();

        // Sort by relevance
        self.sort_results(&mut results);

        Ok(results)
    }

    fn add(&mut self, id: Id, point: &Point) -> NearResult<()> {
        if point.dimensionality() != self.dimensionality {
            return Err(NearError::DimensionalityMismatch {
                expected: self.dimensionality,
                got: point.dimensionality(),
            });
        }

        self.points.insert(id, point.clone());
        Ok(())
    }

    fn remove(&mut self, id: Id) -> NearResult<()> {
        self.points.remove(&id);
        Ok(())
    }

    fn rebuild(&mut self) -> NearResult<()> {
        // Flat index doesn't need rebuilding
        Ok(())
    }

    fn is_ready(&self) -> bool {
        true // Always ready
    }

    fn len(&self) -> usize {
        self.points.len()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn setup_index() -> FlatIndex {
        let mut index = FlatIndex::cosine(3);

        // Add some test points
        let points = vec![
            (Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])),
            (Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])),
            (Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])),
            (Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()),
        ];

        for (id, point) in points {
            index.add(id, &point).unwrap();
        }

        index
    }

    #[test]
    fn test_flat_index_near() {
        let index = setup_index();

        // Query for points near [1, 0, 0]
        let query = Point::new(vec![1.0, 0.0, 0.0]);
        let results = index.near(&query, 2).unwrap();

        assert_eq!(results.len(), 2);

        // First result should be [1, 0, 0] with cosine = 1.0
        assert_eq!(results[0].id, Id::from_bytes([1; 16]));
        assert!((results[0].score - 1.0).abs() < 0.0001);
    }

    #[test]
    fn test_flat_index_within_cosine() {
        let index = setup_index();

        // Find all points with cosine > 0.5 to [1, 0, 0]
        let query = Point::new(vec![1.0, 0.0, 0.0]);
        let results = index.within(&query, 0.5).unwrap();

        // Should find [1,0,0] (cosine=1.0) and [0.7,0.7,0] (cosine≈0.707)
        assert_eq!(results.len(), 2);
    }

    #[test]
    fn test_flat_index_euclidean() {
        let mut index = FlatIndex::euclidean(2);

        index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap();
        index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap();
        index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap();

        let query = Point::new(vec![0.0, 0.0]);
        let results = index.near(&query, 2).unwrap();

        // Nearest should be [0,0] with distance 0
        assert_eq!(results[0].id, Id::from_bytes([1; 16]));
        assert!((results[0].score - 0.0).abs() < 0.0001);

        // Second nearest should be [1,0] with distance 1
        assert_eq!(results[1].id, Id::from_bytes([2; 16]));
        assert!((results[1].score - 1.0).abs() < 0.0001);
    }

    #[test]
    fn test_flat_index_add_remove() {
        let mut index = FlatIndex::cosine(3);

        let id = Id::from_bytes([1; 16]);
        let point = Point::new(vec![1.0, 0.0, 0.0]);

        index.add(id, &point).unwrap();
        assert_eq!(index.len(), 1);

        index.remove(id).unwrap();
        assert_eq!(index.len(), 0);
    }

    #[test]
    fn test_flat_index_dimensionality_check() {
        let mut index = FlatIndex::cosine(3);

        let wrong_dims = Point::new(vec![1.0, 0.0]); // 2 dims
        let result = index.add(Id::now(), &wrong_dims);

        match result {
            Err(NearError::DimensionalityMismatch { expected, got }) => {
                assert_eq!(expected, 3);
                assert_eq!(got, 2);
            }
            _ => panic!("Expected DimensionalityMismatch error"),
        }
    }

    #[test]
    fn test_flat_index_ready() {
        let index = FlatIndex::cosine(3);
        assert!(index.is_ready());
    }
}