opex792 commited on
Commit
5bab2e7
·
verified ·
1 Parent(s): e6d5f01

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +298 -307
main.py CHANGED
@@ -2,26 +2,119 @@ import os
2
  import torch
3
  from imagebind import data
4
  from imagebind.models import imagebind_model
5
- from imagebind.models.imagebind_model import ModalityType
6
  from pydub import AudioSegment
7
- from fastapi import FastAPI, UploadFile, File, Form
8
- from typing import List, Dict
 
 
 
9
  import tempfile
10
- from pydantic import BaseModel
11
  import uvicorn
12
  import numpy as np
13
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
- from fastapi import Depends, HTTPException, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- app = FastAPI()
 
 
 
 
 
 
 
17
 
18
- # Add these lines after the app initialization
19
  security = HTTPBearer()
20
- API_TOKEN = os.getenv("API_TOKEN", "your-default-token-here") # Set a default token or use environment variable
21
 
22
- # Add this function for token verification
23
  async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
24
- if credentials.credentials != API_TOKEN:
 
25
  raise HTTPException(
26
  status_code=status.HTTP_401_UNAUTHORIZED,
27
  detail="Invalid authentication token",
@@ -29,333 +122,231 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(secur
29
  )
30
  return credentials.credentials
31
 
32
- def convert_audio_to_wav(audio_path: str) -> str:
33
- """Convert MP3 to WAV if necessary."""
34
- if audio_path.lower().endswith('.mp3'):
 
 
 
 
 
 
 
 
 
 
35
  wav_path = audio_path.rsplit('.', 1)[0] + '.wav'
36
- if not os.path.exists(wav_path):
37
- audio = AudioSegment.from_mp3(audio_path)
 
38
  audio.export(wav_path, format='wav')
39
- return wav_path
 
 
 
 
 
 
 
 
40
  return audio_path
41
 
42
- class EmbeddingManager:
43
- def __init__(self):
44
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
45
- self.model = imagebind_model.imagebind_huge(pretrained=True)
46
- self.model.eval()
47
- self.model.to(self.device)
48
-
49
- def compute_embeddings(self,
50
- images: List[str] = None,
51
- audio_files: List[str] = None,
52
- texts: List[str] = None) -> dict:
53
- """Compute embeddings for provided modalities only."""
54
- with torch.no_grad():
55
- inputs = {}
56
-
57
- if texts:
58
- inputs[ModalityType.TEXT] = data.load_and_transform_text(texts, self.device)
59
- if images:
60
- inputs[ModalityType.VISION] = data.load_and_transform_vision_data(images, self.device)
61
- if audio_files:
62
- inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data(audio_files, self.device)
63
-
64
- if not inputs:
65
- return {}
66
-
67
- embeddings = self.model(inputs)
68
-
69
- result = {}
70
- if ModalityType.VISION in inputs:
71
- result['vision'] = embeddings[ModalityType.VISION].cpu().numpy().tolist()
72
- if ModalityType.AUDIO in inputs:
73
- result['audio'] = embeddings[ModalityType.AUDIO].cpu().numpy().tolist()
74
- if ModalityType.TEXT in inputs:
75
- result['text'] = embeddings[ModalityType.TEXT].cpu().numpy().tolist()
76
-
77
- return result
78
-
79
- @staticmethod
80
- def compute_similarities(embeddings: Dict[str, List[List[float]]]) -> dict:
81
- """Compute similarities between available embeddings."""
82
- similarities = {}
83
-
84
- # Convert available embeddings to tensors
85
- tensors = {
86
- k: torch.tensor(v) for k, v in embeddings.items()
87
- if isinstance(v, (list, np.ndarray)) and len(v) > 0
88
- }
89
-
90
- # Compute cross-modal similarities
91
- modality_pairs = [
92
- ('vision', 'audio', 'vision_audio'),
93
- ('vision', 'text', 'vision_text'),
94
- ('audio', 'text', 'audio_text')
95
- ]
96
-
97
- for mod1, mod2, key in modality_pairs:
98
- if mod1 in tensors and mod2 in tensors:
99
- similarities[key] = torch.softmax(
100
- tensors[mod1] @ tensors[mod2].T,
101
- dim=-1
102
- ).numpy().tolist()
103
-
104
- # Compute same-modality similarities
105
- for modality in ['vision', 'audio', 'text']:
106
- if modality in tensors:
107
- key = f'{modality}_{modality}'
108
- similarities[key] = torch.softmax(
109
- tensors[modality] @ tensors[modality].T,
110
- dim=-1
111
- ).numpy().tolist()
112
-
113
- return similarities
114
 
115
- # Initialize the embedding manager
116
- embedding_manager = EmbeddingManager()
 
117
 
118
- class EmbeddingResponse(BaseModel):
119
- embeddings: dict
120
- file_names: dict
 
 
 
 
121
 
122
- class SimilarityRequest(BaseModel):
123
- embeddings: Dict[str, List[List[float]]]
124
- threshold: float = 0.5
125
- top_k: int | None = None
126
- include_self_similarity: bool = False
127
- normalize_scores: bool = True
128
 
129
  class SimilarityMatch(BaseModel):
130
- index_a: int
131
- index_b: int
132
- score: float
133
- modality_a: str
134
- modality_b: str
135
- item_a: str # Original item identifier (filename or text)
136
- item_b: str # Original item identifier (filename or text)
 
 
 
 
 
 
137
 
138
  class SimilarityResponse(BaseModel):
139
  matches: List[SimilarityMatch]
140
- statistics: Dict[str, float] # Contains avg_score, max_score, etc.
141
- modality_pairs: List[str] # Lists which modality comparisons were performed
142
-
143
- class ModalityPair:
144
- def __init__(self, mod1: str, mod2: str):
145
- self.mod1 = min(mod1, mod2) # Ensure consistent ordering
146
- self.mod2 = max(mod1, mod2)
147
-
148
- def __str__(self):
149
- return f"{self.mod1}_to_{self.mod2}"
150
-
151
- def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor:
152
- """Compute cosine similarity between two sets of embeddings."""
153
- # Normalize embeddings if requested
154
- if normalize:
155
- tensor1 = torch.nn.functional.normalize(tensor1, dim=1)
156
- tensor2 = torch.nn.functional.normalize(tensor2, dim=1)
157
-
158
- # Compute similarity matrix
159
- similarity = torch.matmul(tensor1, tensor2.T)
160
-
161
- return similarity
162
 
163
- def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]:
164
- """Get top-k matches from a similarity matrix."""
165
- if top_k is None:
166
- top_k = similarity_matrix.numel()
167
-
168
- # Flatten and get top-k indices
169
- flat_sim = similarity_matrix.flatten()
170
- top_k = min(top_k, flat_sim.numel())
171
- values, indices = torch.topk(flat_sim, k=top_k)
172
-
173
- # Convert flat indices to 2D indices
174
- rows = indices // similarity_matrix.size(1)
175
- cols = indices % similarity_matrix.size(1)
176
-
177
- return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)]
178
-
179
- @app.post("/compute_embeddings", response_model=EmbeddingResponse)
180
- async def generate_embeddings(
181
- credentials: HTTPAuthorizationCredentials = Depends(verify_token),
182
- texts: str | None = Form(None),
183
- images: List[UploadFile] | None = File(default=None),
184
- audio_files: List[UploadFile] | None = File(default=None)
185
  ):
