Commit
·
63b0848
1
Parent(s):
6284a4a
refactor: computeSimilarities
Browse files
README.md
CHANGED
|
@@ -54,8 +54,24 @@ docker run -p 7860:7860 imagebind-api
|
|
| 54 |
|
| 55 |
The API will be available at `http://localhost:7860` with the following endpoints:
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
For detailed API documentation, visit `http://localhost:7860/docs`
|
| 61 |
|
|
|
|
| 54 |
|
| 55 |
The API will be available at `http://localhost:7860` with the following endpoints:
|
| 56 |
|
| 57 |
+
### POST `/compute_embeddings`
|
| 58 |
+
|
| 59 |
+
Generate embeddings for images, audio files, and text.
|
| 60 |
+
|
| 61 |
+
### POST `/compute_similarities`
|
| 62 |
+
|
| 63 |
+
Compute similarities between embeddings with advanced filtering options:
|
| 64 |
+
|
| 65 |
+
- Threshold filtering for minimum similarity scores
|
| 66 |
+
- Top-K results limitation
|
| 67 |
+
- Optional self-similarity inclusion
|
| 68 |
+
- Score normalization
|
| 69 |
+
- Detailed match information including original file/text references
|
| 70 |
+
- Statistical analysis of similarity scores
|
| 71 |
+
|
| 72 |
+
### GET `/health`
|
| 73 |
+
|
| 74 |
+
Basic health check endpoint
|
| 75 |
|
| 76 |
For detailed API documentation, visit `http://localhost:7860/docs`
|
| 77 |
|
main.py
CHANGED
|
@@ -119,8 +119,62 @@ class EmbeddingResponse(BaseModel):
|
|
| 119 |
embeddings: dict
|
| 120 |
file_names: dict
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
class SimilarityResponse(BaseModel):
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
@app.post("/compute_embeddings", response_model=EmbeddingResponse)
|
| 126 |
async def generate_embeddings(
|
|
@@ -202,12 +256,96 @@ async def generate_embeddings(
|
|
| 202 |
|
| 203 |
@app.post("/compute_similarities", response_model=SimilarityResponse)
|
| 204 |
async def compute_similarities(
|
| 205 |
-
|
|
|
|
| 206 |
credentials: HTTPAuthorizationCredentials = Depends(verify_token)
|
| 207 |
):
|
| 208 |
-
"""
|
| 209 |
-
similarities
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
@app.get("/health")
|
| 213 |
async def health_check(
|
|
|
|
| 119 |
embeddings: dict
|
| 120 |
file_names: dict
|
| 121 |
|
| 122 |
+
class SimilarityRequest(BaseModel):
|
| 123 |
+
embeddings: Dict[str, List[List[float]]]
|
| 124 |
+
threshold: float = 0.5
|
| 125 |
+
top_k: int | None = None
|
| 126 |
+
include_self_similarity: bool = False
|
| 127 |
+
normalize_scores: bool = True
|
| 128 |
+
|
| 129 |
+
class SimilarityMatch(BaseModel):
|
| 130 |
+
index_a: int
|
| 131 |
+
index_b: int
|
| 132 |
+
score: float
|
| 133 |
+
modality_a: str
|
| 134 |
+
modality_b: str
|
| 135 |
+
item_a: str # Original item identifier (filename or text)
|
| 136 |
+
item_b: str # Original item identifier (filename or text)
|
| 137 |
+
|
| 138 |
class SimilarityResponse(BaseModel):
|
| 139 |
+
matches: List[SimilarityMatch]
|
| 140 |
+
statistics: Dict[str, float] # Contains avg_score, max_score, etc.
|
| 141 |
+
modality_pairs: List[str] # Lists which modality comparisons were performed
|
| 142 |
+
|
| 143 |
+
class ModalityPair:
|
| 144 |
+
def __init__(self, mod1: str, mod2: str):
|
| 145 |
+
self.mod1 = min(mod1, mod2) # Ensure consistent ordering
|
| 146 |
+
self.mod2 = max(mod1, mod2)
|
| 147 |
+
|
| 148 |
+
def __str__(self):
|
| 149 |
+
return f"{self.mod1}_to_{self.mod2}"
|
| 150 |
+
|
| 151 |
+
def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor:
|
| 152 |
+
"""Compute cosine similarity between two sets of embeddings."""
|
| 153 |
+
# Normalize embeddings if requested
|
| 154 |
+
if normalize:
|
| 155 |
+
tensor1 = torch.nn.functional.normalize(tensor1, dim=1)
|
| 156 |
+
tensor2 = torch.nn.functional.normalize(tensor2, dim=1)
|
| 157 |
+
|
| 158 |
+
# Compute similarity matrix
|
| 159 |
+
similarity = torch.matmul(tensor1, tensor2.T)
|
| 160 |
+
|
| 161 |
+
return similarity
|
| 162 |
+
|
| 163 |
+
def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]:
|
| 164 |
+
"""Get top-k matches from a similarity matrix."""
|
| 165 |
+
if top_k is None:
|
| 166 |
+
top_k = similarity_matrix.numel()
|
| 167 |
+
|
| 168 |
+
# Flatten and get top-k indices
|
| 169 |
+
flat_sim = similarity_matrix.flatten()
|
| 170 |
+
top_k = min(top_k, flat_sim.numel())
|
| 171 |
+
values, indices = torch.topk(flat_sim, k=top_k)
|
| 172 |
+
|
| 173 |
+
# Convert flat indices to 2D indices
|
| 174 |
+
rows = indices // similarity_matrix.size(1)
|
| 175 |
+
cols = indices % similarity_matrix.size(1)
|
| 176 |
+
|
| 177 |
+
return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)]
|
| 178 |
|
| 179 |
@app.post("/compute_embeddings", response_model=EmbeddingResponse)
|
| 180 |
async def generate_embeddings(
|
|
|
|
| 256 |
|
| 257 |
@app.post("/compute_similarities", response_model=SimilarityResponse)
|
| 258 |
async def compute_similarities(
|
| 259 |
+
request: SimilarityRequest,
|
| 260 |
+
file_names: Dict[str, List[str]], # Maps modality to list of file/text names
|
| 261 |
credentials: HTTPAuthorizationCredentials = Depends(verify_token)
|
| 262 |
):
|
| 263 |
+
"""
|
| 264 |
+
Compute cross-modal similarities with advanced filtering and matching options.
|
| 265 |
+
|
| 266 |
+
Parameters:
|
| 267 |
+
- embeddings: Dict mapping modality to embedding tensors
|
| 268 |
+
- threshold: Minimum similarity score to include in results
|
| 269 |
+
- top_k: Maximum number of matches to return (per modality pair)
|
| 270 |
+
- include_self_similarity: Whether to include same-item comparisons
|
| 271 |
+
- normalize_scores: Whether to normalize embeddings before comparison
|
| 272 |
+
- file_names: Dict mapping modality to list of original file/text names
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
matches = []
|
| 276 |
+
statistics = {
|
| 277 |
+
"avg_score": 0.0,
|
| 278 |
+
"max_score": 0.0,
|
| 279 |
+
"min_score": 1.0,
|
| 280 |
+
"total_comparisons": 0
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
# Convert embeddings to tensors
|
| 284 |
+
tensors = {
|
| 285 |
+
k: torch.tensor(v) for k, v in request.embeddings.items()
|
| 286 |
+
if isinstance(v, (list, np.ndarray)) and len(v) > 0
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
modality_pairs = []
|
| 290 |
+
all_scores = []
|
| 291 |
+
|
| 292 |
+
# Get all possible modality pairs
|
| 293 |
+
modalities = list(tensors.keys())
|
| 294 |
+
for i, mod1 in enumerate(modalities):
|
| 295 |
+
for mod2 in modalities[i:]: # Include self-comparisons if requested
|
| 296 |
+
if mod1 == mod2 and not request.include_self_similarity:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
pair = ModalityPair(mod1, mod2)
|
| 300 |
+
modality_pairs.append(str(pair))
|
| 301 |
+
|
| 302 |
+
# Compute similarity matrix
|
| 303 |
+
sim_matrix = compute_similarity_matrix(
|
| 304 |
+
tensors[mod1],
|
| 305 |
+
tensors[mod2],
|
| 306 |
+
normalize=request.normalize_scores
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Get top matches
|
| 310 |
+
top_matches = get_top_k_matches(sim_matrix, request.top_k)
|
| 311 |
+
|
| 312 |
+
# Filter by threshold and create match objects
|
| 313 |
+
for idx_a, idx_b, score in top_matches:
|
| 314 |
+
if score < request.threshold:
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
# Skip self-matches if not requested
|
| 318 |
+
if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity:
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
matches.append(SimilarityMatch(
|
| 322 |
+
index_a=idx_a,
|
| 323 |
+
index_b=idx_b,
|
| 324 |
+
score=float(score),
|
| 325 |
+
modality_a=mod1,
|
| 326 |
+
modality_b=mod2,
|
| 327 |
+
item_a=file_names[mod1][idx_a],
|
| 328 |
+
item_b=file_names[mod2][idx_b]
|
| 329 |
+
))
|
| 330 |
+
all_scores.append(score)
|
| 331 |
+
|
| 332 |
+
# Compute statistics
|
| 333 |
+
if all_scores:
|
| 334 |
+
statistics.update({
|
| 335 |
+
"avg_score": float(np.mean(all_scores)),
|
| 336 |
+
"max_score": float(np.max(all_scores)),
|
| 337 |
+
"min_score": float(np.min(all_scores)),
|
| 338 |
+
"total_comparisons": len(all_scores)
|
| 339 |
+
})
|
| 340 |
+
|
| 341 |
+
# Sort matches by score in descending order
|
| 342 |
+
matches.sort(key=lambda x: x.score, reverse=True)
|
| 343 |
+
|
| 344 |
+
return SimilarityResponse(
|
| 345 |
+
matches=matches,
|
| 346 |
+
statistics=statistics,
|
| 347 |
+
modality_pairs=modality_pairs
|
| 348 |
+
)
|
| 349 |
|
| 350 |
@app.get("/health")
|
| 351 |
async def health_check(
|