File size: 7,092 Bytes
3e805ab
 
c96096b
 
3e805ab
c96096b
3e805ab
c96096b
 
 
 
 
 
 
 
 
 
3e805ab
 
 
c96096b
 
 
 
 
 
 
 
 
 
3e805ab
c96096b
3e805ab
c96096b
 
3e805ab
c96096b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# src/models.py
import torch
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
from ultralytics import YOLO
import torch.nn.functional as F
from deepface import DeepFace

# YOLO class index for "person" — we must exclude these from the object lane
# when faces have already been found, to avoid polluting the object index with humans.
YOLO_PERSON_CLASS_ID = 0

# Minimum face bounding box area (pixels²) to avoid indexing tiny/background faces
# e.g. a face on a TV screen in the background, or a crowd member 50px wide
MIN_FACE_AREA = 3000  # roughly 55x55 pixels minimum

class AIModelManager:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
        print(f"Loading models onto: {self.device.upper()}...")

        self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
        self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device)
        self.siglip_model.eval()

        self.dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
        self.dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(self.device)
        self.dinov2_model.eval()
        
        self.yolo = YOLO('yolo11n-seg.pt')

    def _embed_object_crop(self, crop_pil):
        """Runs SigLIP + DINOv2 on a single crop and returns the fused 1536-D vector."""
        with torch.no_grad():
            siglip_inputs = self.siglip_processor(images=crop_pil, return_tensors="pt").to(self.device)
            siglip_out = self.siglip_model.get_image_features(**siglip_inputs)
            if hasattr(siglip_out, 'image_embeds'):
                siglip_out = siglip_out.image_embeds
            elif isinstance(siglip_out, tuple):
                siglip_out = siglip_out[0]
            siglip_vec = F.normalize(siglip_out, p=2, dim=1).cpu()

            dinov2_inputs = self.dinov2_processor(images=crop_pil, return_tensors="pt").to(self.device)
            dinov2_out = self.dinov2_model(**dinov2_inputs)
            dinov2_vec = dinov2_out.last_hidden_state[:, 0, :]
            dinov2_vec = F.normalize(dinov2_vec, p=2, dim=1).cpu()

            object_vec = torch.cat((siglip_vec, dinov2_vec), dim=1)
            object_vec = F.normalize(object_vec, p=2, dim=1)

        return object_vec.flatten().numpy()

    def process_image(self, image_path: str, is_query=False):
        """
        Master function: Extracts EVERY face and EVERY non-human object from an image.

        Key design decisions:
        - Face lane runs first and tags every face with its bounding box area.
        - Only faces above MIN_FACE_AREA are indexed (filters background/tiny faces).
        - For queries, ALL detected faces are used (not just the first one).
        - Object lane SKIPS any YOLO detection whose class is 'person', so humans
          never pollute the object index when faces were already found.
        - If NO faces are found at all, humans caught by YOLO DO go into the object
          lane (as a fallback for silhouettes, backs-of-head, full body shots etc.)
        """
        extracted_vectors = []
        original_img_pil = Image.open(image_path).convert('RGB')
        img_np = np.array(original_img_pil)
        img_h, img_w = img_np.shape[:2]

        faces_were_found = False  # Track whether Lane 1 found anything usable

        # ==========================================
        # LANE 1: THE FACE LANE
        # ==========================================
        try:
            face_objs = DeepFace.represent(
                img_path=img_np,
                model_name="GhostFaceNet",
                detector_backend="retinaface",
                enforce_detection=True,
                align=True
            )

            for index, face in enumerate(face_objs):
                # --- BUG FIX 5: Filter out tiny/background faces ---
                facial_area = face.get("facial_area", {})
                fw = facial_area.get("w", img_w)
                fh = facial_area.get("h", img_h)
                face_area_px = fw * fh

                if face_area_px < MIN_FACE_AREA:
                    print(f"🟡 FACE {index+1} SKIPPED: Too small ({fw}x{fh}px = {face_area_px}px²) — likely background noise.")
                    continue

                face_vec = torch.tensor([face["embedding"]])
                face_vec = F.normalize(face_vec, p=2, dim=1)

                extracted_vectors.append({
                    "type": "face",
                    "vector": face_vec.flatten().numpy()
                })
                faces_were_found = True
                print(f"🟢 FACE {index+1} EXTRACTED: {fw}x{fh}px — Added to Face Index.")

                # --- BUG FIX 2: For queries, do NOT break — search with ALL faces ---
                # The calling code in main.py already loops over all returned vectors,
                # so returning multiple face vectors means we search for every person
                # in a group photo query simultaneously.
                # (is_query flag is kept as parameter for future use / logging only)

        except ValueError:
            print("🟠 NO FACES DETECTED -> Falling back to Object Lane for any humans.")

        # ==========================================
        # LANE 2: THE OBJECT LANE
        # ==========================================
        yolo_results = self.yolo(image_path, conf=0.5)

        # Always include the full image as one crop for global context
        crops = [original_img_pil]

        for r in yolo_results:
            if r.masks is not None:
                for seg_idx, mask_xy in enumerate(r.masks.xy):
                    # --- BUG FIX 1: Skip 'person' class detections when faces were found ---
                    # This prevents human body crops from polluting the object index.
                    # If no faces were found (back-of-head, silhouette, etc.), we DO
                    # allow person-class detections through as a fallback.
                    detected_class_id = int(r.boxes.cls[seg_idx].item())
                    if faces_were_found and detected_class_id == YOLO_PERSON_CLASS_ID:
                        print(f"🔵 PERSON crop SKIPPED (faces already in Face Lane) — avoiding object index pollution.")
                        continue

                    polygon = np.array(mask_xy, dtype=np.int32)
                    if len(polygon) < 3:
                        continue
                    x, y, w, h = cv2.boundingRect(polygon)
                    if w < 30 or h < 30:
                        continue

                    cropped_img = original_img_pil.crop((x, y, x + w, y + h))
                    crops.append(cropped_img)

        for crop in crops:
            vec = self._embed_object_crop(crop)
            extracted_vectors.append({
                "type": "object",
                "vector": vec
            })

        return extracted_vectors