186
- """Generate embeddings for any provided files and texts."""
187
- temp_files = []
 
 
188
 
189
  try:
190
- image_paths = []
191
- image_names = []
192
- audio_paths = []
193
- audio_names = []
194
- text_list = []
195
 
196
- # Process images if provided
197
  if images:
198
- for img in images:
199
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(img.filename)[1]) as tmp:
200
- content = await img.read()
201
- tmp.write(content)
202
- image_paths.append(tmp.name)
203
- image_names.append(img.filename)
204
- temp_files.append(tmp.name)
205
 
206
- # Process audio files if provided
207
  if audio_files:
208
- for audio in audio_files:
209
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp:
210
- content = await audio.read()
211
- tmp.write(content)
212
- audio_path = convert_audio_to_wav(tmp.name)
213
- audio_paths.append(audio_path)
214
- audio_names.append(audio.filename)
215
- temp_files.append(tmp.name)
216
- if audio_path != tmp.name:
217
- temp_files.append(audio_path)
218
-
219
- # Process texts if provided
220
- if texts:
221
- text_list = [text.strip() for text in texts.split('\n') if text.strip()]
222
-
223
- # Compute embeddings only if we have any input
224
- if not any([image_paths, audio_paths, text_list]):
225
- return EmbeddingResponse(
226
- embeddings={},
227
- file_names={}
228
- )
229
-
230
- embeddings = embedding_manager.compute_embeddings(
231
- image_paths if image_paths else None,
232
- audio_paths if audio_paths else None,
233
- text_list if text_list else None
234
  )
235
 
236
- file_names = {}
237
- if image_names:
238
- file_names['images'] = image_names
239
- if audio_names:
240
- file_names['audio'] = audio_names
241
- if text_list:
242
- file_names['texts'] = text_list
243
-
244
- return EmbeddingResponse(
245
- embeddings=embeddings,
246
- file_names=file_names
247
- )
248
 
 
 
 
 
 
249
  finally:
250
- # Clean up temporary files
251
- for temp_file in temp_files:
252
  try:
253
- os.unlink(temp_file)
254
- except:
255
- pass
256
-
257
- @app.post("/compute_similarities", response_model=SimilarityResponse)
258
- async def compute_similarities(
259
- request: SimilarityRequest,
260
- file_names: Dict[str, List[str]], # Maps modality to list of file/text names
261
- credentials: HTTPAuthorizationCredentials = Depends(verify_token)
262
- ):
263
- """
264
- Compute cross-modal similarities with advanced filtering and matching options.
265
-
266
- Parameters:
267
- - embeddings: Dict mapping modality to embedding tensors
268
- - threshold: Minimum similarity score to include in results
269
- - top_k: Maximum number of matches to return (per modality pair)
270
- - include_self_similarity: Whether to include same-item comparisons
271
- - normalize_scores: Whether to normalize embeddings before comparison
272
- - file_names: Dict mapping modality to list of original file/text names
273
- """
274
-
275
- matches = []
276
- statistics = {
277
- "avg_score": 0.0,
278
- "max_score": 0.0,
279
- "min_score": 1.0,
280
- "total_comparisons": 0
281
- }
282
-
283
- # Convert embeddings to tensors
284
- tensors = {
285
- k: torch.tensor(v) for k, v in request.embeddings.items()
286
- if isinstance(v, (list, np.ndarray)) and len(v) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  }
288
-
289
- modality_pairs = []
290
- all_scores = []
291
-
292
- # Get all possible modality pairs
293
- modalities = list(tensors.keys())
294
- for i, mod1 in enumerate(modalities):
295
- for mod2 in modalities[i:]: # Include self-comparisons if requested
296
- if mod1 == mod2 and not request.include_self_similarity:
297
- continue
298
-
299
- pair = ModalityPair(mod1, mod2)
300
- modality_pairs.append(str(pair))
301
-
302
- # Compute similarity matrix
303
- sim_matrix = compute_similarity_matrix(
304
- tensors[mod1],
305
- tensors[mod2],
306
- normalize=request.normalize_scores
307
- )
308
-
309
- # Get top matches
310
- top_matches = get_top_k_matches(sim_matrix, request.top_k)
311
-
312
- # Filter by threshold and create match objects
313
- for idx_a, idx_b, score in top_matches:
314
- if score < request.threshold:
315
- continue
316
-
317
- # Skip self-matches if not requested
318
- if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity:
319
- continue
320
-
321
- matches.append(SimilarityMatch(
322
- index_a=idx_a,
323
- index_b=idx_b,
324
- score=float(score),
325
- modality_a=mod1,
326
- modality_b=mod2,
327
- item_a=file_names[mod1][idx_a],
328
- item_b=file_names[mod2][idx_b]
329
- ))
330
- all_scores.append(score)
331
-
332
- # Compute statistics
333
- if all_scores:
334
- statistics.update({
335
- "avg_score": float(np.mean(all_scores)),
336
- "max_score": float(np.max(all_scores)),
337
- "min_score": float(np.min(all_scores)),
338
- "total_comparisons": len(all_scores)
339
- })
340
-
341
- # Sort matches by score in descending order
342
- matches.sort(key=lambda x: x.score, reverse=True)
343
-
344
  return SimilarityResponse(
345
- matches=matches,
346
- statistics=statistics,
347
- modality_pairs=modality_pairs
348
  )
349
 
350
- @app.get("/health")
351
- async def health_check(
352
- credentials: HTTPAuthorizationCredentials = Depends(verify_token)
353
- ):
354
- """Basic healthcheck endpoint that returns the status of the service."""
355
  return {
356
  "status": "healthy",
357
- "model_device": embedding_manager.device
 
 
358
  }
