File size: 8,181 Bytes
3e805ab
58c92f2
 
 
 
362d86f
 
 
 
 
3e805ab
c96096b
 
3e805ab
c96096b
3e805ab
c96096b
 
 
 
362d86f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96096b
3e805ab
 
362d86f
 
 
 
c96096b
 
5d013dc
58c92f2
362d86f
 
58c92f2
362d86f
 
 
 
 
58c92f2
 
 
5d013dc
362d86f
 
 
 
 
 
 
 
 
c96096b
3e805ab
58c92f2
362d86f
 
58c92f2
362d86f
 
 
 
 
 
58c92f2
362d86f
58c92f2
362d86f
 
58c92f2
362d86f
 
58c92f2
 
362d86f
 
 
 
 
58c92f2
362d86f
 
 
 
 
 
 
 
 
 
c96096b
8c6ce56
 
362d86f
8c6ce56
 
 
 
58c92f2
362d86f
8c6ce56
362d86f
 
 
8c6ce56
362d86f
 
 
 
8c6ce56
362d86f
 
c96096b
5d013dc
 
 
 
 
 
362d86f
c96096b
 
5d013dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362d86f
 
5d013dc
 
 
 
362d86f
 
 
 
 
 
 
 
 
 
 
 
58c92f2
362d86f
58c92f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# src/models.py
import os
# FIX 1: Force Legacy Keras to prevent DeepFace/RetinaFace crash in TF 2.16+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Hides the annoying CUDA/cuInit warnings

import asyncio
import hashlib
import functools

import torch
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
from ultralytics import YOLO
import torch.nn.functional as F
from deepface import DeepFace

YOLO_PERSON_CLASS_ID = 0
MIN_FACE_AREA        = 3000   # ~55×55 px minimum face
MAX_CROPS            = 6      # max YOLO crops + 1 full-image crop per request
MAX_IMAGE_SIZE       = 512    # resize longest edge before any inference


def _resize_pil(img: Image.Image, max_side: int = MAX_IMAGE_SIZE) -> Image.Image:
    w, h = img.size
    if max(w, h) <= max_side:
        return img
    scale = max_side / max(w, h)
    return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

def _img_hash(image_path: str) -> str:
    h = hashlib.md5()
    with open(image_path, "rb") as f:
        h.update(f.read(65536))
    return h.hexdigest()

