Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from imagebind import data | |
| from imagebind.models import imagebind_model | |
| from imagebind.models.imagebind_model import ModalityType as ImageBindModalityType | |
| from pydub import AudioSegment | |
| from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.concurrency import run_in_threadpool | |
| from pydantic import BaseModel, Field # Убрали BaseSettings отсюда | |
| from pydantic_settings import BaseSettings # <--- ИЗМЕНЕННЫЙ ИМПОРТ | |
| from typing import List, Dict, Optional, Tuple, Any | |
| import tempfile | |
| import uvicorn | |
| import numpy as np | |
| import logging | |
| from contextlib import asynccontextmanager | |
| class Settings(BaseSettings): | |
| api_token: str = "your-default-token-here" | |
| model_device: Optional[str] = None | |
| log_level: str = "INFO" | |
| class Config: | |
| env_file = ".env" | |
| env_file_encoding = 'utf-8' | |
| settings = Settings() | |
| logging.basicConfig(level=settings.log_level.upper()) | |
| logger = logging.getLogger(__name__) | |
| class EmbeddingManager: | |
| _instance = None | |
| def __new__(cls, *args, **kwargs): | |
| if not cls._instance: | |
| cls._instance = super(EmbeddingManager, cls).__new__(cls, *args, **kwargs) | |
| return cls._instance | |
| def __init__(self): | |
| if not hasattr(self, 'initialized'): | |
| self.device = settings.model_device or ("cuda:0" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Initializing EmbeddingManager on device: {self.device}") | |
| try: | |
| self.model = imagebind_model.imagebind_huge(pretrained=True) | |
| self.model.eval() | |
| self.model.to(self.device) | |
| self.initialized = True | |
| logger.info("ImageBind model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load ImageBind model: {e}") | |
| raise RuntimeError(f"Failed to load ImageBind model: {e}") | |
| async def compute_embeddings(self, | |
| image_inputs: Optional[List[Tuple[str, str]]] = None, | |
| audio_inputs: Optional[List[Tuple[str, str]]] = None, | |
| text_inputs: Optional[List[str]] = None, | |
| depth_inputs: Optional[List[Tuple[str, str]]] = None, | |
| thermal_inputs: Optional[List[Tuple[str, str]]] = None, | |
| imu_inputs: Optional[List[Tuple[str, str]]] = None | |
| ) -> Dict[str, List[Dict[str, Any]]]: | |
| inputs = {} | |
| input_ids = {} | |
| if text_inputs: | |
| inputs[ImageBindModalityType.TEXT] = data.load_and_transform_text(text_inputs, self.device) | |
| input_ids[ImageBindModalityType.TEXT] = text_inputs | |
| if image_inputs: | |
| paths = [item[0] for item in image_inputs] | |
| inputs[ImageBindModalityType.VISION] = data.load_and_transform_vision_data(paths, self.device) | |
| input_ids[ImageBindModalityType.VISION] = [item[1] for item in image_inputs] | |
| if audio_inputs: | |
| paths = [item[0] for item in audio_inputs] | |
| inputs[ImageBindModalityType.AUDIO] = data.load_and_transform_audio_data(paths, self.device) | |
| input_ids[ImageBindModalityType.AUDIO] = [item[1] for item in audio_inputs] | |
| if depth_inputs: | |
| logger.warning("Depth modality processing is not yet fully implemented.") | |
| if thermal_inputs: | |
| logger.warning("Thermal modality processing is not yet fully implemented.") | |
| if imu_inputs: | |
| logger.warning("IMU modality processing is not yet fully implemented.") | |
| if not inputs: | |
| return {} | |
| with torch.no_grad(): | |
| raw_embeddings = await run_in_threadpool(self.model, inputs) | |
| result_embeddings = {} | |
| for modality_type, embeddings_tensor in raw_embeddings.items(): | |
| modality_key = modality_type.name.lower() | |
| result_embeddings[modality_key] = [] | |
| ids_for_modality = input_ids.get(modality_type, []) | |
| for i, emb in enumerate(embeddings_tensor.cpu().numpy().tolist()): | |
| item_id = ids_for_modality[i] if i < len(ids_for_modality) else f"item_{i}" | |
| result_embeddings[modality_key].append({"id": item_id, "embedding": emb}) | |
| return result_embeddings | |
| embedding_manager: Optional[EmbeddingManager] = None | |
| async def lifespan(app: FastAPI): | |
| global embedding_manager | |
| logger.info("Application startup...") | |
| embedding_manager = EmbeddingManager() | |
| settings.model_device = embedding_manager.device | |
| yield | |
| logger.info("Application shutdown...") | |
| app = FastAPI(lifespan=lifespan, title="ImageBind API", version="0.2.0") | |
| security = HTTPBearer() | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| if credentials.scheme != "Bearer" or credentials.credentials != settings.api_token: | |
| logger.warning(f"Invalid authentication attempt. Scheme: {credentials.scheme}") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return credentials.credentials | |
| async def _save_upload_file_tmp(upload_file: UploadFile) -> Tuple[str, str]: | |
| try: | |
| suffix = os.path.splitext(upload_file.filename)[1] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| content = await upload_file.read() | |
| tmp.write(content) | |
| return tmp.name, upload_file.filename | |
| except Exception as e: | |
| logger.error(f"Error saving uploaded file {upload_file.filename}: {e}") | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save file: {upload_file.filename}") | |
| def convert_audio_to_wav(audio_path: str, original_filename: str) -> str: | |
| if audio_path.lower().endswith('.mp3') or not audio_path.lower().endswith('.wav'): | |
| wav_path = audio_path.rsplit('.', 1)[0] + '.wav' | |
| try: | |
| logger.info(f"Converting {original_filename} to WAV format.") | |
| audio = AudioSegment.from_file(audio_path) | |
| audio.export(wav_path, format='wav') | |
| if audio_path != wav_path and os.path.exists(audio_path): | |
| try: | |
| os.unlink(audio_path) | |
| except OSError: | |
| pass | |
| return wav_path | |
| except Exception as e: | |
| logger.error(f"Error converting audio file {original_filename} to WAV: {e}") | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not process audio file {original_filename}: {e}") | |
| return audio_path | |
| class ModalityType(str): | |
| VISION = "vision" | |
| AUDIO = "audio" | |
| TEXT = "text" | |
| DEPTH = "depth" | |
| THERMAL = "thermal" | |
| IMU = "imu" | |
| class EmbeddingItem(BaseModel): | |
| id: str | |
| embedding: List[float] | |
| class EmbeddingPayload(BaseModel): | |
| vision: Optional[List[EmbeddingItem]] = None | |
| audio: Optional[List[EmbeddingItem]] = None | |
| text: Optional[List[EmbeddingItem]] = None | |
| depth: Optional[List[EmbeddingItem]] = None | |
| thermal: Optional[List[EmbeddingItem]] = None | |
| imu: Optional[List[EmbeddingItem]] = None | |
| class EmbeddingResponse(BaseModel): | |
| embeddings: EmbeddingPayload | |
| message: str = "Embeddings computed successfully" | |
| class SimilarityMatch(BaseModel): | |
| item_a_id: str | |
| item_b_id: str | |
| modality_a: ModalityType | |
| modality_b: ModalityType | |
| score: float | |
| class SimilarityRequest(BaseModel): | |
| embeddings_payload: EmbeddingPayload | |
| threshold: float = 0.5 | |
| top_k: Optional[int] = None | |
| normalize_scores: bool = True | |
| compare_within_modalities: bool = True | |
| compare_across_modalities: bool = True | |
| class SimilarityResponse(BaseModel): | |
| matches: List[SimilarityMatch] | |
| statistics: Dict[str, float] | |
| modality_pairs_compared: List[str] | |
| async def generate_embeddings_endpoint( | |
| texts: Optional[List[str]] = Form(None), | |
| images: Optional[List[UploadFile]] = File(default=None), | |
| audio_files: Optional[List[UploadFile]] = File(default=None) | |
| ): | |
| if embedding_manager is None: | |
| raise HTTPException(status_code=503, detail="Embedding manager not initialized.") | |
| temp_files_to_clean = [] | |
| try: | |
| image_inputs: List[Tuple[str, str]] = [] | |
| audio_inputs: List[Tuple[str, str]] = [] | |
| if images: | |
| for img_file in images: | |
| path, name = await _save_upload_file_tmp(img_file) | |
| image_inputs.append((path, name)) | |
| temp_files_to_clean.append(path) | |
| if audio_files: | |
| for audio_file_in in audio_files: | |
| path, name = await _save_upload_file_tmp(audio_file_in) | |
| temp_files_to_clean.append(path) | |
| wav_path = convert_audio_to_wav(path, name) | |
| audio_inputs.append((wav_path, name)) | |
| if wav_path != path: | |
| temp_files_to_clean.append(wav_path) | |
| text_inputs_processed = [t.strip() for t in texts if t.strip()] if texts else None | |
| if not any([image_inputs, audio_inputs, text_inputs_processed]): | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid inputs provided for embedding.") | |
| computed_data = await embedding_manager.compute_embeddings( | |
| image_inputs=image_inputs if image_inputs else None, | |
| audio_inputs=audio_inputs if audio_inputs else None, | |
| text_inputs=text_inputs_processed if text_inputs_processed else None | |
| ) | |
| payload_data = { | |
| ModalityType.VISION: computed_data.get(ModalityType.VISION, []), | |
| ModalityType.AUDIO: computed_data.get(ModalityType.AUDIO, []), | |
| ModalityType.TEXT: computed_data.get(ModalityType.TEXT, []), | |
| } | |
| embedding_payload = EmbeddingPayload(**payload_data) | |
| return EmbeddingResponse(embeddings=embedding_payload) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in /compute_embeddings: {e}", exc_info=True) | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}") | |
| finally: | |
| for temp_file in temp_files_to_clean: | |
| try: | |
| if os.path.exists(temp_file): | |
| os.unlink(temp_file) | |
| except Exception as e_clean: | |
| logger.warning(f"Could not clean up temporary file {temp_file}: {e_clean}") | |
| def _compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool) -> torch.Tensor: | |
| if normalize: | |
| tensor1 = torch.nn.functional.normalize(tensor1, p=2, dim=1) | |
| tensor2 = torch.nn.functional.normalize(tensor2, p=2, dim=1) | |
| return torch.matmul(tensor1, tensor2.T) | |
| async def compute_similarities_endpoint(request: SimilarityRequest): | |
| all_matches: List[SimilarityMatch] = [] | |
| all_scores: List[float] = [] | |
| modality_pairs_compared_set = set() | |
| embeddings_by_modality: Dict[ModalityType, List[EmbeddingItem]] = {} | |
| if request.embeddings_payload.vision: | |
| embeddings_by_modality[ModalityType.VISION] = request.embeddings_payload.vision | |
| if request.embeddings_payload.audio: | |
| embeddings_by_modality[ModalityType.AUDIO] = request.embeddings_payload.audio | |
| if request.embeddings_payload.text: | |
| embeddings_by_modality[ModalityType.TEXT] = request.embeddings_payload.text | |
| modalities_present = list(embeddings_by_modality.keys()) | |
| current_device = embedding_manager.device if embedding_manager else "cpu" | |
| for i, mod1_type in enumerate(modalities_present): | |
| items1 = embeddings_by_modality[mod1_type] | |
| if not items1: continue | |
| tensor1 = torch.tensor([item.embedding for item in items1], device=current_device) | |
| if request.compare_within_modalities: | |
| sim_matrix_intra = _compute_similarity_matrix(tensor1, tensor1, request.normalize_scores) | |
| modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod1_type.value}") | |
| for r_idx in range(len(items1)): | |
| for c_idx in range(r_idx + 1, len(items1)): | |
| score = float(sim_matrix_intra[r_idx, c_idx].item()) | |
| if score >= request.threshold: | |
| all_matches.append(SimilarityMatch( | |
| item_a_id=items1[r_idx].id, item_b_id=items1[c_idx].id, | |
| modality_a=mod1_type, modality_b=mod1_type, score=score | |
| )) | |
| all_scores.append(score) | |
| if request.compare_across_modalities: | |
| for j in range(i + 1, len(modalities_present)): | |
| mod2_type = modalities_present[j] | |
| items2 = embeddings_by_modality[mod2_type] | |
| if not items2: continue | |
| tensor2 = torch.tensor([item.embedding for item in items2], device=current_device) | |
| sim_matrix_inter = _compute_similarity_matrix(tensor1, tensor2, request.normalize_scores) | |
| modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod2_type.value}") | |
| for r_idx in range(len(items1)): | |
| for c_idx in range(len(items2)): | |
| score = float(sim_matrix_inter[r_idx, c_idx].item()) | |
| if score >= request.threshold: | |
| all_matches.append(SimilarityMatch( | |
| item_a_id=items1[r_idx].id, item_b_id=items2[c_idx].id, | |
| modality_a=mod1_type, modality_b=mod2_type, score=score | |
| )) | |
| all_scores.append(score) | |
| all_matches.sort(key=lambda x: x.score, reverse=True) | |
| if request.top_k and len(all_matches) > request.top_k: | |
| all_matches = all_matches[:request.top_k] | |
| all_scores = [match.score for match in all_matches] | |
| stats = { | |
| "total_matches_found_above_threshold": len(all_matches), | |
| "avg_score": float(np.mean(all_scores)) if all_scores else 0.0, | |
| "max_score": float(np.max(all_scores)) if all_scores else 0.0, | |
| "min_score": float(np.min(all_scores)) if all_scores else 0.0, | |
| } | |
| return SimilarityResponse( | |
| matches=all_matches, | |
| statistics=stats, | |
| modality_pairs_compared=sorted(list(modality_pairs_compared_set)) | |
| ) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_device": settings.model_device, | |
| "torch_version": torch.__version__, | |
| "cuda_available": torch.cuda.is_available() | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level=settings.log_level.lower()) |