359
 
360
  if __name__ == "__main__":
361
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import torch
3
  from imagebind import data
4
  from imagebind.models import imagebind_model
5
+ from imagebind.models.imagebind_model import ModalityType as ImageBindModalityType
6
  from pydub import AudioSegment
7
+ from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+ from fastapi.concurrency import run_in_threadpool
10
+ from pydantic import BaseModel, Field, BaseSettings
11
+ from typing import List, Dict, Optional, Tuple, Any
12
  import tempfile
 
13
  import uvicorn
14
  import numpy as np
15
+ import logging
16
+ from contextlib import asynccontextmanager
17
+
18
+ class Settings(BaseSettings):
19
+ api_token: str = "your-default-token-here"
20
+ model_device: Optional[str] = None
21
+ log_level: str = "INFO"
22
+
23
+ class Config:
24
+ env_file = ".env"
25
+ env_file_encoding = 'utf-8'
26
+
27
+ settings = Settings()
28
+
29
+ logging.basicConfig(level=settings.log_level.upper())
30
+ logger = logging.getLogger(__name__)
31
+
32
+ class EmbeddingManager:
33
+ _instance = None
34
+
35
+ def __new__(cls, *args, **kwargs):
36
+ if not cls._instance:
37
+ cls._instance = super(EmbeddingManager, cls).__new__(cls, *args, **kwargs)
38
+ return cls._instance
39
+
40
+ def __init__(self):
41
+ if not hasattr(self, 'initialized'):
42
+ self.device = settings.model_device or ("cuda:0" if torch.cuda.is_available() else "cpu")
43
+ logger.info(f"Initializing EmbeddingManager on device: {self.device}")
44
+ try:
45
+ self.model = imagebind_model.imagebind_huge(pretrained=True)
46
+ self.model.eval()
47
+ self.model.to(self.device)
48
+ self.initialized = True
49
+ logger.info("ImageBind model loaded successfully.")
50
+ except Exception as e:
51
+ logger.error(f"Failed to load ImageBind model: {e}")
52
+ raise RuntimeError(f"Failed to load ImageBind model: {e}")
53
+
54
+ async def compute_embeddings(self,
55
+ image_inputs: Optional[List[Tuple[str, str]]] = None,
56
+ audio_inputs: Optional[List[Tuple[str, str]]] = None,
57
+ text_inputs: Optional[List[str]] = None,
58
+ depth_inputs: Optional[List[Tuple[str, str]]] = None,
59
+ thermal_inputs: Optional[List[Tuple[str, str]]] = None,
60
+ imu_inputs: Optional[List[Tuple[str, str]]] = None
61
+ ) -> Dict[str, List[Dict[str, Any]]]:
62
+ inputs = {}
63
+ input_ids = {}
64
+
65
+ if text_inputs:
66
+ inputs[ImageBindModalityType.TEXT] = data.load_and_transform_text(text_inputs, self.device)
67
+ input_ids[ImageBindModalityType.TEXT] = text_inputs
68
+ if image_inputs:
69
+ paths = [item[0] for item in image_inputs]
70
+ inputs[ImageBindModalityType.VISION] = data.load_and_transform_vision_data(paths, self.device)
71
+ input_ids[ImageBindModalityType.VISION] = [item[1] for item in image_inputs]
72
+ if audio_inputs:
73
+ paths = [item[0] for item in audio_inputs]
74
+ inputs[ImageBindModalityType.AUDIO] = data.load_and_transform_audio_data(paths, self.device)
75
+ input_ids[ImageBindModalityType.AUDIO] = [item[1] for item in audio_inputs]
76
+
77
+ if depth_inputs:
78
+ logger.warning("Depth modality processing is not yet fully implemented.")
79
+ if thermal_inputs:
80
+ logger.warning("Thermal modality processing is not yet fully implemented.")
81
+ if imu_inputs:
82
+ logger.warning("IMU modality processing is not yet fully implemented.")
83
+
84
+ if not inputs:
85
+ return {}
86
+
87
+ with torch.no_grad():
88
+ raw_embeddings = await run_in_threadpool(self.model, inputs)
89
+
90
+ result_embeddings = {}
91
+ for modality_type, embeddings_tensor in raw_embeddings.items():
92
+ modality_key = modality_type.name.lower()
93
+ result_embeddings[modality_key] = []
94
+ ids_for_modality = input_ids.get(modality_type, [])
95
+ for i, emb in enumerate(embeddings_tensor.cpu().numpy().tolist()):
96
+ item_id = ids_for_modality[i] if i < len(ids_for_modality) else f"item_{i}"
97
+ result_embeddings[modality_key].append({"id": item_id, "embedding": emb})
98
+
99
+ return result_embeddings
100
+
101
+ embedding_manager: Optional[EmbeddingManager] = None
102
 
