github-actions[bot]
Deploy from 128291a769737147011181c09a08b5186e167d8e
a2b95a9
raw
history blame
14.8 kB
"""Face detection and age estimation inference engine.
This module implements the FaceAgeInferenceEngine class that coordinates
face detection and age estimation using YOLO Face-Person Detector and MiVOLO v2.
"""
import time
from contextlib import contextmanager
from functools import lru_cache
import numpy as np
import torch
from opentelemetry import metrics, trace
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModel,
AutoModelForImageClassification,
)
from .config import Settings, settings
from .image import compute_scaled_line_width, draw_face_annotations
from .types import BoundingBox, InferenceError, InferenceOutput
# Type alias for detection results
type FaceDetections = list[BoundingBox]
type PersonDetections = list[BoundingBox]
# Get tracer for this module
tracer = trace.get_tracer(__name__)
# Get meter and create metrics instruments
# Uses no-op provider when running standalone, real provider when ml-api sets one
meter = metrics.get_meter(__name__)
_inference_duration = meter.create_histogram(
"inference.duration_ms",
unit="ms",
description="Total inference time in milliseconds",
)
_yolo_duration = meter.create_histogram(
"inference.yolo_duration_ms",
unit="ms",
description="YOLO face detection time in milliseconds",
)
_mivolo_duration = meter.create_histogram(
"inference.mivolo_duration_ms",
unit="ms",
description="MiVOLO age estimation time in milliseconds",
)
_faces_detected = meter.create_counter(
"inference.faces_detected",
description="Total number of faces detected",
)
_inference_errors = meter.create_counter(
"inference.errors",
description="Number of inference errors",
)
@contextmanager
def _telemetry_span(name: str, histogram=None):
"""Start a span and optionally record elapsed time to a histogram."""
start = time.perf_counter()
with tracer.start_as_current_span(name) as span:
yield span
if histogram is not None:
elapsed_ms = (time.perf_counter() - start) * 1000
histogram.record(elapsed_ms)
def _compute_iou(box1: BoundingBox, box2: BoundingBox) -> float:
"""Compute intersection over union between two bounding boxes.
Args:
box1: First bounding box (x1, y1, x2, y2).
box2: Second bounding box (x1, y1, x2, y2).
Returns:
IoU value between 0 and 1.
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def _face_inside_person(face: BoundingBox, person: BoundingBox) -> bool:
"""Check if a face bounding box is inside a person bounding box.
Args:
face: Face bounding box (x1, y1, x2, y2).
person: Person bounding box (x1, y1, x2, y2).
Returns:
True if face center is inside person box.
"""
face_cx = (face[0] + face[2]) / 2
face_cy = (face[1] + face[3]) / 2
return person[0] <= face_cx <= person[2] and person[1] <= face_cy <= person[3]
class FaceAgeInferenceEngine:
"""Inference engine coordinating detection and age estimation.
Uses YOLO Face-Person Detector for detection and MiVOLO v2 for age estimation.
Models are automatically downloaded from HuggingFace Hub on first use.
"""
def __init__(self, service_settings: Settings | None = None) -> None:
"""Initialize inference models.
Effectful: downloads models from HuggingFace Hub if not cached.
Args:
service_settings: Configuration object (uses global if None).
Raises:
InferenceError: If models cannot be loaded.
"""
self.settings = service_settings or settings
# Determine torch dtype and device
self.device = torch.device(self.settings.device)
self.dtype = torch.float16 if "cuda" in self.settings.device else torch.float32
try:
# Load YOLO Face-Person Detector from HuggingFace Hub
self.detector = AutoModel.from_pretrained(
self.settings.detector_model_id,
trust_remote_code=True,
dtype=self.dtype,
).to(self.device)
# Load MiVOLO v2 config, model, and image processor
self.mivolo_config = AutoConfig.from_pretrained(
self.settings.mivolo_model_id,
trust_remote_code=True,
)
self.mivolo = AutoModelForImageClassification.from_pretrained(
self.settings.mivolo_model_id,
trust_remote_code=True,
dtype=self.dtype,
).to(self.device)
self.image_processor = AutoImageProcessor.from_pretrained(
self.settings.mivolo_model_id,
trust_remote_code=True,
)
except Exception as exc:
raise InferenceError(
f"Failed to load models from HuggingFace Hub: {exc}"
) from exc
def _extract_detections(self, results) -> tuple[FaceDetections, PersonDetections]:
"""Extract face and person bounding boxes from YOLO results.
Args:
results: YOLO detection results.
Returns:
Tuple of (face_boxes, person_boxes) where each box is (x1, y1, x2, y2).
"""
faces: FaceDetections = []
persons: PersonDetections = []
for box in results.boxes:
cls_id = int(box.cls.item())
cls_name = results.names[cls_id].lower()
coords = box.xyxy[0].cpu().numpy()
bbox: BoundingBox = (
int(coords[0]),
int(coords[1]),
int(coords[2]),
int(coords[3]),
)
if cls_name == "face":
faces.append(bbox)
elif cls_name == "person":
persons.append(bbox)
return faces, persons
def _match_faces_to_persons(
self,
faces: FaceDetections,
persons: PersonDetections,
) -> list[tuple[BoundingBox, BoundingBox | None]]:
"""Match each face to its corresponding person bounding box.
Args:
faces: List of face bounding boxes.
persons: List of person bounding boxes.
Returns:
List of (face, person) pairs. Person may be None if no match found.
"""
matched: list[tuple[BoundingBox, BoundingBox | None]] = []
for face in faces:
best_person: BoundingBox | None = None
best_overlap = 0.0
for person in persons:
if _face_inside_person(face, person):
overlap = _compute_iou(face, person)
if overlap > best_overlap or best_person is None:
best_person = person
best_overlap = overlap
matched.append((face, best_person))
return matched
def _crop_regions(
self,
image_bgr: np.ndarray,
matched_pairs: list[tuple[BoundingBox, BoundingBox | None]],
) -> tuple[list[np.ndarray], list[np.ndarray | None]]:
"""Crop face and body regions from image.
Args:
image_bgr: Input image in BGR format.
matched_pairs: List of (face, person) bounding box pairs.
Returns:
Tuple of (face_crops, body_crops). Body crop may be None if no person matched.
"""
face_crops: list[np.ndarray] = []
body_crops: list[np.ndarray | None] = []
h, w = image_bgr.shape[:2]
for face, person in matched_pairs:
# Crop face (clamp to image bounds)
x1, y1, x2, y2 = face
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
face_crop = image_bgr[y1:y2, x1:x2]
face_crops.append(face_crop)
# Crop body if available
if person is not None:
px1, py1, px2, py2 = person
px1, py1 = max(0, px1), max(0, py1)
px2, py2 = min(w, px2), min(h, py2)
body_crop = image_bgr[py1:py2, px1:px2]
body_crops.append(body_crop)
else:
body_crops.append(None)
return face_crops, body_crops
def _run_mivolo(
self,
face_crops: list[np.ndarray],
body_crops: list[np.ndarray | None],
) -> list[float]:
"""Run MiVOLO v2 age estimation on cropped regions.
Uses chunked batching to avoid OOM on group photos with many faces.
Args:
face_crops: List of face crop images (BGR).
body_crops: List of body crop images (BGR), may contain None.
Returns:
List of estimated ages.
"""
if not face_crops:
return []
batch_size = max(1, int(self.settings.mivolo_batch_size))
def _run_batch(
batch_faces: list[np.ndarray],
batch_bodies: list[np.ndarray | None],
) -> list[float]:
faces_input = self.image_processor(images=batch_faces)["pixel_values"]
faces_input = faces_input.to(dtype=self.dtype, device=self.device)
valid_body_indices: list[int] = []
valid_body_images: list[np.ndarray] = []
for i, body_crop in enumerate(batch_bodies):
if body_crop is not None:
valid_body_indices.append(i)
valid_body_images.append(body_crop)
body_input = torch.zeros_like(faces_input)
if valid_body_images:
valid_body_input = self.image_processor(images=valid_body_images)[
"pixel_values"
]
valid_body_input = valid_body_input.to(dtype=self.dtype, device=self.device)
for tensor_idx, batch_idx in enumerate(valid_body_indices):
body_input[batch_idx] = valid_body_input[tensor_idx]
with torch.no_grad():
output = self.mivolo(faces_input=faces_input, body_input=body_input)
return output.age_output.cpu().flatten().tolist()
ages: list[float] = []
for start in range(0, len(face_crops), batch_size):
ages.extend(
_run_batch(
face_crops[start : start + batch_size],
body_crops[start : start + batch_size],
)
)
return ages
def _run_yolo_detection(
self,
image_bgr: np.ndarray,
) -> tuple[FaceDetections, PersonDetections]:
"""Run YOLO face/person detection with telemetry."""
with _telemetry_span("inference.yolo_detection", _yolo_duration) as det_span:
results = self.detector(
image_bgr,
conf=self.settings.confidence_threshold,
iou=self.settings.iou_threshold,
)[0]
faces, persons = self._extract_detections(results)
det_span.set_attribute("faces_detected", len(faces))
det_span.set_attribute("persons_detected", len(persons))
_faces_detected.add(len(faces))
return faces, persons
def _run_mivolo_with_metrics(
self,
face_crops: list[np.ndarray],
body_crops: list[np.ndarray | None],
) -> list[float]:
"""Run MiVOLO v2 age estimation with telemetry."""
with _telemetry_span("inference.mivolo_age", _mivolo_duration) as age_span:
ages = self._run_mivolo(face_crops, body_crops)
age_span.set_attribute("ages_estimated", len(ages))
return ages
def predict(self, image_bgr: np.ndarray) -> InferenceOutput:
"""Run face detection and age estimation on an image.
Effectful: calls ML models, renders annotations.
Args:
image_bgr: Input image in BGR format.
Returns:
Inference results with annotated image.
Raises:
InferenceError: If inference or annotation fails.
"""
if image_bgr.size == 0:
raise InferenceError("Decoded image is empty.")
with _telemetry_span("inference.predict", _inference_duration) as span:
span.set_attribute("image.height", image_bgr.shape[0])
span.set_attribute("image.width", image_bgr.shape[1])
try:
# 1. Run face+person detection
faces, persons = self._run_yolo_detection(image_bgr)
# 2. Match faces to persons
matched_pairs = self._match_faces_to_persons(faces, persons)
# 3. Crop face and body regions
face_crops, body_crops = self._crop_regions(image_bgr, matched_pairs)
# 4. Run MiVOLO v2 on crops
ages = self._run_mivolo_with_metrics(face_crops, body_crops)
# 5. Compute annotation parameters
line_width = compute_scaled_line_width(image_bgr.shape)
# 6. Draw annotations (face boxes only with age labels)
annotated = draw_face_annotations(image_bgr, faces, ages, line_width)
span.set_attribute("total_faces", len(faces))
except InferenceError:
_inference_errors.add(1, {"error_type": "inference_error"})
raise
except Exception as exc:
_inference_errors.add(1, {"error_type": "unknown_error"})
span.record_exception(exc)
raise InferenceError(
"Unable to run inference on the provided image."
) from exc
return InferenceOutput(
ages=tuple(ages),
annotated_image=annotated,
)
@lru_cache(maxsize=1)
def get_inference_engine() -> FaceAgeInferenceEngine:
"""Get or create singleton inference engine.
Cached to avoid reloading heavy ML models.
Returns:
Initialized inference engine.
"""
return FaceAgeInferenceEngine()
__all__ = [
"FaceAgeInferenceEngine",
"get_inference_engine",
]