fcastrovilli commited on
Commit
63b0848
·
1 Parent(s): 6284a4a

refactor: computeSimilarities

Browse files
Files changed (2) hide show
  1. README.md +18 -2
  2. main.py +143 -5
README.md CHANGED
@@ -54,8 +54,24 @@ docker run -p 7860:7860 imagebind-api
54
 
55
  The API will be available at `http://localhost:7860` with the following endpoints:
56
 
57
- - POST `/compute_embeddings`: Generate embeddings for images, audio files, and text
58
- - POST `/compute_similarities`: Compute similarities between embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  For detailed API documentation, visit `http://localhost:7860/docs`
61
 
 
54
 
55
  The API will be available at `http://localhost:7860` with the following endpoints:
56
 
57
+ ### POST `/compute_embeddings`
58
+
59
+ Generate embeddings for images, audio files, and text.
60
+
61
+ ### POST `/compute_similarities`
62
+
63
+ Compute similarities between embeddings with advanced filtering options:
64
+
65
+ - Threshold filtering for minimum similarity scores
66
+ - Top-K results limitation
67
+ - Optional self-similarity inclusion
68
+ - Score normalization
69
+ - Detailed match information including original file/text references
70
+ - Statistical analysis of similarity scores
71
+
72
+ ### GET `/health`
73
+
74
+ Basic health check endpoint
75
 
76
  For detailed API documentation, visit `http://localhost:7860/docs`
77
 
main.py CHANGED
@@ -119,8 +119,62 @@ class EmbeddingResponse(BaseModel):
119
  embeddings: dict
