AdarshDRC commited on
Commit
5d013dc
·
verified ·
1 Parent(s): 8f0f0e4

Update src/models.py

Browse files
Files changed (1) hide show
  1. src/models.py +41 -18
src/models.py CHANGED
@@ -44,7 +44,7 @@ class AIModelManager:
44
  )
45
  print(f"Loading models onto: {self.device.upper()}...")
46
 
47
- self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=True)
48
  self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval()
49
 
50
  self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
@@ -57,7 +57,7 @@ class AIModelManager:
57
  # FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default.
58
  # This fixes the "InvalidCxxCompiler" Search crash.
59
 
60
- self.yolo = YOLO("yolo11n.pt")
61
 
62
  self._cache = {}
63
  self._cache_maxsize = 256
@@ -128,27 +128,50 @@ class AIModelManager:
128
  except Exception as e:
129
  print(f"🟠 Face lane error: {e} — falling back to object lane")
130
 
131
- crops = [small_pil]
 
 
 
 
 
132
  yolo_results = self.yolo(image_path, conf=0.5, verbose=False)
133
 
134
  for r in yolo_results:
135
- if r.boxes is None:
136
- continue
137
- for box in r.boxes:
138
- cls_id = int(box.cls.item())
139
- if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
140
- continue
141
- x1, y1, x2, y2 = box.xyxy[0].tolist()
142
- w, h = x2 - x1, y2 - y1
143
- if w < 30 or h < 30:
144
- continue
145
- crop = small_pil.crop((x1, y1, x2, y2))
146
- crops.append(crop)
147
- if len(crops) >= MAX_CROPS + 1:
148
- break
149
- if len(crops) >= MAX_CROPS + 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  break
151
 
 
 
 
 
152
  print(f"🧠 Embedding {len(crops)} crop(s) in one batch …")
153
  vecs = self._embed_crops_batch(crops)
154
  for vec in vecs:
 
44
  )
45
  print(f"Loading models onto: {self.device.upper()}...")
46
 
47
+ self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
48
  self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval()
49
 
50
  self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
 
57
  # FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default.
58
  # This fixes the "InvalidCxxCompiler" Search crash.
59
 
60
+ self.yolo = YOLO("yolo11n-seg.pt") # seg model → pixel masks → accurate crops
61
 
62
  self._cache = {}
63
  self._cache_maxsize = 256
 
128
  except Exception as e:
129
  print(f"🟠 Face lane error: {e} — falling back to object lane")
130
 
131
+ # Full-res PIL for crops YOLO returns coordinates in full-res pixel space.
132
+ # We crop from original_pil then resize each crop before embedding.
133
+ # BUG FIX: old optimised code cropped from small_pil (512px) using
134
+ # full-res YOLO coordinates → completely wrong crop regions.
135
+ crops_pil = [original_pil] # full-image always included for global context
136
+
137
  yolo_results = self.yolo(image_path, conf=0.5, verbose=False)
138
 
139
  for r in yolo_results:
140
+ # Use segmentation masks when available (yolo11n-seg.pt)
141
+ if r.masks is not None:
142
+ for seg_idx, mask_xy in enumerate(r.masks.xy):
143
+ cls_id = int(r.boxes.cls[seg_idx].item())
144
+ if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
145
+ print("🔵 PERSON crop skipped — face lane already active")
146
+ continue
147
+ polygon = np.array(mask_xy, dtype=np.int32)
148
+ if len(polygon) < 3:
149
+ continue
150
+ x, y, w, h = cv2.boundingRect(polygon)
151
+ if w < 30 or h < 30:
152
+ continue
153
+ crop = original_pil.crop((x, y, x + w, y + h))
154
+ crops_pil.append(crop)
155
+ if len(crops_pil) >= MAX_CROPS + 1:
156
+ break
157
+ elif r.boxes is not None:
158
+ # Fallback: plain bounding boxes (shouldn't happen with seg model)
159
+ for box in r.boxes:
160
+ cls_id = int(box.cls.item())
161
+ if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
162
+ continue
163
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
164
+ if (x2 - x1) < 30 or (y2 - y1) < 30:
165
+ continue
166
+ crop = original_pil.crop((x1, y1, x2, y2))
167
+ crops_pil.append(crop)
168
+ if len(crops_pil) >= MAX_CROPS + 1:
169
  break
170
 
171
+ # Resize each crop to MAX_IMAGE_SIZE before batched embedding
172
+ # (models expect ~224px anyway; no quality loss, big speed gain)
173
+ crops = [_resize_pil(c, MAX_IMAGE_SIZE) for c in crops_pil]
174
+
175
  print(f"🧠 Embedding {len(crops)} crop(s) in one batch …")
176
  vecs = self._embed_crops_batch(crops)
177
  for vec in vecs: