Spaces:
Running
Running
Update src/models.py
Browse files- src/models.py +41 -18
src/models.py
CHANGED
|
@@ -44,7 +44,7 @@ class AIModelManager:
|
|
| 44 |
)
|
| 45 |
print(f"Loading models onto: {self.device.upper()}...")
|
| 46 |
|
| 47 |
-
self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=
|
| 48 |
self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval()
|
| 49 |
|
| 50 |
self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
|
@@ -57,7 +57,7 @@ class AIModelManager:
|
|
| 57 |
# FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default.
|
| 58 |
# This fixes the "InvalidCxxCompiler" Search crash.
|
| 59 |
|
| 60 |
-
self.yolo = YOLO("yolo11n.pt")
|
| 61 |
|
| 62 |
self._cache = {}
|
| 63 |
self._cache_maxsize = 256
|
|
@@ -128,27 +128,50 @@ class AIModelManager:
|
|
| 128 |
except Exception as e:
|
| 129 |
print(f"🟠 Face lane error: {e} — falling back to object lane")
|
| 130 |
|
| 131 |
-
crops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
yolo_results = self.yolo(image_path, conf=0.5, verbose=False)
|
| 133 |
|
| 134 |
for r in yolo_results:
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
break
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
print(f"🧠 Embedding {len(crops)} crop(s) in one batch …")
|
| 153 |
vecs = self._embed_crops_batch(crops)
|
| 154 |
for vec in vecs:
|
|
|
|
| 44 |
)
|
| 45 |
print(f"Loading models onto: {self.device.upper()}...")
|
| 46 |
|
| 47 |
+
self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
|
| 48 |
self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval()
|
| 49 |
|
| 50 |
self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
|
|
|
| 57 |
# FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default.
|
| 58 |
# This fixes the "InvalidCxxCompiler" Search crash.
|
| 59 |
|
| 60 |
+
self.yolo = YOLO("yolo11n-seg.pt") # seg model → pixel masks → accurate crops
|
| 61 |
|
| 62 |
self._cache = {}
|
| 63 |
self._cache_maxsize = 256
|
|
|
|
| 128 |
except Exception as e:
|
| 129 |
print(f"🟠 Face lane error: {e} — falling back to object lane")
|
| 130 |
|
| 131 |
+
# Full-res PIL for crops — YOLO returns coordinates in full-res pixel space.
|
| 132 |
+
# We crop from original_pil then resize each crop before embedding.
|
| 133 |
+
# BUG FIX: old optimised code cropped from small_pil (512px) using
|
| 134 |
+
# full-res YOLO coordinates → completely wrong crop regions.
|
| 135 |
+
crops_pil = [original_pil] # full-image always included for global context
|
| 136 |
+
|
| 137 |
yolo_results = self.yolo(image_path, conf=0.5, verbose=False)
|
| 138 |
|
| 139 |
for r in yolo_results:
|
| 140 |
+
# Use segmentation masks when available (yolo11n-seg.pt)
|
| 141 |
+
if r.masks is not None:
|
| 142 |
+
for seg_idx, mask_xy in enumerate(r.masks.xy):
|
| 143 |
+
cls_id = int(r.boxes.cls[seg_idx].item())
|
| 144 |
+
if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
|
| 145 |
+
print("🔵 PERSON crop skipped — face lane already active")
|
| 146 |
+
continue
|
| 147 |
+
polygon = np.array(mask_xy, dtype=np.int32)
|
| 148 |
+
if len(polygon) < 3:
|
| 149 |
+
continue
|
| 150 |
+
x, y, w, h = cv2.boundingRect(polygon)
|
| 151 |
+
if w < 30 or h < 30:
|
| 152 |
+
continue
|
| 153 |
+
crop = original_pil.crop((x, y, x + w, y + h))
|
| 154 |
+
crops_pil.append(crop)
|
| 155 |
+
if len(crops_pil) >= MAX_CROPS + 1:
|
| 156 |
+
break
|
| 157 |
+
elif r.boxes is not None:
|
| 158 |
+
# Fallback: plain bounding boxes (shouldn't happen with seg model)
|
| 159 |
+
for box in r.boxes:
|
| 160 |
+
cls_id = int(box.cls.item())
|
| 161 |
+
if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
|
| 162 |
+
continue
|
| 163 |
+
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
| 164 |
+
if (x2 - x1) < 30 or (y2 - y1) < 30:
|
| 165 |
+
continue
|
| 166 |
+
crop = original_pil.crop((x1, y1, x2, y2))
|
| 167 |
+
crops_pil.append(crop)
|
| 168 |
+
if len(crops_pil) >= MAX_CROPS + 1:
|
| 169 |
break
|
| 170 |
|
| 171 |
+
# Resize each crop to MAX_IMAGE_SIZE before batched embedding
|
| 172 |
+
# (models expect ~224px anyway; no quality loss, big speed gain)
|
| 173 |
+
crops = [_resize_pil(c, MAX_IMAGE_SIZE) for c in crops_pil]
|
| 174 |
+
|
| 175 |
print(f"🧠 Embedding {len(crops)} crop(s) in one batch …")
|
| 176 |
vecs = self._embed_crops_batch(crops)
|
| 177 |
for vec in vecs:
|