class AIModelManager:
    def __init__(self):
        self.device = (
            "cuda" if torch.cuda.is_available()
            else ("mps" if torch.backends.mps.is_available() else "cpu")
        )
        print(f"Loading models onto: {self.device.upper()}...")

        self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224", use_fast=False)
        self.siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(self.device).eval()

        self.dinov2_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
        self.dinov2_model = AutoModel.from_pretrained("facebook/dinov2-base").to(self.device).eval()

        if self.device == "cuda":
            self.siglip_model = self.siglip_model.half()
            self.dinov2_model = self.dinov2_model.half()

        # FIX 2: Removed torch.compile() because HF Spaces do not have the g++ compiler installed by default.
        # This fixes the "InvalidCxxCompiler" Search crash.
        
        self.yolo = YOLO("yolo11n-seg.pt")   # seg model → pixel masks → accurate crops

        self._cache = {}
        self._cache_maxsize = 256

        print("✅ Models ready!")

    def _embed_crops_batch(self, crops: list[Image.Image]) -> list[np.ndarray]:
        if not crops:
            return []

        with torch.no_grad():
            sig_inputs = self.siglip_processor(images=crops, return_tensors="pt", padding=True)
            sig_inputs = {k: v.to(self.device) for k, v in sig_inputs.items()}
            if self.device == "cuda":
                sig_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in sig_inputs.items()}

            sig_out = self.siglip_model.get_image_features(**sig_inputs)
            if hasattr(sig_out, "image_embeds"):
                sig_out = sig_out.image_embeds
            elif isinstance(sig_out, tuple):
                sig_out = sig_out[0]
            sig_vecs = F.normalize(sig_out.float(), p=2, dim=1).cpu()

            dino_inputs = self.dinov2_processor(images=crops, return_tensors="pt")
            dino_inputs = {k: v.to(self.device) for k, v in dino_inputs.items()}
            if self.device == "cuda":
                dino_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in dino_inputs.items()}

            dino_out  = self.dinov2_model(**dino_inputs)
            dino_vecs = dino_out.last_hidden_state[:, 0, :]
            dino_vecs = F.normalize(dino_vecs.float(), p=2, dim=1).cpu()

            fused = F.normalize(torch.cat([sig_vecs, dino_vecs], dim=1), p=2, dim=1)

        return [fused[i].numpy() for i in range(len(crops))]

    def process_image(self, image_path: str, is_query: bool = False, detect_faces: bool = True) -> list[dict]:
        cache_key = _img_hash(image_path)
        if cache_key in self._cache:
            print("⚡ Cache hit — skipping inference")
            return self._cache[cache_key]

        extracted = []
        original_pil = Image.open(image_path).convert("RGB")
        small_pil    = _resize_pil(original_pil, MAX_IMAGE_SIZE)
        img_np       = np.array(small_pil)
        faces_found  = False

        if detect_faces:
            try:
                print("🔍 Face detection …")
                face_objs = DeepFace.represent(
                    img_path=img_np,
                    model_name="GhostFaceNet",
                    detector_backend="retinaface",
                    enforce_detection=False,
                    align=True,
                )
                for face in (face_objs or []):
                    fa = face.get("facial_area", {})
                    if fa.get("w", 0) * fa.get("h", 0) < MIN_FACE_AREA:
                        continue
                    vec = torch.tensor([face["embedding"]])
                    vec = F.normalize(vec, p=2, dim=1)
                    extracted.append({"type": "face", "vector": vec.flatten().numpy()})
                    faces_found = True

            except Exception as e:
                print(f"🟠 Face lane error: {e} — falling back to object lane")

        # Full-res PIL for crops — YOLO returns coordinates in full-res pixel space.
        # We crop from original_pil then resize each crop before embedding.
        # BUG FIX: old optimised code cropped from small_pil (512px) using
        # full-res YOLO coordinates → completely wrong crop regions.
        crops_pil = [original_pil]   # full-image always included for global context

        yolo_results = self.yolo(image_path, conf=0.5, verbose=False)

        for r in yolo_results:
            # Use segmentation masks when available (yolo11n-seg.pt)
            if r.masks is not None:
                for seg_idx, mask_xy in enumerate(r.masks.xy):
                    cls_id = int(r.boxes.cls[seg_idx].item())
                    if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
                        print("🔵 PERSON crop skipped — face lane already active")
                        continue
                    polygon = np.array(mask_xy, dtype=np.int32)
                    if len(polygon) < 3:
                        continue
                    x, y, w, h = cv2.boundingRect(polygon)
                    if w < 30 or h < 30:
                        continue
                    crop = original_pil.crop((x, y, x + w, y + h))
                    crops_pil.append(crop)
                    if len(crops_pil) >= MAX_CROPS + 1:
                        break
            elif r.boxes is not None:
                # Fallback: plain bounding boxes (shouldn't happen with seg model)
                for box in r.boxes:
                    cls_id = int(box.cls.item())
                    if faces_found and cls_id == YOLO_PERSON_CLASS_ID:
                        continue
                    x1, y1, x2, y2 = box.xyxy[0].tolist()
                    if (x2 - x1) < 30 or (y2 - y1) < 30:
                        continue
                    crop = original_pil.crop((x1, y1, x2, y2))
                    crops_pil.append(crop)
            if len(crops_pil) >= MAX_CROPS + 1:
                break

        # Resize each crop to MAX_IMAGE_SIZE before batched embedding
        # (models expect ~224px anyway; no quality loss, big speed gain)
        crops = [_resize_pil(c, MAX_IMAGE_SIZE) for c in crops_pil]

        print(f"🧠 Embedding {len(crops)} crop(s) in one batch …")
        vecs = self._embed_crops_batch(crops)
        for vec in vecs:
            extracted.append({"type": "object", "vector": vec})

        if len(self._cache) >= self._cache_maxsize:
            oldest = next(iter(self._cache))
            del self._cache[oldest]
        self._cache[cache_key] = extracted

        return extracted

    async def process_image_async(self, image_path: str, is_query: bool = False, detect_faces: bool = True) -> list[dict]:
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, functools.partial(self.process_image, image_path, is_query, detect_faces))