AdarshDRC commited on
Commit
c96096b
·
verified ·
1 Parent(s): 0c1e873

Update src/models.py

Browse files
Files changed (1) hide show
  1. src/models.py +144 -48
src/models.py CHANGED
@@ -1,58 +1,154 @@
1
  # src/models.py
2
  import torch
 
 
3
  from PIL import Image
4
- from transformers import AutoProcessor, AutoModel
5
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
6
 
7
  class AIModelManager:
8
  def __init__(self):
9
- # Load SigLIP (Vision & Text Encoder)
10
- self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224",use_fast=False)
11
- self.model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
12
- self.model.eval() # Set to evaluation mode
 
 
 
 
 
 
13
 
14
- # Load YOLOv11 (Nano version for speed)
15
- self.yolo = YOLO('yolov8n.pt') # Will auto-download the tiny weights
16
 
17
- def encode_image(self, image: Image.Image):
18
- """Converts a PIL Image into a vector."""
19
- inputs = self.processor(images=image, return_tensors="pt")
20
  with torch.no_grad():
21
- outputs = self.model.get_image_features(**inputs)
22
-
23
- # Extract the raw tensor from the output object
24
- if hasattr(outputs, 'image_embeds'):
25
- image_features = outputs.image_embeds
26
- elif hasattr(outputs, 'pooler_output'):
27
- image_features = outputs.pooler_output
28
- else:
29
- image_features = outputs
30
-
31
- return image_features.flatten().numpy()
32
-
33
- def encode_text(self, text: str):
34
- """Converts a text string into a vector."""
35
- inputs = self.processor(text=text, return_tensors="pt", padding="max_length")
36
- with torch.no_grad():
37
- outputs = self.model.get_text_features(**inputs)
38
-
39
- # Hugging Face quirk: Extract the raw tensor from the output object
40
- if hasattr(outputs, 'text_embeds'):
41
- text_features = outputs.text_embeds
42
- elif hasattr(outputs, 'pooler_output'):
43
- text_features = outputs.pooler_output
44
- else:
45
- text_features = outputs
46
-
47
- return text_features.flatten().numpy()
48
-
49
- def get_crops_from_image(self, image: Image.Image):
50
- """Uses YOLO to find objects and returns a list of cropped PIL Images."""
51
- results = self.yolo(image, conf=0.5) # Only keep confident detections
52
- crops = []
53
- for result in results:
54
- for box in result.boxes.xyxy: # Get bounding box coordinates
55
- x1, y1, x2, y2 = map(int, box.tolist())
56
- cropped_img = image.crop((x1, y1, x2, y2))
57
- crops.append(cropped_img)
58
- return crops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/models.py
2
  import torch
3
+ import cv2
4
+ import numpy as np
5
  from PIL import Image
6
+ from transformers import AutoProcessor, AutoModel, AutoImageProcessor
7
  from ultralytics import YOLO
8
+ import torch.nn.functional as F
9
+ from deepface import DeepFace
10
+
11
+ # YOLO class index for "person" — we must exclude these from the object lane
12
+ # when faces have already been found, to avoid polluting the object index with humans.
13
+ YOLO_PERSON_CLASS_ID = 0
14
+
15
+ # Minimum face bounding box area (pixels²) to avoid indexing tiny/background faces
16
+ # e.g. a face on a TV screen in the background, or a crowd member 50px wide
17
+ MIN_FACE_AREA = 3000 # roughly 55x55 pixels minimum
18
 
19
  class AIModelManager:
20
  def __init__(self):
21
+ self.device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
22
+ print(f"Loading models onto: {self.device.upper()}...")
23
+
24
+ self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
25
+ self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device)
26
+ self.siglip_model.eval()
27
+
28
+ self.dinov2_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
29
+ self.dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(self.device)
30
+ self.dinov2_model.eval()
31
 
32
+ self.yolo = YOLO('yolo11n-seg.pt')
 
33
 
34
+ def _embed_object_crop(self, crop_pil):
35
+ """Runs SigLIP + DINOv2 on a single crop and returns the fused 1536-D vector."""
 
36
  with torch.no_grad():
