Spaces:
Running
Running
| """Face detection and age estimation inference engine. | |
| This module implements the FaceAgeInferenceEngine class that coordinates | |
| face detection and age estimation using YOLO Face-Person Detector and MiVOLO v2. | |
| """ | |
| import time | |
| from contextlib import contextmanager | |
| from functools import lru_cache | |
| import numpy as np | |
| import torch | |
| from opentelemetry import metrics, trace | |
| from transformers import ( | |
| AutoConfig, | |
| AutoImageProcessor, | |
| AutoModel, | |
| AutoModelForImageClassification, | |
| ) | |
| from .config import Settings, settings | |
| from .image import compute_scaled_line_width, draw_face_annotations | |
| from .types import BoundingBox, InferenceError, InferenceOutput | |
| # Type alias for detection results | |
| type FaceDetections = list[BoundingBox] | |
| type PersonDetections = list[BoundingBox] | |
| # Get tracer for this module | |
| tracer = trace.get_tracer(__name__) | |
| # Get meter and create metrics instruments | |
| # Uses no-op provider when running standalone, real provider when ml-api sets one | |
| meter = metrics.get_meter(__name__) | |
| _inference_duration = meter.create_histogram( | |
| "inference.duration_ms", | |
| unit="ms", | |
| description="Total inference time in milliseconds", | |
| ) | |
| _yolo_duration = meter.create_histogram( | |
| "inference.yolo_duration_ms", | |
| unit="ms", | |
| description="YOLO face detection time in milliseconds", | |
| ) | |
| _mivolo_duration = meter.create_histogram( | |
| "inference.mivolo_duration_ms", | |
| unit="ms", | |
| description="MiVOLO age estimation time in milliseconds", | |
| ) | |
| _faces_detected = meter.create_counter( | |
| "inference.faces_detected", | |
| description="Total number of faces detected", | |
| ) | |
| _inference_errors = meter.create_counter( | |
| "inference.errors", | |
| description="Number of inference errors", | |
| ) | |
| def _telemetry_span(name: str, histogram=None): | |
| """Start a span and optionally record elapsed time to a histogram.""" | |
| start = time.perf_counter() | |
| with tracer.start_as_current_span(name) as span: | |
| yield span | |
| if histogram is not None: | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| histogram.record(elapsed_ms) | |
| def _compute_iou(box1: BoundingBox, box2: BoundingBox) -> float: | |
| """Compute intersection over union between two bounding boxes. | |
| Args: | |
| box1: First bounding box (x1, y1, x2, y2). | |
| box2: Second bounding box (x1, y1, x2, y2). | |
| Returns: | |
| IoU value between 0 and 1. | |
| """ | |
| x1 = max(box1[0], box2[0]) | |
| y1 = max(box1[1], box2[1]) | |
| x2 = min(box1[2], box2[2]) | |
| y2 = min(box1[3], box2[3]) | |
| if x2 <= x1 or y2 <= y1: | |
| return 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def _face_inside_person(face: BoundingBox, person: BoundingBox) -> bool: | |
| """Check if a face bounding box is inside a person bounding box. | |
| Args: | |
| face: Face bounding box (x1, y1, x2, y2). | |
| person: Person bounding box (x1, y1, x2, y2). | |
| Returns: | |
| True if face center is inside person box. | |
| """ | |
| face_cx = (face[0] + face[2]) / 2 | |
| face_cy = (face[1] + face[3]) / 2 | |
| return person[0] <= face_cx <= person[2] and person[1] <= face_cy <= person[3] | |
| class FaceAgeInferenceEngine: | |
| """Inference engine coordinating detection and age estimation. | |
| Uses YOLO Face-Person Detector for detection and MiVOLO v2 for age estimation. | |
| Models are automatically downloaded from HuggingFace Hub on first use. | |
| """ | |
| def __init__(self, service_settings: Settings | None = None) -> None: | |
| """Initialize inference models. | |
| Effectful: downloads models from HuggingFace Hub if not cached. | |
| Args: | |
| service_settings: Configuration object (uses global if None). | |
| Raises: | |
| InferenceError: If models cannot be loaded. | |
| """ | |
| self.settings = service_settings or settings | |
| # Determine torch dtype and device | |
| self.device = torch.device(self.settings.device) | |
| self.dtype = torch.float16 if "cuda" in self.settings.device else torch.float32 | |
| try: | |
| # Load YOLO Face-Person Detector from HuggingFace Hub | |
| self.detector = AutoModel.from_pretrained( | |
| self.settings.detector_model_id, | |
| trust_remote_code=True, | |
| dtype=self.dtype, | |
| ).to(self.device) | |
| # Load MiVOLO v2 config, model, and image processor | |
| self.mivolo_config = AutoConfig.from_pretrained( | |
| self.settings.mivolo_model_id, | |
| trust_remote_code=True, | |
| ) | |
| self.mivolo = AutoModelForImageClassification.from_pretrained( | |
| self.settings.mivolo_model_id, | |
| trust_remote_code=True, | |
| dtype=self.dtype, | |
| ).to(self.device) | |
| self.image_processor = AutoImageProcessor.from_pretrained( | |
| self.settings.mivolo_model_id, | |
| trust_remote_code=True, | |
| ) | |
| except Exception as exc: | |
| raise InferenceError( | |
| f"Failed to load models from HuggingFace Hub: {exc}" | |
| ) from exc | |
| def _extract_detections(self, results) -> tuple[FaceDetections, PersonDetections]: | |
| """Extract face and person bounding boxes from YOLO results. | |
| Args: | |
| results: YOLO detection results. | |
| Returns: | |
| Tuple of (face_boxes, person_boxes) where each box is (x1, y1, x2, y2). | |
| """ | |
| faces: FaceDetections = [] | |
| persons: PersonDetections = [] | |
| for box in results.boxes: | |
| cls_id = int(box.cls.item()) | |
| cls_name = results.names[cls_id].lower() | |
| coords = box.xyxy[0].cpu().numpy() | |
| bbox: BoundingBox = ( | |
| int(coords[0]), | |
| int(coords[1]), | |
| int(coords[2]), | |
| int(coords[3]), | |
| ) | |
| if cls_name == "face": | |
| faces.append(bbox) | |
| elif cls_name == "person": | |
| persons.append(bbox) | |
| return faces, persons | |
| def _match_faces_to_persons( | |
| self, | |
| faces: FaceDetections, | |
| persons: PersonDetections, | |
| ) -> list[tuple[BoundingBox, BoundingBox | None]]: | |
| """Match each face to its corresponding person bounding box. | |
| Args: | |
| faces: List of face bounding boxes. | |
| persons: List of person bounding boxes. | |
| Returns: | |
| List of (face, person) pairs. Person may be None if no match found. | |
| """ | |
| matched: list[tuple[BoundingBox, BoundingBox | None]] = [] | |
| for face in faces: | |
| best_person: BoundingBox | None = None | |
| best_overlap = 0.0 | |
| for person in persons: | |
| if _face_inside_person(face, person): | |
| overlap = _compute_iou(face, person) | |
| if overlap > best_overlap or best_person is None: | |
| best_person = person | |
| best_overlap = overlap | |
| matched.append((face, best_person)) | |
| return matched | |
| def _crop_regions( | |
| self, | |
| image_bgr: np.ndarray, | |
| matched_pairs: list[tuple[BoundingBox, BoundingBox | None]], | |
| ) -> tuple[list[np.ndarray], list[np.ndarray | None]]: | |
| """Crop face and body regions from image. | |
| Args: | |
| image_bgr: Input image in BGR format. | |
| matched_pairs: List of (face, person) bounding box pairs. | |
| Returns: | |
| Tuple of (face_crops, body_crops). Body crop may be None if no person matched. | |
| """ | |
| face_crops: list[np.ndarray] = [] | |
| body_crops: list[np.ndarray | None] = [] | |
| h, w = image_bgr.shape[:2] | |
| for face, person in matched_pairs: | |
| # Crop face (clamp to image bounds) | |
| x1, y1, x2, y2 = face | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(w, x2), min(h, y2) | |
| face_crop = image_bgr[y1:y2, x1:x2] | |
| face_crops.append(face_crop) | |
| # Crop body if available | |
| if person is not None: | |
| px1, py1, px2, py2 = person | |
| px1, py1 = max(0, px1), max(0, py1) | |
| px2, py2 = min(w, px2), min(h, py2) | |
| body_crop = image_bgr[py1:py2, px1:px2] | |
| body_crops.append(body_crop) | |
| else: | |
| body_crops.append(None) | |
| return face_crops, body_crops | |
| def _run_mivolo( | |
| self, | |
| face_crops: list[np.ndarray], | |
| body_crops: list[np.ndarray | None], | |
| ) -> list[float]: | |
| """Run MiVOLO v2 age estimation on cropped regions. | |
| Uses chunked batching to avoid OOM on group photos with many faces. | |
| Args: | |
| face_crops: List of face crop images (BGR). | |
| body_crops: List of body crop images (BGR), may contain None. | |
| Returns: | |
| List of estimated ages. | |
| """ | |
| if not face_crops: | |
| return [] | |
| batch_size = max(1, int(self.settings.mivolo_batch_size)) | |
| def _run_batch( | |
| batch_faces: list[np.ndarray], | |
| batch_bodies: list[np.ndarray | None], | |
| ) -> list[float]: | |
| faces_input = self.image_processor(images=batch_faces)["pixel_values"] | |
| faces_input = faces_input.to(dtype=self.dtype, device=self.device) | |
| valid_body_indices: list[int] = [] | |
| valid_body_images: list[np.ndarray] = [] | |
| for i, body_crop in enumerate(batch_bodies): | |
| if body_crop is not None: | |
| valid_body_indices.append(i) | |
| valid_body_images.append(body_crop) | |
| body_input = torch.zeros_like(faces_input) | |
| if valid_body_images: | |
| valid_body_input = self.image_processor(images=valid_body_images)[ | |
| "pixel_values" | |
| ] | |
| valid_body_input = valid_body_input.to(dtype=self.dtype, device=self.device) | |
| for tensor_idx, batch_idx in enumerate(valid_body_indices): | |
| body_input[batch_idx] = valid_body_input[tensor_idx] | |
| with torch.no_grad(): | |
| output = self.mivolo(faces_input=faces_input, body_input=body_input) | |
| return output.age_output.cpu().flatten().tolist() | |
| ages: list[float] = [] | |
| for start in range(0, len(face_crops), batch_size): | |
| ages.extend( | |
| _run_batch( | |
| face_crops[start : start + batch_size], | |
| body_crops[start : start + batch_size], | |
| ) | |
| ) | |
| return ages | |
| def _run_yolo_detection( | |
| self, | |
| image_bgr: np.ndarray, | |
| ) -> tuple[FaceDetections, PersonDetections]: | |
| """Run YOLO face/person detection with telemetry.""" | |
| with _telemetry_span("inference.yolo_detection", _yolo_duration) as det_span: | |
| results = self.detector( | |
| image_bgr, | |
| conf=self.settings.confidence_threshold, | |
| iou=self.settings.iou_threshold, | |
| )[0] | |
| faces, persons = self._extract_detections(results) | |
| det_span.set_attribute("faces_detected", len(faces)) | |
| det_span.set_attribute("persons_detected", len(persons)) | |
| _faces_detected.add(len(faces)) | |
| return faces, persons | |
| def _run_mivolo_with_metrics( | |
| self, | |
| face_crops: list[np.ndarray], | |
| body_crops: list[np.ndarray | None], | |
| ) -> list[float]: | |
| """Run MiVOLO v2 age estimation with telemetry.""" | |
| with _telemetry_span("inference.mivolo_age", _mivolo_duration) as age_span: | |
| ages = self._run_mivolo(face_crops, body_crops) | |
| age_span.set_attribute("ages_estimated", len(ages)) | |
| return ages | |
| def predict(self, image_bgr: np.ndarray) -> InferenceOutput: | |
| """Run face detection and age estimation on an image. | |
| Effectful: calls ML models, renders annotations. | |
| Args: | |
| image_bgr: Input image in BGR format. | |
| Returns: | |
| Inference results with annotated image. | |
| Raises: | |
| InferenceError: If inference or annotation fails. | |
| """ | |
| if image_bgr.size == 0: | |
| raise InferenceError("Decoded image is empty.") | |
| with _telemetry_span("inference.predict", _inference_duration) as span: | |
| span.set_attribute("image.height", image_bgr.shape[0]) | |
| span.set_attribute("image.width", image_bgr.shape[1]) | |
| try: | |
| # 1. Run face+person detection | |
| faces, persons = self._run_yolo_detection(image_bgr) | |
| # 2. Match faces to persons | |
| matched_pairs = self._match_faces_to_persons(faces, persons) | |
| # 3. Crop face and body regions | |
| face_crops, body_crops = self._crop_regions(image_bgr, matched_pairs) | |
| # 4. Run MiVOLO v2 on crops | |
| ages = self._run_mivolo_with_metrics(face_crops, body_crops) | |
| # 5. Compute annotation parameters | |
| line_width = compute_scaled_line_width(image_bgr.shape) | |
| # 6. Draw annotations (face boxes only with age labels) | |
| annotated = draw_face_annotations(image_bgr, faces, ages, line_width) | |
| span.set_attribute("total_faces", len(faces)) | |
| except InferenceError: | |
| _inference_errors.add(1, {"error_type": "inference_error"}) | |
| raise | |
| except Exception as exc: | |
| _inference_errors.add(1, {"error_type": "unknown_error"}) | |
| span.record_exception(exc) | |
| raise InferenceError( | |
| "Unable to run inference on the provided image." | |
| ) from exc | |
| return InferenceOutput( | |
| ages=tuple(ages), | |
| annotated_image=annotated, | |
| ) | |
| def get_inference_engine() -> FaceAgeInferenceEngine: | |
| """Get or create singleton inference engine. | |
| Cached to avoid reloading heavy ML models. | |
| Returns: | |
| Initialized inference engine. | |
| """ | |
| return FaceAgeInferenceEngine() | |
| __all__ = [ | |
| "FaceAgeInferenceEngine", | |
| "get_inference_engine", | |
| ] | |