Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from imagebind import data | |
| from imagebind.models import imagebind_model | |
| from imagebind.models.imagebind_model import ModalityType | |
| from pydub import AudioSegment | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from typing import List, Dict | |
| import tempfile | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import numpy as np | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi import Depends, HTTPException, status | |
| app = FastAPI() | |
| # Add these lines after the app initialization | |
| security = HTTPBearer() | |
| API_TOKEN = os.getenv("API_TOKEN", "your-default-token-here") # Set a default token or use environment variable | |
| # Add this function for token verification | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| if credentials.credentials != API_TOKEN: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return credentials.credentials | |
| def convert_audio_to_wav(audio_path: str) -> str: | |
| """Convert MP3 to WAV if necessary.""" | |
| if audio_path.lower().endswith('.mp3'): | |
| wav_path = audio_path.rsplit('.', 1)[0] + '.wav' | |
| if not os.path.exists(wav_path): | |
| audio = AudioSegment.from_mp3(audio_path) | |
| audio.export(wav_path, format='wav') | |
| return wav_path | |
| return audio_path | |
| class EmbeddingManager: | |
| def __init__(self): | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.model = imagebind_model.imagebind_huge(pretrained=True) | |
| self.model.eval() | |
| self.model.to(self.device) | |
| def compute_embeddings(self, | |
| images: List[str] = None, | |
| audio_files: List[str] = None, | |
| texts: List[str] = None) -> dict: | |
| """Compute embeddings for provided modalities only.""" | |
| with torch.no_grad(): | |
| inputs = {} | |
| if texts: | |
| inputs[ModalityType.TEXT] = data.load_and_transform_text(texts, self.device) | |
| if images: | |
| inputs[ModalityType.VISION] = data.load_and_transform_vision_data(images, self.device) | |
| if audio_files: | |
| inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data(audio_files, self.device) | |
| if not inputs: | |
| return {} | |
| embeddings = self.model(inputs) | |
| result = {} | |
| if ModalityType.VISION in inputs: | |
| result['vision'] = embeddings[ModalityType.VISION].cpu().numpy().tolist() | |
| if ModalityType.AUDIO in inputs: | |
| result['audio'] = embeddings[ModalityType.AUDIO].cpu().numpy().tolist() | |
| if ModalityType.TEXT in inputs: | |
| result['text'] = embeddings[ModalityType.TEXT].cpu().numpy().tolist() | |
| return result | |
| def compute_similarities(embeddings: Dict[str, List[List[float]]]) -> dict: | |
| """Compute similarities between available embeddings.""" | |
| similarities = {} | |
| # Convert available embeddings to tensors | |
| tensors = { | |
| k: torch.tensor(v) for k, v in embeddings.items() | |
| if isinstance(v, (list, np.ndarray)) and len(v) > 0 | |
| } | |
| # Compute cross-modal similarities | |
| modality_pairs = [ | |
| ('vision', 'audio', 'vision_audio'), | |
| ('vision', 'text', 'vision_text'), | |
| ('audio', 'text', 'audio_text') | |
| ] | |
| for mod1, mod2, key in modality_pairs: | |
| if mod1 in tensors and mod2 in tensors: | |
| similarities[key] = torch.softmax( | |
| tensors[mod1] @ tensors[mod2].T, | |
| dim=-1 | |
| ).numpy().tolist() | |
| # Compute same-modality similarities | |
| for modality in ['vision', 'audio', 'text']: | |
| if modality in tensors: | |
| key = f'{modality}_{modality}' | |
| similarities[key] = torch.softmax( | |
| tensors[modality] @ tensors[modality].T, | |
| dim=-1 | |
| ).numpy().tolist() | |
| return similarities | |
| # Initialize the embedding manager | |
| embedding_manager = EmbeddingManager() | |
| class EmbeddingResponse(BaseModel): | |
| embeddings: dict | |
| file_names: dict | |
| class SimilarityRequest(BaseModel): | |
| embeddings: Dict[str, List[List[float]]] | |
| threshold: float = 0.5 | |
| top_k: int | None = None | |
| include_self_similarity: bool = False | |
| normalize_scores: bool = True | |
| class SimilarityMatch(BaseModel): | |
| index_a: int | |
| index_b: int | |
| score: float | |
| modality_a: str | |
| modality_b: str | |
| item_a: str # Original item identifier (filename or text) | |
| item_b: str # Original item identifier (filename or text) | |
| class SimilarityResponse(BaseModel): | |
| matches: List[SimilarityMatch] | |
| statistics: Dict[str, float] # Contains avg_score, max_score, etc. | |
| modality_pairs: List[str] # Lists which modality comparisons were performed | |
| class ModalityPair: | |
| def __init__(self, mod1: str, mod2: str): | |
| self.mod1 = min(mod1, mod2) # Ensure consistent ordering | |
| self.mod2 = max(mod1, mod2) | |
| def __str__(self): | |
| return f"{self.mod1}_to_{self.mod2}" | |
| def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
| """Compute cosine similarity between two sets of embeddings.""" | |
| # Normalize embeddings if requested | |
| if normalize: | |
| tensor1 = torch.nn.functional.normalize(tensor1, dim=1) | |
| tensor2 = torch.nn.functional.normalize(tensor2, dim=1) | |
| # Compute similarity matrix | |
| similarity = torch.matmul(tensor1, tensor2.T) | |
| return similarity | |
| def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]: | |
| """Get top-k matches from a similarity matrix.""" | |
| if top_k is None: | |
| top_k = similarity_matrix.numel() | |
| # Flatten and get top-k indices | |
| flat_sim = similarity_matrix.flatten() | |
| top_k = min(top_k, flat_sim.numel()) | |
| values, indices = torch.topk(flat_sim, k=top_k) | |
| # Convert flat indices to 2D indices | |
| rows = indices // similarity_matrix.size(1) | |
| cols = indices % similarity_matrix.size(1) | |
| return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)] | |
| async def generate_embeddings( | |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token), | |
| texts: str | None = Form(None), | |
| images: List[UploadFile] | None = File(default=None), | |
| audio_files: List[UploadFile] | None = File(default=None) | |
| ): | |
| """Generate embeddings for any provided files and texts.""" | |
| temp_files = [] | |
| try: | |
| image_paths = [] | |
| image_names = [] | |
| audio_paths = [] | |
| audio_names = [] | |
| text_list = [] | |
| # Process images if provided | |
| if images: | |
| for img in images: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(img.filename)[1]) as tmp: | |
| content = await img.read() | |
| tmp.write(content) | |
| image_paths.append(tmp.name) | |
| image_names.append(img.filename) | |
| temp_files.append(tmp.name) | |
| # Process audio files if provided | |
| if audio_files: | |
| for audio in audio_files: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp: | |
| content = await audio.read() | |
| tmp.write(content) | |
| audio_path = convert_audio_to_wav(tmp.name) | |
| audio_paths.append(audio_path) | |
| audio_names.append(audio.filename) | |
| temp_files.append(tmp.name) | |
| if audio_path != tmp.name: | |
| temp_files.append(audio_path) | |
| # Process texts if provided | |
| if texts: | |
| text_list = [text.strip() for text in texts.split('\n') if text.strip()] | |
| # Compute embeddings only if we have any input | |
| if not any([image_paths, audio_paths, text_list]): | |
| return EmbeddingResponse( | |
| embeddings={}, | |
| file_names={} | |
| ) | |
| embeddings = embedding_manager.compute_embeddings( | |
| image_paths if image_paths else None, | |
| audio_paths if audio_paths else None, | |
| text_list if text_list else None | |
| ) | |
| file_names = {} | |
| if image_names: | |
| file_names['images'] = image_names | |
| if audio_names: | |
| file_names['audio'] = audio_names | |
| if text_list: | |
| file_names['texts'] = text_list | |
| return EmbeddingResponse( | |
| embeddings=embeddings, | |
| file_names=file_names | |
| ) | |
| finally: | |
| # Clean up temporary files | |
| for temp_file in temp_files: | |
| try: | |
| os.unlink(temp_file) | |
| except: | |
| pass | |
| async def compute_similarities( | |
| request: SimilarityRequest, | |
| file_names: Dict[str, List[str]], # Maps modality to list of file/text names | |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token) | |
| ): | |
| """ | |
| Compute cross-modal similarities with advanced filtering and matching options. | |
| Parameters: | |
| - embeddings: Dict mapping modality to embedding tensors | |
| - threshold: Minimum similarity score to include in results | |
| - top_k: Maximum number of matches to return (per modality pair) | |
| - include_self_similarity: Whether to include same-item comparisons | |
| - normalize_scores: Whether to normalize embeddings before comparison | |
| - file_names: Dict mapping modality to list of original file/text names | |
| """ | |
| matches = [] | |
| statistics = { | |
| "avg_score": 0.0, | |
| "max_score": 0.0, | |
| "min_score": 1.0, | |
| "total_comparisons": 0 | |
| } | |
| # Convert embeddings to tensors | |
| tensors = { | |
| k: torch.tensor(v) for k, v in request.embeddings.items() | |
| if isinstance(v, (list, np.ndarray)) and len(v) > 0 | |
| } | |
| modality_pairs = [] | |
| all_scores = [] | |
| # Get all possible modality pairs | |
| modalities = list(tensors.keys()) | |
| for i, mod1 in enumerate(modalities): | |
| for mod2 in modalities[i:]: # Include self-comparisons if requested | |
| if mod1 == mod2 and not request.include_self_similarity: | |
| continue | |
| pair = ModalityPair(mod1, mod2) | |
| modality_pairs.append(str(pair)) | |
| # Compute similarity matrix | |
| sim_matrix = compute_similarity_matrix( | |
| tensors[mod1], | |
| tensors[mod2], | |
| normalize=request.normalize_scores | |
| ) | |
| # Get top matches | |
| top_matches = get_top_k_matches(sim_matrix, request.top_k) | |
| # Filter by threshold and create match objects | |
| for idx_a, idx_b, score in top_matches: | |
| if score < request.threshold: | |
| continue | |
| # Skip self-matches if not requested | |
| if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity: | |
| continue | |
| matches.append(SimilarityMatch( | |
| index_a=idx_a, | |
| index_b=idx_b, | |
| score=float(score), | |
| modality_a=mod1, | |
| modality_b=mod2, | |
| item_a=file_names[mod1][idx_a], | |
| item_b=file_names[mod2][idx_b] | |
| )) | |
| all_scores.append(score) | |
| # Compute statistics | |
| if all_scores: | |
| statistics.update({ | |
| "avg_score": float(np.mean(all_scores)), | |
| "max_score": float(np.max(all_scores)), | |
| "min_score": float(np.min(all_scores)), | |
| "total_comparisons": len(all_scores) | |
| }) | |
| # Sort matches by score in descending order | |
| matches.sort(key=lambda x: x.score, reverse=True) | |
| return SimilarityResponse( | |
| matches=matches, | |
| statistics=statistics, | |
| modality_pairs=modality_pairs | |
| ) | |
| async def health_check( | |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token) | |
| ): | |
| """Basic healthcheck endpoint that returns the status of the service.""" | |
| return { | |
| "status": "healthy", | |
| "model_device": embedding_manager.device | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |