"""Face detection and age estimation inference engine. This module implements the FaceAgeInferenceEngine class that coordinates face detection and age estimation using YOLO Face-Person Detector and MiVOLO v2. """ import time from contextlib import contextmanager from functools import lru_cache import numpy as np import torch from opentelemetry import metrics, trace from transformers import ( AutoConfig, AutoImageProcessor, AutoModel, AutoModelForImageClassification, ) from .config import Settings, settings from .image import compute_scaled_line_width, draw_face_annotations from .types import BoundingBox, InferenceError, InferenceOutput # Type alias for detection results type FaceDetections = list[BoundingBox] type PersonDetections = list[BoundingBox] # Get tracer for this module tracer = trace.get_tracer(__name__) # Get meter and create metrics instruments # Uses no-op provider when running standalone, real provider when ml-api sets one meter = metrics.get_meter(__name__) _inference_duration = meter.create_histogram( "inference.duration_ms", unit="ms", description="Total inference time in milliseconds", ) _yolo_duration = meter.create_histogram( "inference.yolo_duration_ms", unit="ms", description="YOLO face detection time in milliseconds", ) _mivolo_duration = meter.create_histogram( "inference.mivolo_duration_ms", unit="ms", description="MiVOLO age estimation time in milliseconds", ) _faces_detected = meter.create_counter( "inference.faces_detected", description="Total number of faces detected", ) _inference_errors = meter.create_counter( "inference.errors", description="Number of inference errors", ) @contextmanager def _telemetry_span(name: str, histogram=None): """Start a span and optionally record elapsed time to a histogram.""" start = time.perf_counter() with tracer.start_as_current_span(name) as span: yield span if histogram is not None: elapsed_ms = (time.perf_counter() - start) * 1000 histogram.record(elapsed_ms) def _compute_iou(box1: BoundingBox, box2: BoundingBox) -> float: """Compute intersection over union between two bounding boxes. Args: box1: First bounding box (x1, y1, x2, y2). box2: Second bounding box (x1, y1, x2, y2). Returns: IoU value between 0 and 1. """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) if x2 <= x1 or y2 <= y1: return 0.0 intersection = (x2 - x1) * (y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0.0 def _face_inside_person(face: BoundingBox, person: BoundingBox) -> bool: """Check if a face bounding box is inside a person bounding box. Args: face: Face bounding box (x1, y1, x2, y2). person: Person bounding box (x1, y1, x2, y2). Returns: True if face center is inside person box. """ face_cx = (face[0] + face[2]) / 2 face_cy = (face[1] + face[3]) / 2 return person[0] <= face_cx <= person[2] and person[1] <= face_cy <= person[3] class FaceAgeInferenceEngine: """Inference engine coordinating detection and age estimation. Uses YOLO Face-Person Detector for detection and MiVOLO v2 for age estimation. Models are automatically downloaded from HuggingFace Hub on first use. """ def __init__(self, service_settings: Settings | None = None) -> None: """Initialize inference models. Effectful: downloads models from HuggingFace Hub if not cached. Args: service_settings: Configuration object (uses global if None). Raises: InferenceError: If models cannot be loaded. """ self.settings = service_settings or settings # Determine torch dtype and device self.device = torch.device(self.settings.device) self.dtype = torch.float16 if "cuda" in self.settings.device else torch.float32 try: # Load YOLO Face-Person Detector from HuggingFace Hub self.detector = AutoModel.from_pretrained( self.settings.detector_model_id, trust_remote_code=True, dtype=self.dtype, ).to(self.device) # Load MiVOLO v2 config, model, and image processor self.mivolo_config = AutoConfig.from_pretrained( self.settings.mivolo_model_id, trust_remote_code=True, ) self.mivolo = AutoModelForImageClassification.from_pretrained( self.settings.mivolo_model_id, trust_remote_code=True, dtype=self.dtype, ).to(self.device) self.image_processor = AutoImageProcessor.from_pretrained( self.settings.mivolo_model_id, trust_remote_code=True, ) except Exception as exc: raise InferenceError( f"Failed to load models from HuggingFace Hub: {exc}" ) from exc def _extract_detections(self, results) -> tuple[FaceDetections, PersonDetections]: """Extract face and person bounding boxes from YOLO results. Args: results: YOLO detection results. Returns: Tuple of (face_boxes, person_boxes) where each box is (x1, y1, x2, y2). """ faces: FaceDetections = [] persons: PersonDetections = [] for box in results.boxes: cls_id = int(box.cls.item()) cls_name = results.names[cls_id].lower() coords = box.xyxy[0].cpu().numpy() bbox: BoundingBox = ( int(coords[0]), int(coords[1]), int(coords[2]), int(coords[3]), ) if cls_name == "face": faces.append(bbox) elif cls_name == "person": persons.append(bbox) return faces, persons def _match_faces_to_persons( self, faces: FaceDetections, persons: PersonDetections, ) -> list[tuple[BoundingBox, BoundingBox | None]]: """Match each face to its corresponding person bounding box. Args: faces: List of face bounding boxes. persons: List of person bounding boxes. Returns: List of (face, person) pairs. Person may be None if no match found. """ matched: list[tuple[BoundingBox, BoundingBox | None]] = [] for face in faces: best_person: BoundingBox | None = None best_overlap = 0.0 for person in persons: if _face_inside_person(face, person): overlap = _compute_iou(face, person) if overlap > best_overlap or best_person is None: best_person = person best_overlap = overlap matched.append((face, best_person)) return matched def _crop_regions( self, image_bgr: np.ndarray, matched_pairs: list[tuple[BoundingBox, BoundingBox | None]], ) -> tuple[list[np.ndarray], list[np.ndarray | None]]: """Crop face and body regions from image. Args: image_bgr: Input image in BGR format. matched_pairs: List of (face, person) bounding box pairs. Returns: Tuple of (face_crops, body_crops). Body crop may be None if no person matched. """ face_crops: list[np.ndarray] = [] body_crops: list[np.ndarray | None] = [] h, w = image_bgr.shape[:2] for face, person in matched_pairs: # Crop face (clamp to image bounds) x1, y1, x2, y2 = face x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) face_crop = image_bgr[y1:y2, x1:x2] face_crops.append(face_crop) # Crop body if available if person is not None: px1, py1, px2, py2 = person px1, py1 = max(0, px1), max(0, py1) px2, py2 = min(w, px2), min(h, py2) body_crop = image_bgr[py1:py2, px1:px2] body_crops.append(body_crop) else: body_crops.append(None) return face_crops, body_crops def _run_mivolo( self, face_crops: list[np.ndarray], body_crops: list[np.ndarray | None], ) -> list[float]: """Run MiVOLO v2 age estimation on cropped regions. Uses chunked batching to avoid OOM on group photos with many faces. Args: face_crops: List of face crop images (BGR). body_crops: List of body crop images (BGR), may contain None. Returns: List of estimated ages. """ if not face_crops: return [] batch_size = max(1, int(self.settings.mivolo_batch_size)) def _run_batch( batch_faces: list[np.ndarray], batch_bodies: list[np.ndarray | None], ) -> list[float]: faces_input = self.image_processor(images=batch_faces)["pixel_values"] faces_input = faces_input.to(dtype=self.dtype, device=self.device) valid_body_indices: list[int] = [] valid_body_images: list[np.ndarray] = [] for i, body_crop in enumerate(batch_bodies): if body_crop is not None: valid_body_indices.append(i) valid_body_images.append(body_crop) body_input = torch.zeros_like(faces_input) if valid_body_images: valid_body_input = self.image_processor(images=valid_body_images)[ "pixel_values" ] valid_body_input = valid_body_input.to(dtype=self.dtype, device=self.device) for tensor_idx, batch_idx in enumerate(valid_body_indices): body_input[batch_idx] = valid_body_input[tensor_idx] with torch.no_grad(): output = self.mivolo(faces_input=faces_input, body_input=body_input) return output.age_output.cpu().flatten().tolist() ages: list[float] = [] for start in range(0, len(face_crops), batch_size): ages.extend( _run_batch( face_crops[start : start + batch_size], body_crops[start : start + batch_size], ) ) return ages def _run_yolo_detection( self, image_bgr: np.ndarray, ) -> tuple[FaceDetections, PersonDetections]: """Run YOLO face/person detection with telemetry.""" with _telemetry_span("inference.yolo_detection", _yolo_duration) as det_span: results = self.detector( image_bgr, conf=self.settings.confidence_threshold, iou=self.settings.iou_threshold, )[0] faces, persons = self._extract_detections(results) det_span.set_attribute("faces_detected", len(faces)) det_span.set_attribute("persons_detected", len(persons)) _faces_detected.add(len(faces)) return faces, persons def _run_mivolo_with_metrics( self, face_crops: list[np.ndarray], body_crops: list[np.ndarray | None], ) -> list[float]: """Run MiVOLO v2 age estimation with telemetry.""" with _telemetry_span("inference.mivolo_age", _mivolo_duration) as age_span: ages = self._run_mivolo(face_crops, body_crops) age_span.set_attribute("ages_estimated", len(ages)) return ages def predict(self, image_bgr: np.ndarray) -> InferenceOutput: """Run face detection and age estimation on an image. Effectful: calls ML models, renders annotations. Args: image_bgr: Input image in BGR format. Returns: Inference results with annotated image. Raises: InferenceError: If inference or annotation fails. """ if image_bgr.size == 0: raise InferenceError("Decoded image is empty.") with _telemetry_span("inference.predict", _inference_duration) as span: span.set_attribute("image.height", image_bgr.shape[0]) span.set_attribute("image.width", image_bgr.shape[1]) try: # 1. Run face+person detection faces, persons = self._run_yolo_detection(image_bgr) # 2. Match faces to persons matched_pairs = self._match_faces_to_persons(faces, persons) # 3. Crop face and body regions face_crops, body_crops = self._crop_regions(image_bgr, matched_pairs) # 4. Run MiVOLO v2 on crops ages = self._run_mivolo_with_metrics(face_crops, body_crops) # 5. Compute annotation parameters line_width = compute_scaled_line_width(image_bgr.shape) # 6. Draw annotations (face boxes only with age labels) annotated = draw_face_annotations(image_bgr, faces, ages, line_width) span.set_attribute("total_faces", len(faces)) except InferenceError: _inference_errors.add(1, {"error_type": "inference_error"}) raise except Exception as exc: _inference_errors.add(1, {"error_type": "unknown_error"}) span.record_exception(exc) raise InferenceError( "Unable to run inference on the provided image." ) from exc return InferenceOutput( ages=tuple(ages), annotated_image=annotated, ) @lru_cache(maxsize=1) def get_inference_engine() -> FaceAgeInferenceEngine: """Get or create singleton inference engine. Cached to avoid reloading heavy ML models. Returns: Initialized inference engine. """ return FaceAgeInferenceEngine() __all__ = [ "FaceAgeInferenceEngine", "get_inference_engine", ]