Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["INSIGHTFACE_HOME"] = "/tmp/.insightface" | |
| import json | |
| import tempfile | |
| import numpy as np | |
| from insightface.app import FaceAnalysis | |
| from scipy.spatial.distance import cosine | |
| import cv2 # OpenCV for image processing | |
| from typing import List, Dict, Any | |
| from datetime import datetime, timedelta | |
| import requests | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient | |
| from config import get_config | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| class EndpointHandler: | |
| def __init__(self, model_dir=None): | |
| # Initialize FaceAnalysis with GPU/CPU fallback support | |
| print("\n" + "="*80) | |
| print("INITIALIZING FACEANALYSIS") | |
| print("="*80) | |
| try: | |
| # Try GPU first | |
| print("Attempting to initialize with GPU (CUDA)...") | |
| self.app = FaceAnalysis(root="/tmp/.insightface", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| self.app.prepare(ctx_id=0) # 0 = GPU | |
| print("✅ GPU initialization successful (ctx_id=0)") | |
| self.gpu_available = True | |
| except RuntimeError as e: | |
| # GPU not available, fall back to CPU | |
| print(f"⚠️ GPU initialization failed: {str(e)[:100]}...") | |
| print("Falling back to CPU (CPUExecutionProvider)...") | |
| try: | |
| self.app = FaceAnalysis(root="/tmp/.insightface", providers=['CPUExecutionProvider']) | |
| self.app.prepare(ctx_id=-1) # -1 = CPU | |
| print("✅ CPU initialization successful (ctx_id=-1)") | |
| self.gpu_available = False | |
| except Exception as cpu_error: | |
| print(f"❌ CPU initialization also failed: {cpu_error}") | |
| raise | |
| print("="*80 + "\n") | |
| print("=" * 80) | |
| print("InsightFace Providers:") | |
| for model in self.app.models: | |
| if hasattr(model, 'sess'): | |
| print(f" {model.__class__.__name__}: {model.sess.get_providers()}") | |
| print("=" * 80) | |
| # Get configuration | |
| config = get_config() | |
| azure_config = config.get_azure_config() | |
| storage_config = config.get_storage_config() | |
| # Initialize Azure Blob Storage client | |
| if azure_config['connection_string']: | |
| self.blob_service_client = BlobServiceClient.from_connection_string( | |
| azure_config['connection_string'] | |
| ) | |
| else: | |
| # Use account name and key if connection string not available | |
| account_url = f"https://{azure_config['account_name']}.blob.core.windows.net" | |
| self.blob_service_client = BlobServiceClient( | |
| account_url=account_url, | |
| credential=azure_config['account_key'] | |
| ) | |
| self.container_name = storage_config['container_name'] | |
| self.prefix = storage_config['prefix'] | |
| self.embeddings_folder = storage_config['embeddings_folder'] | |
| # Get container client | |
| self.container_client = self.blob_service_client.get_container_client(self.container_name) | |
| # Initialize caching | |
| self.embeddings_cache = None | |
| self.cache_timestamp = 0 | |
| self.cache_ttl = 86400 # 24 hours in seconds | |
| self.image_cache = {} # In-memory image cache (URL -> numpy array) | |
| self.image_cache_max_size = 100 # Max 100 images in memory | |
| self.thread_pool = ThreadPoolExecutor(max_workers=8) | |
| # Pre-warm GPU and compile models | |
| self._prewarm_models() | |
| def _prewarm_models(self): | |
| """Pre-warm GPU and compile ONNX models on startup to eliminate cold-start latency.""" | |
| try: | |
| print("\n" + "="*80) | |
| mode = "GPU" if self.gpu_available else "CPU" | |
| print(f"PRE-WARMING MODELS ({mode} MODE)") | |
| print("="*80) | |
| start = time.time() | |
| # Create a small dummy image (100x100 random RGB) | |
| dummy_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) | |
| # Run inference on dummy image to trigger compilation | |
| _ = self.app.get(dummy_img) | |
| elapsed = time.time() - start | |
| print(f"✅ Models pre-warmed in {elapsed:.2f}s ({mode})") | |
| print("="*80 + "\n") | |
| except Exception as e: | |
| print(f"Warning: Model pre-warming failed (non-fatal): {e}") | |
| print("="*80 + "\n") | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| try: | |
| if "inputs" in data: | |
| return self.process_hf_input(data) | |
| else: | |
| return self.process_json_input(data) | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def process_hf_input(self, hf_data): | |
| """Process Hugging Face format input.""" | |
| if "inputs" in hf_data: | |
| actual_data = hf_data["inputs"] | |
| return self.process_json_input(actual_data) | |
| else: | |
| return {"error": "Invalid Hugging Face JSON structure."} | |
| def process_json_input(self, json_data): | |
| if "query_images" in json_data and "gender" in json_data: | |
| query_images = json_data["query_images"] | |
| gender = json_data["gender"] | |
| top_n = json_data.get("top_n", 5) | |
| similar_images = self.find_similar_images_aggregate(query_images, gender, top_n) | |
| return {"similar_images": similar_images} | |
| elif "extract_embeddings" in json_data and json_data["extract_embeddings"]: | |
| self.extract_and_save_embeddings() | |
| return {"status": "Embeddings extraction completed."} | |
| else: | |
| raise ValueError("Invalid JSON structure.") | |
| def load_embeddings_from_azure(self): | |
| """Load existing embeddings from Azure Blob Storage with caching.""" | |
| current_time = time.time() | |
| # Return cached embeddings if still valid | |
| if self.embeddings_cache is not None and (current_time - self.cache_timestamp) < self.cache_ttl: | |
| cache_age = int(current_time - self.cache_timestamp) | |
| print(f"[TIMING] Using cached embeddings (age: {cache_age}s)") | |
| return self.embeddings_cache | |
| print("\n" + "="*80) | |
| print("LOADING EMBEDDINGS FROM AZURE") | |
| print("="*80) | |
| blob_name = f'profile-media/embeddings/embeddings_db.json' | |
| print(f"Account: {self.blob_service_client.account_name}") | |
| print(f"Container: {self.container_name}") | |
| print(f"Blob Path: {blob_name}") | |
| print(f"Full URL: https://{self.blob_service_client.account_name}.blob.core.windows.net/{self.container_name}/{blob_name}") | |
| print("="*80 + "\n") | |
| try: | |
| blob_client = self.container_client.get_blob_client(blob_name) | |
| # Download the existing embeddings file if it exists | |
| temp_dir = tempfile.gettempdir() | |
| temp_file_path = os.path.join(temp_dir, 'embeddings_db.json') | |
| download_start = time.time() | |
| with open(temp_file_path, 'wb') as download_file: | |
| download_stream = blob_client.download_blob() | |
| download_file.write(download_stream.readall()) | |
| download_time = time.time() - download_start | |
| parse_start = time.time() | |
| with open(temp_file_path, 'r') as f: | |
| self.embeddings_cache = json.load(f) | |
| self.cache_timestamp = current_time | |
| parse_time = time.time() - parse_start | |
| print(f"[TIMING] Loaded {len(self.embeddings_cache)} embeddings: download={download_time:.3f}s, parse={parse_time:.3f}s") | |
| return self.embeddings_cache | |
| except Exception as e: | |
| print(f'Embeddings file not found in Azure, initializing a new one: {e}') | |
| self.embeddings_cache = [] | |
| self.cache_timestamp = current_time | |
| return [] | |
| def extract_and_save_embeddings(self): | |
| """Extract embeddings from images and save them to Azure Blob Storage.""" | |
| embeddings_db = self.load_embeddings_from_azure() | |
| now = datetime.utcnow() | |
| thirty_days_ago = now - timedelta(days=30) | |
| # Process images from both profile-media and ai-images folders | |
| folders_to_process = [ | |
| 'profile-media/', # profile-media folder (without container name) | |
| 'ai-images/men/', # ai-images/men folder (without container name) | |
| 'ai-images/women/' # ai-images/women folder (without container name) | |
| ] | |
| for folder_prefix in folders_to_process: | |
| try: | |
| print(f"Processing folder: {folder_prefix}") | |
| # List all blobs in the container with the current prefix | |
| blob_list = self.container_client.list_blobs(name_starts_with=folder_prefix) | |
| for blob in blob_list: | |
| blob_name = blob.name | |
| if blob_name.endswith(('.jpg', '.jpeg', '.png')): | |
| image_url = f'https://{self.blob_service_client.account_name}.blob.core.windows.net/{self.container_name}/{blob_name}' | |
| existing_entry = next((item for item in embeddings_db if item['image_url'] == image_url), None) | |
| if existing_entry: | |
| embedding_timestamp = datetime.fromisoformat(existing_entry['timestamp']) | |
| if (existing_entry.get('no_face_detected') or embedding_timestamp > thirty_days_ago) and blob.last_modified.replace(tzinfo=None) <= thirty_days_ago: | |
| continue | |
| print(f"Processing image: {blob_name}") | |
| try: | |
| # Create a unique temporary file with proper permissions | |
| temp_suffix = os.path.splitext(blob_name)[1] or '.jpg' | |
| with tempfile.NamedTemporaryFile(suffix=temp_suffix, delete=False) as temp_image_file: | |
| temp_file_path = temp_image_file.name | |
| # Download blob to temporary file | |
| blob_client = self.container_client.get_blob_client(blob_name) | |
| with open(temp_file_path, 'wb') as download_file: | |
| download_stream = blob_client.download_blob() | |
| download_file.write(download_stream.readall()) | |
| img = self.load_image_from_blob(blob_client) | |
| # Clean up temporary file immediately after reading | |
| try: | |
| os.unlink(temp_file_path) | |
| except: | |
| pass # Ignore cleanup errors | |
| if img is None: | |
| print(f"Failed to read image: {blob_name}") | |
| continue | |
| faces = self.app.get(img) | |
| if len(faces) == 0: | |
| print(f"No face detected in: {blob_name}") | |
| no_face_entry = { | |
| 'image_url': image_url, | |
| 'no_face_detected': True, | |
| 'timestamp': now.isoformat() | |
| } | |
| if existing_entry: | |
| existing_entry.update(no_face_entry) | |
| else: | |
| embeddings_db.append(no_face_entry) | |
| continue | |
| face = faces[0] | |
| embedding = face.embedding.tolist() | |
| gender = 'male' if face.gender == 1 else 'female' | |
| new_entry = { | |
| 'embedding': embedding, | |
| 'gender': gender, | |
| 'image_url': image_url, | |
| 'timestamp': now.isoformat() | |
| } | |
| if existing_entry: | |
| existing_entry.update(new_entry) | |
| else: | |
| embeddings_db.append(new_entry) | |
| print(f"Successfully processed: {blob_name} (gender: {gender})") | |
| except Exception as e: | |
| print(f"Error processing image {blob_name}: {e}") | |
| continue | |
| except Exception as e: | |
| print(f"Error processing folder {folder_prefix}: {e}") | |
| continue | |
| print(f"Total embeddings in database: {len(embeddings_db)}") | |
| # Save embeddings back to Azure | |
| try: | |
| temp_json_path = os.path.join(tempfile.gettempdir(), f'embeddings_db_{int(time.time())}.json') | |
| with open(temp_json_path, 'w') as temp_json_file: | |
| json.dump(embeddings_db, temp_json_file) | |
| # Upload to Azure Blob Storage - save in profile-media/embeddings/ | |
| blob_name = f'profile-media/embeddings/embeddings_db.json' | |
| blob_client = self.container_client.get_blob_client(blob_name) | |
| with open(temp_json_path, 'rb') as data: | |
| blob_client.upload_blob(data, overwrite=True) | |
| print(f"Embeddings saved to Azure: {blob_name}") | |
| # Clean up temporary file | |
| try: | |
| os.unlink(temp_json_path) | |
| except: | |
| pass # Ignore cleanup errors | |
| except Exception as e: | |
| print(f"Error saving embeddings: {e}") | |
| def find_similar_images_aggregate(self, query_images: List[str], gender: str, top_n: int = 5) -> List[str]: | |
| start_time = time.time() | |
| print(f"Debug: Starting similarity search with {len(query_images)} query images") | |
| print(f"Debug: Looking for gender: {gender}, top_n: {top_n}") | |
| # Load embeddings database once (cached) | |
| embeddings_db = self.load_embeddings_from_azure() | |
| print(f"Debug: Total embeddings in database: {len(embeddings_db)}") | |
| # Filter to only include images from profile-media folder structure | |
| profile_media_db = [item for item in embeddings_db if 'image_url' in item and 'profile-media' in item['image_url']] | |
| print(f"Debug: Profile-media embeddings: {len(profile_media_db)}") | |
| # Filter by gender: if 'all', include all items with gender field; otherwise filter by specific gender | |
| if gender == 'all': | |
| filtered_db = [item for item in profile_media_db if 'gender' in item and 'embedding' in item] | |
| else: | |
| filtered_db = [item for item in profile_media_db if 'gender' in item and item['gender'] == gender and 'embedding' in item] | |
| print(f"Debug: Filtered by gender '{gender}': {len(filtered_db)}") | |
| if len(filtered_db) == 0: | |
| print(f"Debug: No embeddings found for gender '{gender}' in profile-media folder") | |
| return [] | |
| # Process query images in parallel | |
| similarities = {} | |
| futures = {} | |
| for i, image_input in enumerate(query_images): | |
| future = self.thread_pool.submit(self._extract_query_embedding, image_input, i) | |
| futures[future] = i | |
| # Collect results from parallel processing | |
| query_embeddings = [] | |
| for future in as_completed(futures): | |
| i = futures[future] | |
| try: | |
| query_embedding = future.result() | |
| if query_embedding is not None: | |
| query_embeddings.append(query_embedding) | |
| print(f"Debug: Successfully extracted face embedding from query image {i+1}") | |
| except Exception as e: | |
| print(f"Debug: Error processing query image {i+1}: {e}") | |
| if not query_embeddings: | |
| print("Debug: No valid query embeddings extracted") | |
| return [] | |
| # Compute similarities for all query embeddings against filtered database | |
| similarity_start = time.time() | |
| for query_embedding in query_embeddings: | |
| for item in filtered_db: | |
| similarity = 1 - cosine(query_embedding, np.array(item['embedding'])) | |
| if item['image_url'] in similarities: | |
| similarities[item['image_url']].append(similarity) | |
| else: | |
| similarities[item['image_url']] = [similarity] | |
| similarity_time = time.time() - similarity_start | |
| # Aggregate similarities | |
| print(f"[TIMING] Similarity computation: {similarity_time:.3f}s") | |
| print(f"Debug: Total similarities found: {len(similarities)}") | |
| aggregated_similarities = [(np.mean(scores), url) for url, scores in similarities.items()] | |
| aggregated_similarities.sort(reverse=True, key=lambda x: x[0]) | |
| result = [url for _, url in aggregated_similarities[:top_n]] | |
| elapsed = time.time() - start_time | |
| print(f"\n[TIMING] REQUEST SUMMARY:") | |
| print(f" Total time: {elapsed:.3f}s") | |
| print(f" Query embeddings extracted: {len(query_embeddings)} images") | |
| print(f" Similarity computations: {len(similarities)} results") | |
| print(f"Debug: Returning {len(result)} recommendations\n") | |
| return result | |
| def _extract_query_embedding(self, image_input: str, index: int) -> Any: | |
| """Extract embedding from a single query image (for parallel processing).""" | |
| try: | |
| print(f"\nDebug: Processing query image {index+1}: {image_input}") | |
| request_start = time.time() | |
| # Load image | |
| load_start = time.time() | |
| if image_input.startswith('http'): | |
| # It's a URL | |
| img = self.load_image_from_url(image_input) | |
| elif image_input.startswith('data:image/'): | |
| # It's a base64-encoded image | |
| img = self.load_image_from_base64(image_input) | |
| elif image_input.startswith('ai-images/'): | |
| # Local filesystem (ai-images baked into container) | |
| img = self.load_image_from_local(image_input) | |
| else: | |
| # It's a local file path reference - convert to full Azure blob URL | |
| blob_url = f"https://koottuprod.blob.core.windows.net/koottu-media/{image_input}" | |
| img = self.load_image_from_url(blob_url) | |
| load_time = time.time() - load_start | |
| if img is None: | |
| print(f"Failed to load image: {image_input}") | |
| return None | |
| # Face detection | |
| detect_start = time.time() | |
| faces = self.app.get(img) | |
| detect_time = time.time() - detect_start | |
| if len(faces) == 0: | |
| elapsed = time.time() - request_start | |
| print(f" [TIMING] No faces detected: load={load_time:.3f}s, detect={detect_time:.3f}s, total={elapsed:.3f}s") | |
| return None | |
| embedding_extraction_time = time.time() - request_start | |
| print(f" [TIMING] Face found: load={load_time:.3f}s, detect={detect_time:.3f}s, total={embedding_extraction_time:.3f}s") | |
| return faces[0].embedding | |
| except Exception as e: | |
| print(f"Error extracting embedding from image {index+1}: {e}") | |
| return None | |
| def find_similar_images_by_embedding(self, query_embedding: np.ndarray, gender: str = 'all', top_n: int = 10, excluded_images: List[str] = None) -> List[str]: | |
| """Find similar images based on a given embedding vector.""" | |
| try: | |
| # Load embeddings database from Azure | |
| embeddings_db = self.load_embeddings_from_azure() | |
| # Filter to only include images from profile-media folder structure | |
| profile_media_db = [item for item in embeddings_db if 'image_url' in item and 'profile-media' in item['image_url']] | |
| # Filter by gender if specified | |
| if gender != 'all': | |
| filtered_db = [item for item in profile_media_db if 'gender' in item and item['gender'] == gender] | |
| else: | |
| filtered_db = [item for item in profile_media_db if 'embedding' in item] | |
| # Filter out excluded images | |
| if excluded_images is not None: | |
| filtered_db = [item for item in filtered_db if item['image_url'] not in excluded_images] | |
| similarities = [] | |
| for item in filtered_db: | |
| if 'embedding' in item and not item.get('no_face_detected', False): | |
| similarity = 1 - cosine(query_embedding, np.array(item['embedding'])) | |
| similarities.append((similarity, item['image_url'])) | |
| # Sort by similarity and return top matches | |
| similarities.sort(reverse=True, key=lambda x: x[0]) | |
| return [url for _, url in similarities[:top_n]] | |
| except Exception as e: | |
| print(f"Error in find_similar_images_by_embedding: {e}") | |
| return [] | |
| def load_image_from_url(self, url): | |
| try: | |
| # Check cache first | |
| if url in self.image_cache: | |
| print(f" [CACHE HIT] {url}") | |
| return self.image_cache[url] | |
| start = time.time() | |
| response = requests.get(url, timeout=30) | |
| download_time = time.time() - start | |
| response.raise_for_status() | |
| start = time.time() | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| image = np.array(image) | |
| parse_time = time.time() - start | |
| start = time.time() | |
| result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| convert_time = time.time() - start | |
| # Cache the result | |
| if len(self.image_cache) >= self.image_cache_max_size: | |
| # Remove oldest entry (simple FIFO) | |
| self.image_cache.pop(next(iter(self.image_cache))) | |
| self.image_cache[url] = result | |
| print(f" [TIMING] Image load: download={download_time:.3f}s, parse={parse_time:.3f}s, convert={convert_time:.3f}s [CACHED]") | |
| return result | |
| except Exception as e: | |
| print(f"Error loading image from URL {url}: {e}") | |
| return None | |
| def load_image_from_blob(self, blob_client): | |
| try: | |
| blob_bytes = blob_client.download_blob().readall() | |
| image = Image.open(BytesIO(blob_bytes)).convert('RGB') | |
| image = np.array(image) | |
| return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| except Exception as e: | |
| print(f"Error loading image from blob: {e}") | |
| return None | |
| def load_image_from_local(self, local_path): | |
| """Load image from local filesystem (for ai-images baked into container).""" | |
| try: | |
| start = time.time() | |
| full_path = os.path.join('/app', local_path) | |
| if not os.path.exists(full_path): | |
| print(f"Local image not found: {full_path}") | |
| return None | |
| image = cv2.imread(full_path) | |
| if image is None: | |
| print(f"Failed to read image: {full_path}") | |
| return None | |
| load_time = time.time() - start | |
| print(f" [TIMING] Image load (local): {load_time:.3f}s [LOCAL FILESYSTEM]") | |
| return image | |
| except Exception as e: | |
| print(f"Error loading local image {local_path}: {e}") | |
| return None | |
| def load_image_from_base64(self, base64_string): | |
| header, encoded = base64_string.split(',', 1) | |
| data = base64.b64decode(encoded) | |
| np_arr = np.frombuffer(data, np.uint8) | |
| img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) | |
| return img # Returns BGR image as expected by OpenCV | |
| # Instantiate the handler | |
| handler = EndpointHandler() | |