103
+ @asynccontextmanager
104
+ async def lifespan(app: FastAPI):
105
+ global embedding_manager
106
+ logger.info("Application startup...")
107
+ embedding_manager = EmbeddingManager()
108
+ settings.model_device = embedding_manager.device
109
+ yield
110
+ logger.info("Application shutdown...")
111
 
112
+ app = FastAPI(lifespan=lifespan, title="ImageBind API", version="0.2.0")
113
  security = HTTPBearer()
 
114
 
 
115
  async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
116
+ if credentials.scheme != "Bearer" or credentials.credentials != settings.api_token:
117
+ logger.warning(f"Invalid authentication attempt. Scheme: {credentials.scheme}")
118
  raise HTTPException(
119
  status_code=status.HTTP_401_UNAUTHORIZED,
120
  detail="Invalid authentication token",
 
122
  )
123
  return credentials.credentials
124
 
125
+ async def _save_upload_file_tmp(upload_file: UploadFile) -> Tuple[str, str]:
126
+ try:
127
+ suffix = os.path.splitext(upload_file.filename)[1]
128
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
129
+ content = await upload_file.read()
130
+ tmp.write(content)
131
+ return tmp.name, upload_file.filename
132
+ except Exception as e:
133
+ logger.error(f"Error saving uploaded file {upload_file.filename}: {e}")
134
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save file: {upload_file.filename}")
135
+
136
+ def convert_audio_to_wav(audio_path: str, original_filename: str) -> str:
137
+ if audio_path.lower().endswith('.mp3') or not audio_path.lower().endswith('.wav'):
138
  wav_path = audio_path.rsplit('.', 1)[0] + '.wav'
139
+ try:
140
+ logger.info(f"Converting {original_filename} to WAV format.")
141
+ audio = AudioSegment.from_file(audio_path)
142
  audio.export(wav_path, format='wav')
143
+ if audio_path != wav_path and os.path.exists(audio_path):
144
+ try:
145
+ os.unlink(audio_path)
146
+ except OSError:
147
+ pass
148
+ return wav_path
149
+ except Exception as e:
150
+ logger.error(f"Error converting audio file {original_filename} to WAV: {e}")
151
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not process audio file {original_filename}: {e}")
152
  return audio_path
153
 
154
+ class ModalityType(str):
155
+ VISION = "vision"
156
+ AUDIO = "audio"
157
+ TEXT = "text"
158
+ DEPTH = "depth"
159
+ THERMAL = "thermal"
160
+ IMU = "imu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ class EmbeddingItem(BaseModel):
163
+ id: str = Field(..., description="Identifier of the item (e.g., filename or text content)")
164
+ embedding: List[float] = Field(..., description="The computed embedding vector")
165
 
166
+ class EmbeddingPayload(BaseModel):
167
+ vision: Optional[List[EmbeddingItem]] = Field(None, description="List of vision embeddings")
168
+ audio: Optional[List[EmbeddingItem]] = Field(None, description="List of audio embeddings")
169
+ text: Optional[List[EmbeddingItem]] = Field(None, description="List of text embeddings")
170
+ depth: Optional[List[EmbeddingItem]] = Field(None, description="List of depth embeddings (future support)")
171
+ thermal: Optional[List[EmbeddingItem]] = Field(None, description="List of thermal embeddings (future support)")
172
+ imu: Optional[List[EmbeddingItem]] = Field(None, description="List of IMU embeddings (future support)")
173
 
174
+ class EmbeddingResponse(BaseModel):
175
+ embeddings: EmbeddingPayload
176
+ message: str = "Embeddings computed successfully"
 
 
 
177
 
178
  class SimilarityMatch(BaseModel):
