Spaces:
Running
Running
| # src/models.py | |
| import os | |
| # FIX 1: Force Legacy Keras to prevent DeepFace/RetinaFace crash in TF 2.16+ | |
| os.environ["TF_USE_LEGACY_KERAS"] = "1" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Hides the annoying CUDA/cuInit warnings | |
| import asyncio | |
| import hashlib | |
| import functools | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel, AutoImageProcessor | |
| from ultralytics import YOLO | |
| import torch.nn.functional as F | |
| from deepface import DeepFace | |
| YOLO_PERSON_CLASS_ID = 0 | |
| MIN_FACE_AREA = 3000 # ~55×55 px minimum face | |
| MAX_CROPS = 6 # max YOLO crops + 1 full-image crop per request | |
| MAX_IMAGE_SIZE = 512 # resize longest edge before any inference | |
| def _resize_pil(img: Image.Image, max_side: int = MAX_IMAGE_SIZE) -> Image.Image: | |
| w, h = img.size | |
| if max(w, h) <= max_side: | |
| return img | |
| scale = max_side / max(w, h) | |
| return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| def _img_hash(image_path: str) -> str: | |
| h = hashlib.md5() | |
| with open(image_path, "rb") as f: | |
| h.update(f.read(65536)) | |
| return h.hexdigest() | |
| class AIModelManager: | |
| def __init__(self): | |
| self.device = ( | |
| "cuda" if torch.cuda.is_available() | |
| else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| ) | |
| print(f"Loading models onto: {self.device.upper()}...") | |
| self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False) | |
| self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval() | |
| self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") | |
| self.dinov2_model = AutoModel.from_pretrained("facebook/dinov2-base").to(self.device).eval() | |
| if self.device == "cuda": | |
| self.siglip_model = self.siglip_model.half() | |
| self.dinov2_model = self.dinov2_model.half() | |
| # FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default. | |
| # This fixes the "InvalidCxxCompiler" Search crash. | |
| self.yolo = YOLO("yolo11n-seg.pt") # seg model → pixel masks → accurate crops | |
| self._cache = {} | |
| self._cache_maxsize = 256 | |
| print("✅ Models ready!") | |
| def _embed_crops_batch(self, crops: list[Image.Image]) -> list[np.ndarray]: | |
| if not crops: | |
| return [] | |
| with torch.no_grad(): | |
| sig_inputs = self.siglip_processor(images=crops, return_tensors="pt", padding=True) | |
| sig_inputs = {k: v.to(self.device) for k, v in sig_inputs.items()} | |
| if self.device == "cuda": | |
| sig_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in sig_inputs.items()} | |
| sig_out = self.siglip_model.get_image_features(**sig_inputs) | |
| if hasattr(sig_out, "image_embeds"): | |
| sig_out = sig_out.image_embeds | |
| elif isinstance(sig_out, tuple): | |
| sig_out = sig_out[0] | |
| sig_vecs = F.normalize(sig_out.float(), p=2, dim=1).cpu() | |
| dino_inputs = self.dinov2_processor(images=crops, return_tensors="pt") | |
| dino_inputs = {k: v.to(self.device) for k, v in dino_inputs.items()} | |
| if self.device == "cuda": | |
| dino_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in dino_inputs.items()} | |
| dino_out = self.dinov2_model(**dino_inputs) | |
| dino_vecs = dino_out.last_hidden_state[:, 0, :] | |
| dino_vecs = F.normalize(dino_vecs.float(), p=2, dim=1).cpu() | |
| fused = F.normalize(torch.cat([sig_vecs, dino_vecs], dim=1), p=2, dim=1) | |
| return [fused[i].numpy() for i in range(len(crops))] | |
| def process_image(self, image_path: str, is_query: bool = False, detect_faces: bool = True) -> list[dict]: | |
| cache_key = _img_hash(image_path) | |
| if cache_key in self._cache: | |
| print("⚡ Cache hit — skipping inference") | |
| return self._cache[cache_key] | |
| extracted = [] | |
| original_pil = Image.open(image_path).convert("RGB") | |
| small_pil = _resize_pil(original_pil, MAX_IMAGE_SIZE) | |
| img_np = np.array(small_pil) | |
| faces_found = False | |
| if detect_faces: | |
| try: | |
| print("🔍 Face detection …") | |
| face_objs = DeepFace.represent( | |
| img_path=img_np, | |
| model_name="GhostFaceNet", | |
| detector_backend="retinaface", | |
| enforce_detection=False, | |
| align=True, | |
| ) | |
| for face in (face_objs or []): | |
| fa = face.get("facial_area", {}) | |
| if fa.get("w", 0) * fa.get("h", 0) < MIN_FACE_AREA: | |
| continue | |
| vec = torch.tensor([face["embedding"]]) | |
| vec = F.normalize(vec, p=2, dim=1) | |
| extracted.append({"type": "face", "vector": vec.flatten().numpy()}) | |
| faces_found = True | |
| except Exception as e: | |
| print(f"🟠 Face lane error: {e} — falling back to object lane") | |
| # Full-res PIL for crops — YOLO returns coordinates in full-res pixel space. | |
| # We crop from original_pil then resize each crop before embedding. | |
| # BUG FIX: old optimised code cropped from small_pil (512px) using | |
| # full-res YOLO coordinates → completely wrong crop regions. | |
| crops_pil = [original_pil] # full-image always included for global context | |
| yolo_results = self.yolo(image_path, conf=0.5, verbose=False) | |
| for r in yolo_results: | |
| # Use segmentation masks when available (yolo11n-seg.pt) | |
| if r.masks is not None: | |
| for seg_idx, mask_xy in enumerate(r.masks.xy): | |
| cls_id = int(r.boxes.cls[seg_idx].item()) | |
| if faces_found and cls_id == YOLO_PERSON_CLASS_ID: | |
| print("🔵 PERSON crop skipped — face lane already active") | |
| continue | |
| polygon = np.array(mask_xy, dtype=np.int32) | |
| if len(polygon) < 3: | |
| continue | |
| x, y, w, h = cv2.boundingRect(polygon) | |
| if w < 30 or h < 30: | |
| continue | |
| crop = original_pil.crop((x, y, x + w, y + h)) | |
| crops_pil.append(crop) | |
| if len(crops_pil) >= MAX_CROPS + 1: | |
| break | |
| elif r.boxes is not None: | |
| # Fallback: plain bounding boxes (shouldn't happen with seg model) | |
| for box in r.boxes: | |
| cls_id = int(box.cls.item()) | |
| if faces_found and cls_id == YOLO_PERSON_CLASS_ID: | |
| continue | |
| x1, y1, x2, y2 = box.xyxy[0].tolist() | |
| if (x2 - x1) < 30 or (y2 - y1) < 30: | |
| continue | |
| crop = original_pil.crop((x1, y1, x2, y2)) | |
| crops_pil.append(crop) | |
| if len(crops_pil) >= MAX_CROPS + 1: | |
| break | |
| # Resize each crop to MAX_IMAGE_SIZE before batched embedding | |
| # (models expect ~224px anyway; no quality loss, big speed gain) | |
| crops = [_resize_pil(c, MAX_IMAGE_SIZE) for c in crops_pil] | |
| print(f"🧠 Embedding {len(crops)} crop(s) in one batch …") | |
| vecs = self._embed_crops_batch(crops) | |
| for vec in vecs: | |
| extracted.append({"type": "object", "vector": vec}) | |
| if len(self._cache) >= self._cache_maxsize: | |
| oldest = next(iter(self._cache)) | |
| del self._cache[oldest] | |
| self._cache[cache_key] = extracted | |
| return extracted | |
| async def process_image_async(self, image_path: str, is_query: bool = False, detect_faces: bool = True) -> list[dict]: | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, functools.partial(self.process_image, image_path, is_query, detect_faces)) |