37
+ siglip_inputs = self.siglip_processor(images=crop_pil, return_tensors="pt").to(self.device)
38
+ siglip_out = self.siglip_model.get_image_features(**siglip_inputs)
39
+ if hasattr(siglip_out, 'image_embeds'):
40
+ siglip_out = siglip_out.image_embeds
41
+ elif isinstance(siglip_out, tuple):
42
+ siglip_out = siglip_out[0]
43
+ siglip_vec = F.normalize(siglip_out, p=2, dim=1).cpu()
44
+
45
+ dinov2_inputs = self.dinov2_processor(images=crop_pil, return_tensors="pt").to(self.device)
46
+ dinov2_out = self.dinov2_model(**dinov2_inputs)
47
+ dinov2_vec = dinov2_out.last_hidden_state[:, 0, :]
48
+ dinov2_vec = F.normalize(dinov2_vec, p=2, dim=1).cpu()
49
+
50
+ object_vec = torch.cat((siglip_vec, dinov2_vec), dim=1)
51
+ object_vec = F.normalize(object_vec, p=2, dim=1)
52
+
53
+ return object_vec.flatten().numpy()
54
+
55
+ def process_image(self, image_path: str, is_query=False):
56
+ """
57
+ Master function: Extracts EVERY face and EVERY non-human object from an image.
58
+
59
+ Key design decisions:
60
+ - Face lane runs first and tags every face with its bounding box area.
61
+ - Only faces above MIN_FACE_AREA are indexed (filters background/tiny faces).
62
+ - For queries, ALL detected faces are used (not just the first one).
63
+ - Object lane SKIPS any YOLO detection whose class is 'person', so humans
64
+ never pollute the object index when faces were already found.
65
+ - If NO faces are found at all, humans caught by YOLO DO go into the object
66
+ lane (as a fallback for silhouettes, backs-of-head, full body shots etc.)
67
+ """
68
+ extracted_vectors = []
69
+ original_img_pil = Image.open(image_path).convert('RGB')
70
+ img_np = np.array(original_img_pil)
71
+ img_h, img_w = img_np.shape[:2]
72
+
73
+ faces_were_found = False # Track whether Lane 1 found anything usable
74
+
75
+ # ==========================================
76
+ # LANE 1: THE FACE LANE
77
+ # ==========================================
78
+ try:
79
+ face_objs = DeepFace.represent(
80
+ img_path=img_np,
81
+ model_name="GhostFaceNet",
82
+ detector_backend="retinaface",
83
+ enforce_detection=True,
84
+ align=True
85
+ )
86
+
87
+ for index, face in enumerate(face_objs):
88
+ # --- BUG FIX 5: Filter out tiny/background faces ---
89
+ facial_area = face.get("facial_area", {})
90
+ fw = facial_area.get("w", img_w)
91
+ fh = facial_area.get("h", img_h)
92
+ face_area_px = fw * fh
93
+
94
+ if face_area_px < MIN_FACE_AREA:
95
+ print(f"🟡 FACE {index+1} SKIPPED: Too small ({fw}x{fh}px = {face_area_px}px²) — likely background noise.")
96
+ continue
97
+
98
+ face_vec = torch.tensor([face["embedding"]])
99
+ face_vec = F.normalize(face_vec, p=2, dim=1)
100
+
101
+ extracted_vectors.append({
102
+ "type": "face",
103
+ "vector": face_vec.flatten().numpy()
104
+ })
105
+ faces_were_found = True
106
+ print(f"🟢 FACE {index+1} EXTRACTED: {fw}x{fh}px — Added to Face Index.")
107
+
108
+ # --- BUG FIX 2: For queries, do NOT break — search with ALL faces ---
109
+ # The calling code in main.py already loops over all returned vectors,
110
+ # so returning multiple face vectors means we search for every person
111
+ # in a group photo query simultaneously.
112
+ # (is_query flag is kept as parameter for future use / logging only)
113
+
114
+ except ValueError:
115
+ print("🟠 NO FACES DETECTED -> Falling back to Object Lane for any humans.")
116
+
117
+ # ==========================================
118
+ # LANE 2: THE OBJECT LANE
119
+ # ==========================================
120
+ yolo_results = self.yolo(image_path, conf=0.5)
121
+
122
+ # Always include the full image as one crop for global context
123
+ crops = [original_img_pil]
124
+
125
+ for r in yolo_results:
126
+ if r.masks is not None:
127
+ for seg_idx, mask_xy in enumerate(r.masks.xy):
128
+ # --- BUG FIX 1: Skip 'person' class detections when faces were found ---
129
+ # This prevents human body crops from polluting the object index.
130
+ # If no faces were found (back-of-head, silhouette, etc.), we DO
131
+ # allow person-class detections through as a fallback.
132
+ detected_class_id = int(r.boxes.cls[seg_idx].item())
133
+ if faces_were_found and detected_class_id == YOLO_PERSON_CLASS_ID:
134
+ print(f"🔵 PERSON crop SKIPPED (faces already in Face Lane) — avoiding object index pollution.")
135
+ continue
136
+
137
+ polygon = np.array(mask_xy, dtype=np.int32)
138
+ if len(polygon) < 3:
139
+ continue
140
+ x, y, w, h = cv2.boundingRect(polygon)
141
+ if w < 30 or h < 30:
142
+ continue
143
+
144
+ cropped_img = original_img_pil.crop((x, y, x + w, y + h))
145
+ crops.append(cropped_img)
146
+
147
+ for crop in crops:
148
+ vec = self._embed_object_crop(crop)
149
+ extracted_vectors.append({
150
+ "type": "object",
151
+ "vector": vec
152
+ })
153
+
154
+ return extracted_vectors