179
+ item_a_id: str
180
+ item_b_id: str
181
+ modality_a: ModalityType
182
+ modality_b: ModalityType
183
+ score: float = Field(..., ge=0.0, le=1.0001)
184
+
185
+ class SimilarityRequest(BaseModel):
186
+ embeddings_payload: EmbeddingPayload = Field(..., description="Payload containing embeddings from the /compute_embeddings endpoint")
187
+ threshold: float = Field(0.5, ge=0.0, le=1.0, description="Minimum similarity score to include in results")
188
+ top_k: Optional[int] = Field(None, gt=0, description="Maximum number of matches to return per modality pair comparison. If None, all matches above threshold are returned.")
189
+ normalize_scores: bool = Field(True, description="Whether to normalize embeddings before computing cosine similarity (recommended)")
190
+ compare_within_modalities: bool = Field(True, description="Compare items within the same modality (e.g., image1 vs image2)")
191
+ compare_across_modalities: bool = Field(True, description="Compare items across different modalities (e.g., image1 vs text1)")
192
 
193
  class SimilarityResponse(BaseModel):
194
  matches: List[SimilarityMatch]
195
+ statistics: Dict[str, float]
196
+ modality_pairs_compared: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ @app.post("/compute_embeddings", response_model=EmbeddingResponse, dependencies=[Depends(verify_token)])
199
+ async def generate_embeddings_endpoint(
200
+ texts: Optional[List[str]] = Form(None, description="List of text strings to embed."),
201
+ images: Optional[List[UploadFile]] = File(default=None, description="List of image files."),
202
+ audio_files: Optional[List[UploadFile]] = File(default=None, description="List of audio files (MP3, WAV, etc.).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  ):
204
+ if embedding_manager is None:
205
+ raise HTTPException(status_code=503, detail="Embedding manager not initialized.")
206
+
207
+ temp_files_to_clean = []
208
 
209
  try:
210
+ image_inputs: List[Tuple[str, str]] = []
211
+ audio_inputs: List[Tuple[str, str]] = []
 
 
 
212
 
 
213
  if images:
214
+ for img_file in images:
215
+ path, name = await _save_upload_file_tmp(img_file)
216
+ image_inputs.append((path, name))
217
+ temp_files_to_clean.append(path)
 
 
 
218
 
 
219
  if audio_files:
220
+ for audio_file_in in audio_files:
221
+ path, name = await _save_upload_file_tmp(audio_file_in)
222
+ temp_files_to_clean.append(path)
223
+ wav_path = convert_audio_to_wav(path, name)
224
+ audio_inputs.append((wav_path, name))
225
+ if wav_path != path:
226
+ temp_files_to_clean.append(wav_path)
227
+
228
+ text_inputs_processed = [t.strip() for t in texts if t.strip()] if texts else None
229
+
230
+ if not any([image_inputs, audio_inputs, text_inputs_processed]):
231
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid inputs provided for embedding.")
232
+
233
+ computed_data = await embedding_manager.compute_embeddings(
234
+ image_inputs=image_inputs if image_inputs else None,
235
+ audio_inputs=audio_inputs if audio_inputs else None,
236
+ text_inputs=text_inputs_processed if text_inputs_processed else None
 
 
 
 
 
 
 
 
 
237
  )
238
 
239
+ payload_data = {
240
+ ModalityType.VISION: computed_data.get(ModalityType.VISION, []),
241
+ ModalityType.AUDIO: computed_data.get(ModalityType.AUDIO, []),
242
+ ModalityType.TEXT: computed_data.get(ModalityType.TEXT, []),
243
+ }
244
+ embedding_payload = EmbeddingPayload(**payload_data)
245
+
246
+ return EmbeddingResponse(embeddings=embedding_payload)
 
 
 
 
247
 
248
+ except HTTPException:
249
+ raise
250
+ except Exception as e:
251
+ logger.error(f"Error in /compute_embeddings: {e}", exc_info=True)
252
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}")
253
  finally:
254
+ for temp_file in temp_files_to_clean:
 
255
  try:
256
+ if os.path.exists(temp_file):
257
+ os.unlink(temp_file)
258
+ except Exception as e_clean:
259
+ logger.warning(f"Could not clean up temporary file {temp_file}: {e_clean}")
260
+
261
+ def _compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool) -> torch.Tensor:
262
+ if normalize:
263
+ tensor1 = torch.nn.functional.normalize(tensor1, p=2, dim=1)
264
+ tensor2 = torch.nn.functional.normalize(tensor2, p=2, dim=1)
265
+ return torch.matmul(tensor1, tensor2.T)
266
+
267
+ @app.post("/compute_similarities", response_model=SimilarityResponse, dependencies=[Depends(verify_token)])
268
+ async def compute_similarities_endpoint(request: SimilarityRequest):
269
+ all_matches: List[SimilarityMatch] = []
270
+ all_scores: List[float] = []
271
+ modality_pairs_compared_set = set()
272
+
273
+ embeddings_by_modality: Dict[ModalityType, List[EmbeddingItem]] = {}
274
+ if request.embeddings_payload.vision:
275
+ embeddings_by_modality[ModalityType.VISION] = request.embeddings_payload.vision
276
+ if request.embeddings_payload.audio:
277
+ embeddings_by_modality[ModalityType.AUDIO] = request.embeddings_payload.audio
278
+ if request.embeddings_payload.text:
279
+ embeddings_by_modality[ModalityType.TEXT] = request.embeddings_payload.text
280
+
281
+ modalities_present = list(embeddings_by_modality.keys())
282
+ current_device = embedding_manager.device if embedding_manager else "cpu"
283
+
284
+
285
+ for i, mod1_type in enumerate(modalities_present):
286
+ items1 = embeddings_by_modality[mod1_type]
287
+ if not items1: continue
288
+ tensor1 = torch.tensor([item.embedding for item in items1], device=current_device)
289
+
290
+ if request.compare_within_modalities:
291
+ sim_matrix_intra = _compute_similarity_matrix(tensor1, tensor1, request.normalize_scores)
292
+ modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod1_type.value}")
293
+
294
+ for r_idx in range(len(items1)):
295
+ for c_idx in range(r_idx + 1, len(items1)):
296
+ score = float(sim_matrix_intra[r_idx, c_idx].item())
297
+ if score >= request.threshold:
298
+ all_matches.append(SimilarityMatch(
299
+ item_a_id=items1[r_idx].id, item_b_id=items1[c_idx].id,
300
+ modality_a=mod1_type, modality_b=mod1_type, score=score
301
+ ))
302
+ all_scores.append(score)
303
+
304
+ if request.compare_across_modalities:
305
+ for j in range(i + 1, len(modalities_present)):
306
+ mod2_type = modalities_present[j]
307
+ items2 = embeddings_by_modality[mod2_type]
308
+ if not items2: continue
309
+ tensor2 = torch.tensor([item.embedding for item in items2], device=current_device)
310
+
311
+ sim_matrix_inter = _compute_similarity_matrix(tensor1, tensor2, request.normalize_scores)
312
+ modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod2_type.value}")
313
+
314
+ for r_idx in range(len(items1)):
315
+ for c_idx in range(len(items2)):
316
+ score = float(sim_matrix_inter[r_idx, c_idx].item())
317
+ if score >= request.threshold:
318
+ all_matches.append(SimilarityMatch(
319
+ item_a_id=items1[r_idx].id, item_b_id=items2[c_idx].id,
320
+ modality_a=mod1_type, modality_b=mod2_type, score=score
321
+ ))
322
+ all_scores.append(score)
323
+
324
+ all_matches.sort(key=lambda x: x.score, reverse=True)
325
+ if request.top_k and len(all_matches) > request.top_k:
326
+ all_matches = all_matches[:request.top_k]
327
+ all_scores = [match.score for match in all_matches]
328
+
329
+ stats = {
330
+ "total_matches_found_above_threshold": len(all_matches),
331
+ "avg_score": float(np.mean(all_scores)) if all_scores else 0.0,
332
+ "max_score": float(np.max(all_scores)) if all_scores else 0.0,
333
+ "min_score": float(np.min(all_scores)) if all_scores else 0.0,
334
  }
335
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  return SimilarityResponse(
337
+ matches=all_matches,
338
+ statistics=stats,
339
+ modality_pairs_compared=sorted(list(modality_pairs_compared_set))
340
  )
341
 
342
+ @app.get("/health", status_code=status.HTTP_200_OK, dependencies=[Depends(verify_token)])
343
+ async def health_check():
 
 
 
344
  return {
345
  "status": "healthy",
346
+ "model_device": settings.model_device,
347
+ "torch_version": torch.__version__,
348
+ "cuda_available": torch.cuda.is_available()
349
  }
350
 
351
  if __name__ == "__main__":
352
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level=settings.log_level.lower())