120
  file_names: dict
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  class SimilarityResponse(BaseModel):
123
- similarities: dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  @app.post("/compute_embeddings", response_model=EmbeddingResponse)
126
  async def generate_embeddings(
@@ -202,12 +256,96 @@ async def generate_embeddings(
202
 
203
  @app.post("/compute_similarities", response_model=SimilarityResponse)
204
  async def compute_similarities(
205
- embeddings: Dict[str, List[List[float]]],
 
206
  credentials: HTTPAuthorizationCredentials = Depends(verify_token)
207
  ):
208
- """Compute similarities from provided embeddings."""
209
- similarities = embedding_manager.compute_similarities(embeddings)
210
- return SimilarityResponse(similarities=similarities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  @app.get("/health")
213
  async def health_check(
 
119
  embeddings: dict
120
  file_names: dict
121
 
122
+ class SimilarityRequest(BaseModel):
123
+ embeddings: Dict[str, List[List[float]]]
124
+ threshold: float = 0.5
125
+ top_k: int | None = None
126
+ include_self_similarity: bool = False
127
+ normalize_scores: bool = True
128
+
129
+ class SimilarityMatch(BaseModel):
130
+ index_a: int
131
+ index_b: int
132
+ score: float
133
+ modality_a: str
134
+ modality_b: str
135
+ item_a: str # Original item identifier (filename or text)
136
+ item_b: str # Original item identifier (filename or text)
137
+
138
  class SimilarityResponse(BaseModel):
139
+ matches: List[SimilarityMatch]
140
+ statistics: Dict[str, float] # Contains avg_score, max_score, etc.
141
+ modality_pairs: List[str] # Lists which modality comparisons were performed
142
+
143
+ class ModalityPair:
144
+ def __init__(self, mod1: str, mod2: str):
145
+ self.mod1 = min(mod1, mod2) # Ensure consistent ordering
146
+ self.mod2 = max(mod1, mod2)
147
+
148
+ def __str__(self):
149
+ return f"{self.mod1}_to_{self.mod2}"
150
+
151
+ def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor:
152
+ """Compute cosine similarity between two sets of embeddings."""
153
+ # Normalize embeddings if requested
154
+ if normalize:
155
+ tensor1 = torch.nn.functional.normalize(tensor1, dim=1)
156
+ tensor2 = torch.nn.functional.normalize(tensor2, dim=1)
157
+
158
+ # Compute similarity matrix
159
+ similarity = torch.matmul(tensor1, tensor2.T)
160
+
161
+ return similarity
162
+
163
+ def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]:
164
+ """Get top-k matches from a similarity matrix."""
165
+ if top_k is None:
166
+ top_k = similarity_matrix.numel()
167
+
168
+ # Flatten and get top-k indices
169
+ flat_sim = similarity_matrix.flatten()
170
+ top_k = min(top_k, flat_sim.numel())
171
+ values, indices = torch.topk(flat_sim, k=top_k)
172
+
173
+ # Convert flat indices to 2D indices
174
+ rows = indices // similarity_matrix.size(1)
175
+ cols = indices % similarity_matrix.size(1)
176
+
177
+ return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)]
178
 
179
  @app.post("/compute_embeddings", response_model=EmbeddingResponse)
180
  async def generate_embeddings(
 
256
 
257
  @app.post("/compute_similarities", response_model=SimilarityResponse)
258
  async def compute_similarities(
259
+ request: SimilarityRequest,
260
+ file_names: Dict[str, List[str]], # Maps modality to list of file/text names
261
  credentials: HTTPAuthorizationCredentials = Depends(verify_token)
262
  ):
263
+ """
264
+ Compute cross-modal similarities with advanced filtering and matching options.
265
+
266
+ Parameters:
267
+ - embeddings: Dict mapping modality to embedding tensors
268
+ - threshold: Minimum similarity score to include in results
269
+ - top_k: Maximum number of matches to return (per modality pair)
270
+ - include_self_similarity: Whether to include same-item comparisons
271
+ - normalize_scores: Whether to normalize embeddings before comparison
272
+ - file_names: Dict mapping modality to list of original file/text names
273
+ """
274
+
275
+ matches = []
276
+ statistics = {
277
+ "avg_score": 0.0,
278
+ "max_score": 0.0,
279
+ "min_score": 1.0,
280
+ "total_comparisons": 0
281
+ }
282
+
283
+ # Convert embeddings to tensors
284
+ tensors = {
285
+ k: torch.tensor(v) for k, v in request.embeddings.items()
286
+ if isinstance(v, (list, np.ndarray)) and len(v) > 0
287
+ }
288
+
289
+ modality_pairs = []
290
+ all_scores = []
291
+
292
+ # Get all possible modality pairs
293
+ modalities = list(tensors.keys())
294
+ for i, mod1 in enumerate(modalities):
295
+ for mod2 in modalities[i:]: # Include self-comparisons if requested
296
+ if mod1 == mod2 and not request.include_self_similarity:
297
+ continue
298
+
299
+ pair = ModalityPair(mod1, mod2)
300
+ modality_pairs.append(str(pair))
301
+
302
+ # Compute similarity matrix
303
+ sim_matrix = compute_similarity_matrix(
304
+ tensors[mod1],
305
+ tensors[mod2],
306
+ normalize=request.normalize_scores
307
+ )
308
+
309
+ # Get top matches
310
+ top_matches = get_top_k_matches(sim_matrix, request.top_k)
311
+
312
+ # Filter by threshold and create match objects
313
+ for idx_a, idx_b, score in top_matches:
314
+ if score < request.threshold:
315
+ continue
316
+
317
+ # Skip self-matches if not requested
318
+ if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity:
319
+ continue
320
+
321
+ matches.append(SimilarityMatch(
322
+ index_a=idx_a,
323
+ index_b=idx_b,
324
+ score=float(score),
325
+ modality_a=mod1,
326
+ modality_b=mod2,
327
+ item_a=file_names[mod1][idx_a],
328
+ item_b=file_names[mod2][idx_b]
329
+ ))
330
+ all_scores.append(score)
331
+
332
+ # Compute statistics
333
+ if all_scores:
334
+ statistics.update({
335
+ "avg_score": float(np.mean(all_scores)),
336
+ "max_score": float(np.max(all_scores)),
337
+ "min_score": float(np.min(all_scores)),
338
+ "total_comparisons": len(all_scores)
339
+ })
340
+
341
+ # Sort matches by score in descending order
342
+ matches.sort(key=lambda x: x.score, reverse=True)
343
+
344
+ return SimilarityResponse(
345
+ matches=matches,
346
+ statistics=statistics,
347
+ modality_pairs=modality_pairs
348
+ )
349
 
350
  @app.get("/health")
351
  async def health_check(