Spaces:
Running
Running
File size: 7,092 Bytes
3e805ab c96096b 3e805ab c96096b 3e805ab c96096b 3e805ab c96096b 3e805ab c96096b 3e805ab c96096b 3e805ab c96096b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | # src/models.py
import torch
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
from ultralytics import YOLO
import torch.nn.functional as F
from deepface import DeepFace
# YOLO class index for "person" — we must exclude these from the object lane
# when faces have already been found, to avoid polluting the object index with humans.
YOLO_PERSON_CLASS_ID = 0
# Minimum face bounding box area (pixels²) to avoid indexing tiny/background faces
# e.g. a face on a TV screen in the background, or a crowd member 50px wide
MIN_FACE_AREA = 3000 # roughly 55x55 pixels minimum
class AIModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Loading models onto: {self.device.upper()}...")
self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device)
self.siglip_model.eval()
self.dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
self.dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(self.device)
self.dinov2_model.eval()
self.yolo = YOLO('yolo11n-seg.pt')
def _embed_object_crop(self, crop_pil):
"""Runs SigLIP + DINOv2 on a single crop and returns the fused 1536-D vector."""
with torch.no_grad():
siglip_inputs = self.siglip_processor(images=crop_pil, return_tensors="pt").to(self.device)
siglip_out = self.siglip_model.get_image_features(**siglip_inputs)
if hasattr(siglip_out, 'image_embeds'):
siglip_out = siglip_out.image_embeds
elif isinstance(siglip_out, tuple):
siglip_out = siglip_out[0]
siglip_vec = F.normalize(siglip_out, p=2, dim=1).cpu()
dinov2_inputs = self.dinov2_processor(images=crop_pil, return_tensors="pt").to(self.device)
dinov2_out = self.dinov2_model(**dinov2_inputs)
dinov2_vec = dinov2_out.last_hidden_state[:, 0, :]
dinov2_vec = F.normalize(dinov2_vec, p=2, dim=1).cpu()
object_vec = torch.cat((siglip_vec, dinov2_vec), dim=1)
object_vec = F.normalize(object_vec, p=2, dim=1)
return object_vec.flatten().numpy()
def process_image(self, image_path: str, is_query=False):
"""
Master function: Extracts EVERY face and EVERY non-human object from an image.
Key design decisions:
- Face lane runs first and tags every face with its bounding box area.
- Only faces above MIN_FACE_AREA are indexed (filters background/tiny faces).
- For queries, ALL detected faces are used (not just the first one).
- Object lane SKIPS any YOLO detection whose class is 'person', so humans
never pollute the object index when faces were already found.
- If NO faces are found at all, humans caught by YOLO DO go into the object
lane (as a fallback for silhouettes, backs-of-head, full body shots etc.)
"""
extracted_vectors = []
original_img_pil = Image.open(image_path).convert('RGB')
img_np = np.array(original_img_pil)
img_h, img_w = img_np.shape[:2]
faces_were_found = False # Track whether Lane 1 found anything usable
# ==========================================
# LANE 1: THE FACE LANE
# ==========================================
try:
face_objs = DeepFace.represent(
img_path=img_np,
model_name="GhostFaceNet",
detector_backend="retinaface",
enforce_detection=True,
align=True
)
for index, face in enumerate(face_objs):
# --- BUG FIX 5: Filter out tiny/background faces ---
facial_area = face.get("facial_area", {})
fw = facial_area.get("w", img_w)
fh = facial_area.get("h", img_h)
face_area_px = fw * fh
if face_area_px < MIN_FACE_AREA:
print(f"🟡 FACE {index+1} SKIPPED: Too small ({fw}x{fh}px = {face_area_px}px²) — likely background noise.")
continue
face_vec = torch.tensor([face["embedding"]])
face_vec = F.normalize(face_vec, p=2, dim=1)
extracted_vectors.append({
"type": "face",
"vector": face_vec.flatten().numpy()
})
faces_were_found = True
print(f"🟢 FACE {index+1} EXTRACTED: {fw}x{fh}px — Added to Face Index.")
# --- BUG FIX 2: For queries, do NOT break — search with ALL faces ---
# The calling code in main.py already loops over all returned vectors,
# so returning multiple face vectors means we search for every person
# in a group photo query simultaneously.
# (is_query flag is kept as parameter for future use / logging only)
except ValueError:
print("🟠 NO FACES DETECTED -> Falling back to Object Lane for any humans.")
# ==========================================
# LANE 2: THE OBJECT LANE
# ==========================================
yolo_results = self.yolo(image_path, conf=0.5)
# Always include the full image as one crop for global context
crops = [original_img_pil]
for r in yolo_results:
if r.masks is not None:
for seg_idx, mask_xy in enumerate(r.masks.xy):
# --- BUG FIX 1: Skip 'person' class detections when faces were found ---
# This prevents human body crops from polluting the object index.
# If no faces were found (back-of-head, silhouette, etc.), we DO
# allow person-class detections through as a fallback.
detected_class_id = int(r.boxes.cls[seg_idx].item())
if faces_were_found and detected_class_id == YOLO_PERSON_CLASS_ID:
print(f"🔵 PERSON crop SKIPPED (faces already in Face Lane) — avoiding object index pollution.")
continue
polygon = np.array(mask_xy, dtype=np.int32)
if len(polygon) < 3:
continue
x, y, w, h = cv2.boundingRect(polygon)
if w < 30 or h < 30:
continue
cropped_img = original_img_pil.crop((x, y, x + w, y + h))
crops.append(cropped_img)
for crop in crops:
vec = self._embed_object_crop(crop)
extracted_vectors.append({
"type": "object",
"vector": vec
})
return extracted_vectors |