FaceInsight_AI / src /inference /predictor.py
vaisagan's picture
Upload src/inference/predictor.py with huggingface_hub
12406b6 verified
"""
Full inference pipeline: image β†’ gender + age + emotion + age-at-70 per face.
"""
from __future__ import annotations
from typing import List, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from src.data.dataset import eval_transforms
from src.inference.age_progression import age_to_70
from src.inference.emotion_detector import EmotionDetector
from src.inference.face_detector import FaceDetector
from src.models.face_model import load_model
GENDER_LABELS = ["Male", "Female"]
MAX_AGE = 90.0
EMOTION_COLORS = {
"Happy": (0, 200, 0),
"Sad": (200, 50, 50),
"Angry": (0, 0, 220),
"Fear": (150, 0, 200),
"Surprise": (200, 150, 0),
"Disgust": (0, 150, 150),
"Neutral": (120, 120, 120),
}
class Predictor:
def __init__(
self,
model_path: str,
img_size: int = 224,
confidence: float = 0.7,
device: Optional[str] = None,
) -> None:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.transform = eval_transforms(img_size)
self.detector = FaceDetector(confidence_threshold=confidence)
self.emotion = EmotionDetector()
self.model = load_model(model_path, self.device)
# ── single face ───────────────────────────────────────────────────────
def _predict_crop(self, face_rgb: np.ndarray) -> dict:
pil = Image.fromarray(face_rgb)
inp = self.transform(pil).unsqueeze(0).to(self.device)
with torch.no_grad():
gender_logits, age_norm = self.model(inp)
probs = torch.softmax(gender_logits, dim=1).squeeze()
gender_idx = int(probs.argmax().item())
gender_conf = float(probs[gender_idx].item())
gender_label = GENDER_LABELS[gender_idx]
age = float(age_norm.item()) * MAX_AGE
age = round(max(1.0, min(MAX_AGE, age)), 1)
emotion_label, emotion_conf = self.emotion.top_emotion(face_rgb)
emotion_probs = self.emotion.predict(face_rgb)
gender_int = 0 if gender_label == "Male" else 1
aged_face = age_to_70(face_rgb, current_age=age, gender=gender_int)
return {
"gender": gender_label,
"gender_conf": round(gender_conf * 100, 1),
"age": age,
"emotion": emotion_label,
"emotion_conf": round(emotion_conf * 100, 1),
"emotion_probs": emotion_probs,
"aged_face": aged_face, # RGB numpy array
}
# ── full image ────────────────────────────────────────────────────────
def predict_image(self, image_rgb: np.ndarray) -> List[dict]:
"""
Args:
image_rgb: RGB numpy array
Returns:
List of result dicts per detected face (see _predict_crop)
"""
bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
crops, boxes = self.detector.crop_faces(bgr)
results = []
for crop, box in zip(crops, boxes):
res = self._predict_crop(crop)
res["box"] = box
results.append(res)
return results
# ── annotated image ───────────────────────────────────────────────────
def annotate(self, image_rgb: np.ndarray) -> np.ndarray:
"""Return a copy of image_rgb with face boxes and labels drawn."""
results = self.predict_image(image_rgb)
out = image_rgb.copy()
for r in results:
x1, y1, x2, y2 = r["box"]
color = (52, 152, 219) if r["gender"] == "Male" else (231, 76, 60)
cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)
lines = [
f"{r['gender']} {r['gender_conf']:.0f}%",
f"Age ~{r['age']:.0f}",
f"{r['emotion']} {r['emotion_conf']:.0f}%",
]
y_off = max(y1 - 10, 60)
for i, line in enumerate(reversed(lines)):
cv2.putText(
out, line,
(x1 + 4, y_off - i * 22),
cv2.FONT_HERSHEY_SIMPLEX, 0.58, color, 2, cv2.LINE_AA,
)
return out