File size: 12,407 Bytes
c4ee290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""
ShortSmith v2 - Body Recognizer Module

Full-body person recognition using OSNet for:
- Identifying people when face is not visible
- Back views, profile shots, masks, helmets
- Clothing and appearance-based matching

Complements face recognition for comprehensive person tracking.
"""

from pathlib import Path
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
import numpy as np

from utils.logger import get_logger, LogTimer
from utils.helpers import ModelLoadError, InferenceError
from config import get_config, ModelConfig

logger = get_logger("models.body_recognizer")


@dataclass
class BodyDetection:
    """Represents a detected person body in an image."""
    bbox: Tuple[int, int, int, int]  # (x1, y1, x2, y2)
    confidence: float                 # Detection confidence
    embedding: Optional[np.ndarray]   # Body appearance embedding
    track_id: Optional[int] = None    # Tracking ID if available

    @property
    def center(self) -> Tuple[int, int]:
        """Center point of body bounding box."""
        x1, y1, x2, y2 = self.bbox
        return ((x1 + x2) // 2, (y1 + y2) // 2)

    @property
    def area(self) -> int:
        """Area of bounding box."""
        x1, y1, x2, y2 = self.bbox
        return (x2 - x1) * (y2 - y1)

    @property
    def width(self) -> int:
        return self.bbox[2] - self.bbox[0]

    @property
    def height(self) -> int:
        return self.bbox[3] - self.bbox[1]

    @property
    def aspect_ratio(self) -> float:
        """Height/width ratio (typical person is ~2.5-3.0)."""
        if self.width == 0:
            return 0
        return self.height / self.width


@dataclass
class BodyMatch:
    """Result of body matching."""
    detection: BodyDetection
    similarity: float
    is_match: bool
    reference_id: Optional[str] = None


class BodyRecognizer:
    """
    Body recognition using person re-identification models.

    Uses:
    - YOLO or similar for person detection
    - OSNet for body appearance embeddings

    Designed to work alongside FaceRecognizer for complete
    person identification across all viewing angles.
    """

    def __init__(
        self,
        config: Optional[ModelConfig] = None,
        load_model: bool = True,
    ):
        """
        Initialize body recognizer.

        Args:
            config: Model configuration
            load_model: Whether to load models immediately
        """
        self.config = config or get_config().model
        self.detector = None
        self.reid_model = None
        self._reference_embeddings: dict = {}

        if load_model:
            self._load_models()

        logger.info(f"BodyRecognizer initialized (threshold={self.config.body_similarity_threshold})")

    def _load_models(self) -> None:
        """Load person detection and re-identification models."""
        with LogTimer(logger, "Loading body recognition models"):
            self._load_detector()
            self._load_reid_model()

    def _load_detector(self) -> None:
        """Load person detector (YOLO)."""
        try:
            from ultralytics import YOLO

            # Use YOLOv8 for person detection
            self.detector = YOLO("yolov8n.pt")  # Nano model for speed
            logger.info("YOLO detector loaded")

        except ImportError:
            logger.warning("ultralytics not installed, using fallback detection")
            self.detector = None

        except Exception as e:
            logger.warning(f"Failed to load YOLO detector: {e}")
            self.detector = None

    def _load_reid_model(self) -> None:
        """Load OSNet re-identification model."""
        try:
            import torch
            import torchvision.transforms as T
            from torchvision.models import mobilenet_v2

            # For simplicity, use MobileNetV2 as a feature extractor
            # In production, would use actual OSNet from torchreid
            self.reid_model = mobilenet_v2(pretrained=True)
            self.reid_model.classifier = torch.nn.Identity()  # Remove classifier

            if self.config.device == "cuda" and torch.cuda.is_available():
                self.reid_model = self.reid_model.cuda()

            self.reid_model.eval()

            # Transform for body crops
            self._transform = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 128)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

            logger.info("Re-ID model loaded (MobileNetV2 backbone)")

        except Exception as e:
            logger.warning(f"Failed to load re-ID model: {e}")
            self.reid_model = None

    def detect_persons(
        self,
        image: Union[str, Path, np.ndarray],
        min_confidence: float = 0.5,
        min_area: int = 2000,
    ) -> List[BodyDetection]:
        """
        Detect persons in an image.

        Args:
            image: Image path or numpy array (BGR format)
            min_confidence: Minimum detection confidence
            min_area: Minimum bounding box area

        Returns:
            List of BodyDetection objects
        """
        import cv2

        # Load image if path
        if isinstance(image, (str, Path)):
            img = cv2.imread(str(image))
            if img is None:
                raise InferenceError(f"Could not load image: {image}")
        else:
            img = image

        detections = []

        if self.detector is not None:
            try:
                # YOLO detection
                results = self.detector(img, classes=[0], verbose=False)  # class 0 = person

                for result in results:
                    for box in result.boxes:
                        conf = float(box.conf[0])
                        if conf < min_confidence:
                            continue

                        bbox = tuple(map(int, box.xyxy[0].tolist()))
                        area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])

                        if area < min_area:
                            continue

                        # Extract embedding
                        embedding = self._extract_embedding(img, bbox)

                        detections.append(BodyDetection(
                            bbox=bbox,
                            confidence=conf,
                            embedding=embedding,
                        ))

            except Exception as e:
                logger.warning(f"YOLO detection failed: {e}")
        else:
            # Fallback: assume full image is a person crop
            h, w = img.shape[:2]
            bbox = (0, 0, w, h)
            embedding = self._extract_embedding(img, bbox)

            detections.append(BodyDetection(
                bbox=bbox,
                confidence=1.0,
                embedding=embedding,
            ))

        logger.debug(f"Detected {len(detections)} persons")
        return detections

    def _extract_embedding(
        self,
        image: np.ndarray,
        bbox: Tuple[int, int, int, int],
    ) -> Optional[np.ndarray]:
        """Extract body appearance embedding."""
        if self.reid_model is None:
            return None

        try:
            import torch

            x1, y1, x2, y2 = bbox
            crop = image[y1:y2, x1:x2]

            if crop.size == 0:
                return None

            # Convert BGR to RGB
            crop_rgb = crop[:, :, ::-1]

            # Transform
            tensor = self._transform(crop_rgb).unsqueeze(0)

            if self.config.device == "cuda" and torch.cuda.is_available():
                tensor = tensor.cuda()

            # Extract features
            with torch.no_grad():
                embedding = self.reid_model(tensor)
                embedding = embedding.cpu().numpy()[0]

            # Normalize
            embedding = embedding / (np.linalg.norm(embedding) + 1e-8)

            return embedding

        except Exception as e:
            logger.debug(f"Embedding extraction failed: {e}")
            return None

    def register_reference(
        self,
        reference_image: Union[str, Path, np.ndarray],
        reference_id: str = "target",
        bbox: Optional[Tuple[int, int, int, int]] = None,
    ) -> bool:
        """
        Register a reference body appearance for matching.

        Args:
            reference_image: Image containing the reference person
            reference_id: Identifier for this reference
            bbox: Bounding box of person (auto-detected if None)

        Returns:
            True if registration successful
        """
        with LogTimer(logger, f"Registering body reference '{reference_id}'"):
            import cv2

            # Load image
            if isinstance(reference_image, (str, Path)):
                img = cv2.imread(str(reference_image))
            else:
                img = reference_image

            if bbox is None:
                # Detect person
                detections = self.detect_persons(img, min_confidence=0.5)
                if not detections:
                    raise InferenceError("No person detected in reference image")

                # Use largest detection
                detections.sort(key=lambda d: d.area, reverse=True)
                bbox = detections[0].bbox

            # Extract embedding
            embedding = self._extract_embedding(img, bbox)

            if embedding is None:
                raise InferenceError("Could not extract body embedding")

            self._reference_embeddings[reference_id] = embedding
            logger.info(f"Registered body reference: {reference_id}")
            return True

    def match_bodies(
        self,
        image: Union[str, Path, np.ndarray],
        reference_id: str = "target",
        threshold: Optional[float] = None,
    ) -> List[BodyMatch]:
        """
        Find body matches for a registered reference.

        Args:
            image: Image to search
            reference_id: Reference to match against
            threshold: Similarity threshold

        Returns:
            List of BodyMatch objects
        """
        threshold = threshold or self.config.body_similarity_threshold

        if reference_id not in self._reference_embeddings:
            logger.warning(f"Body reference '{reference_id}' not registered")
            return []

        reference = self._reference_embeddings[reference_id]
        detections = self.detect_persons(image)

        matches = []
        for detection in detections:
            if detection.embedding is None:
                continue

            similarity = self._cosine_similarity(reference, detection.embedding)

            matches.append(BodyMatch(
                detection=detection,
                similarity=similarity,
                is_match=similarity >= threshold,
                reference_id=reference_id,
            ))

        matches.sort(key=lambda m: m.similarity, reverse=True)
        return matches

    def find_target_in_frame(
        self,
        image: Union[str, Path, np.ndarray],
        reference_id: str = "target",
        threshold: Optional[float] = None,
    ) -> Optional[BodyMatch]:
        """
        Find the best matching body in a frame.

        Args:
            image: Frame to search
            reference_id: Reference to match against
            threshold: Similarity threshold

        Returns:
            Best BodyMatch if found, None otherwise
        """
        matches = self.match_bodies(image, reference_id, threshold)
        matching = [m for m in matches if m.is_match]

        if matching:
            return matching[0]
        return None

    def _cosine_similarity(
        self,
        embedding1: np.ndarray,
        embedding2: np.ndarray,
    ) -> float:
        """Compute cosine similarity."""
        return float(np.dot(embedding1, embedding2))

    def clear_references(self) -> None:
        """Clear all registered references."""
        self._reference_embeddings.clear()
        logger.info("Cleared all body references")

    def get_registered_references(self) -> List[str]:
        """Get list of registered reference IDs."""
        return list(self._reference_embeddings.keys())


# Export public interface
__all__ = ["BodyRecognizer", "BodyDetection", "BodyMatch"]