visual-search-api / src /models.py
AdarshDRC's picture
Update src/models.py
c96096b verified
raw
history blame
7.09 kB
# src/models.py
import torch
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
from ultralytics import YOLO
import torch.nn.functional as F
from deepface import DeepFace
# YOLO class index for "person" — we must exclude these from the object lane
# when faces have already been found, to avoid polluting the object index with humans.
YOLO_PERSON_CLASS_ID = 0
# Minimum face bounding box area (pixels²) to avoid indexing tiny/background faces
# e.g. a face on a TV screen in the background, or a crowd member 50px wide
MIN_FACE_AREA = 3000 # roughly 55x55 pixels minimum
class AIModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Loading models onto: {self.device.upper()}...")
self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device)
self.siglip_model.eval()
self.dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
self.dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(self.device)
self.dinov2_model.eval()
self.yolo = YOLO('yolo11n-seg.pt')
def _embed_object_crop(self, crop_pil):
"""Runs SigLIP + DINOv2 on a single crop and returns the fused 1536-D vector."""
with torch.no_grad():
siglip_inputs = self.siglip_processor(images=crop_pil, return_tensors="pt").to(self.device)
siglip_out = self.siglip_model.get_image_features(**siglip_inputs)
if hasattr(siglip_out, 'image_embeds'):
siglip_out = siglip_out.image_embeds
elif isinstance(siglip_out, tuple):
siglip_out = siglip_out[0]
siglip_vec = F.normalize(siglip_out, p=2, dim=1).cpu()
dinov2_inputs = self.dinov2_processor(images=crop_pil, return_tensors="pt").to(self.device)
dinov2_out = self.dinov2_model(**dinov2_inputs)
dinov2_vec = dinov2_out.last_hidden_state[:, 0, :]
dinov2_vec = F.normalize(dinov2_vec, p=2, dim=1).cpu()
object_vec = torch.cat((siglip_vec, dinov2_vec), dim=1)
object_vec = F.normalize(object_vec, p=2, dim=1)
return object_vec.flatten().numpy()
def process_image(self, image_path: str, is_query=False):
"""
Master function: Extracts EVERY face and EVERY non-human object from an image.
Key design decisions:
- Face lane runs first and tags every face with its bounding box area.
- Only faces above MIN_FACE_AREA are indexed (filters background/tiny faces).
- For queries, ALL detected faces are used (not just the first one).
- Object lane SKIPS any YOLO detection whose class is 'person', so humans
never pollute the object index when faces were already found.
- If NO faces are found at all, humans caught by YOLO DO go into the object
lane (as a fallback for silhouettes, backs-of-head, full body shots etc.)
"""
extracted_vectors = []
original_img_pil = Image.open(image_path).convert('RGB')
img_np = np.array(original_img_pil)
img_h, img_w = img_np.shape[:2]
faces_were_found = False # Track whether Lane 1 found anything usable
# ==========================================
# LANE 1: THE FACE LANE
# ==========================================
try:
face_objs = DeepFace.represent(
img_path=img_np,
model_name="GhostFaceNet",
detector_backend="retinaface",
enforce_detection=True,
align=True
)
for index, face in enumerate(face_objs):
# --- BUG FIX 5: Filter out tiny/background faces ---
facial_area = face.get("facial_area", {})
fw = facial_area.get("w", img_w)
fh = facial_area.get("h", img_h)
face_area_px = fw * fh
if face_area_px < MIN_FACE_AREA:
print(f"🟡 FACE {index+1} SKIPPED: Too small ({fw}x{fh}px = {face_area_px}px²) — likely background noise.")
continue
face_vec = torch.tensor([face["embedding"]])
face_vec = F.normalize(face_vec, p=2, dim=1)
extracted_vectors.append({
"type": "face",
"vector": face_vec.flatten().numpy()
})
faces_were_found = True
print(f"🟢 FACE {index+1} EXTRACTED: {fw}x{fh}px — Added to Face Index.")
# --- BUG FIX 2: For queries, do NOT break — search with ALL faces ---
# The calling code in main.py already loops over all returned vectors,
# so returning multiple face vectors means we search for every person
# in a group photo query simultaneously.
# (is_query flag is kept as parameter for future use / logging only)
except ValueError:
print("🟠 NO FACES DETECTED -> Falling back to Object Lane for any humans.")
# ==========================================
# LANE 2: THE OBJECT LANE
# ==========================================
yolo_results = self.yolo(image_path, conf=0.5)
# Always include the full image as one crop for global context
crops = [original_img_pil]
for r in yolo_results:
if r.masks is not None:
for seg_idx, mask_xy in enumerate(r.masks.xy):
# --- BUG FIX 1: Skip 'person' class detections when faces were found ---
# This prevents human body crops from polluting the object index.
# If no faces were found (back-of-head, silhouette, etc.), we DO
# allow person-class detections through as a fallback.
detected_class_id = int(r.boxes.cls[seg_idx].item())
if faces_were_found and detected_class_id == YOLO_PERSON_CLASS_ID:
print(f"🔵 PERSON crop SKIPPED (faces already in Face Lane) — avoiding object index pollution.")
continue
polygon = np.array(mask_xy, dtype=np.int32)
if len(polygon) < 3:
continue
x, y, w, h = cv2.boundingRect(polygon)
if w < 30 or h < 30:
continue
cropped_img = original_img_pil.crop((x, y, x + w, y + h))
crops.append(cropped_img)
for crop in crops:
vec = self._embed_object_crop(crop)
extracted_vectors.append({
"type": "object",
"vector": vec
})
return extracted_vectors