Spaces:
Running
Running
| # src/models.py | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel, AutoImageProcessor | |
| from ultralytics import YOLO | |
| import torch.nn.functional as F | |
| from deepface import DeepFace | |
| # YOLO class index for "person" — we must exclude these from the object lane | |
| # when faces have already been found, to avoid polluting the object index with humans. | |
| YOLO_PERSON_CLASS_ID = 0 | |
| # Minimum face bounding box area (pixels²) to avoid indexing tiny/background faces | |
| # e.g. a face on a TV screen in the background, or a crowd member 50px wide | |
| MIN_FACE_AREA = 3000 # roughly 55x55 pixels minimum | |
| class AIModelManager: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Loading models onto: {self.device.upper()}...") | |
| self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False) | |
| self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device) | |
| self.siglip_model.eval() | |
| self.dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') | |
| self.dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(self.device) | |
| self.dinov2_model.eval() | |
| self.yolo = YOLO('yolo11n-seg.pt') | |
| def _embed_object_crop(self, crop_pil): | |
| """Runs SigLIP + DINOv2 on a single crop and returns the fused 1536-D vector.""" | |
| with torch.no_grad(): | |
| siglip_inputs = self.siglip_processor(images=crop_pil, return_tensors="pt").to(self.device) | |
| siglip_out = self.siglip_model.get_image_features(**siglip_inputs) | |
| if hasattr(siglip_out, 'image_embeds'): | |
| siglip_out = siglip_out.image_embeds | |
| elif isinstance(siglip_out, tuple): | |
| siglip_out = siglip_out[0] | |
| siglip_vec = F.normalize(siglip_out, p=2, dim=1).cpu() | |
| dinov2_inputs = self.dinov2_processor(images=crop_pil, return_tensors="pt").to(self.device) | |
| dinov2_out = self.dinov2_model(**dinov2_inputs) | |
| dinov2_vec = dinov2_out.last_hidden_state[:, 0, :] | |
| dinov2_vec = F.normalize(dinov2_vec, p=2, dim=1).cpu() | |
| object_vec = torch.cat((siglip_vec, dinov2_vec), dim=1) | |
| object_vec = F.normalize(object_vec, p=2, dim=1) | |
| return object_vec.flatten().numpy() | |
| def process_image(self, image_path: str, is_query=False): | |
| """ | |
| Master function: Extracts EVERY face and EVERY non-human object from an image. | |
| Key design decisions: | |
| - Face lane runs first and tags every face with its bounding box area. | |
| - Only faces above MIN_FACE_AREA are indexed (filters background/tiny faces). | |
| - For queries, ALL detected faces are used (not just the first one). | |
| - Object lane SKIPS any YOLO detection whose class is 'person', so humans | |
| never pollute the object index when faces were already found. | |
| - If NO faces are found at all, humans caught by YOLO DO go into the object | |
| lane (as a fallback for silhouettes, backs-of-head, full body shots etc.) | |
| """ | |
| extracted_vectors = [] | |
| original_img_pil = Image.open(image_path).convert('RGB') | |
| img_np = np.array(original_img_pil) | |
| img_h, img_w = img_np.shape[:2] | |
| faces_were_found = False # Track whether Lane 1 found anything usable | |
| # ========================================== | |
| # LANE 1: THE FACE LANE | |
| # ========================================== | |
| try: | |
| face_objs = DeepFace.represent( | |
| img_path=img_np, | |
| model_name="GhostFaceNet", | |
| detector_backend="retinaface", | |
| enforce_detection=True, | |
| align=True | |
| ) | |
| for index, face in enumerate(face_objs): | |
| # --- BUG FIX 5: Filter out tiny/background faces --- | |
| facial_area = face.get("facial_area", {}) | |
| fw = facial_area.get("w", img_w) | |
| fh = facial_area.get("h", img_h) | |
| face_area_px = fw * fh | |
| if face_area_px < MIN_FACE_AREA: | |
| print(f"🟡 FACE {index+1} SKIPPED: Too small ({fw}x{fh}px = {face_area_px}px²) — likely background noise.") | |
| continue | |
| face_vec = torch.tensor([face["embedding"]]) | |
| face_vec = F.normalize(face_vec, p=2, dim=1) | |
| extracted_vectors.append({ | |
| "type": "face", | |
| "vector": face_vec.flatten().numpy() | |
| }) | |
| faces_were_found = True | |
| print(f"🟢 FACE {index+1} EXTRACTED: {fw}x{fh}px — Added to Face Index.") | |
| # --- BUG FIX 2: For queries, do NOT break — search with ALL faces --- | |
| # The calling code in main.py already loops over all returned vectors, | |
| # so returning multiple face vectors means we search for every person | |
| # in a group photo query simultaneously. | |
| # (is_query flag is kept as parameter for future use / logging only) | |
| except ValueError: | |
| print("🟠 NO FACES DETECTED -> Falling back to Object Lane for any humans.") | |
| # ========================================== | |
| # LANE 2: THE OBJECT LANE | |
| # ========================================== | |
| yolo_results = self.yolo(image_path, conf=0.5) | |
| # Always include the full image as one crop for global context | |
| crops = [original_img_pil] | |
| for r in yolo_results: | |
| if r.masks is not None: | |
| for seg_idx, mask_xy in enumerate(r.masks.xy): | |
| # --- BUG FIX 1: Skip 'person' class detections when faces were found --- | |
| # This prevents human body crops from polluting the object index. | |
| # If no faces were found (back-of-head, silhouette, etc.), we DO | |
| # allow person-class detections through as a fallback. | |
| detected_class_id = int(r.boxes.cls[seg_idx].item()) | |
| if faces_were_found and detected_class_id == YOLO_PERSON_CLASS_ID: | |
| print(f"🔵 PERSON crop SKIPPED (faces already in Face Lane) — avoiding object index pollution.") | |
| continue | |
| polygon = np.array(mask_xy, dtype=np.int32) | |
| if len(polygon) < 3: | |
| continue | |
| x, y, w, h = cv2.boundingRect(polygon) | |
| if w < 30 or h < 30: | |
| continue | |
| cropped_img = original_img_pil.crop((x, y, x + w, y + h)) | |
| crops.append(cropped_img) | |
| for crop in crops: | |
| vec = self._embed_object_crop(crop) | |
| extracted_vectors.append({ | |
| "type": "object", | |
| "vector": vec | |
| }) | |
| return extracted_vectors |