File size: 4,603 Bytes
12406b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Full inference pipeline: image β†’ gender + age + emotion + age-at-70 per face.
"""

from __future__ import annotations

from typing import List, Tuple

import cv2
import numpy as np
import torch
from PIL import Image

from src.data.dataset import eval_transforms
from src.inference.age_progression import age_to_70
from src.inference.emotion_detector import EmotionDetector
from src.inference.face_detector import FaceDetector
from src.models.face_model import load_model

GENDER_LABELS = ["Male", "Female"]
MAX_AGE       = 90.0

EMOTION_COLORS = {
    "Happy":    (0, 200, 0),
    "Sad":      (200, 50, 50),
    "Angry":    (0, 0, 220),
    "Fear":     (150, 0, 200),
    "Surprise": (200, 150, 0),
    "Disgust":  (0, 150, 150),
    "Neutral":  (120, 120, 120),
}


class Predictor:
    def __init__(
        self,
        model_path: str,
        img_size: int     = 224,
        confidence: float = 0.7,
        device: Optional[str] = None,
    ) -> None:
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device    = torch.device(device)
        self.transform = eval_transforms(img_size)
        self.detector  = FaceDetector(confidence_threshold=confidence)
        self.emotion   = EmotionDetector()
        self.model     = load_model(model_path, self.device)

    # ── single face ───────────────────────────────────────────────────────

    def _predict_crop(self, face_rgb: np.ndarray) -> dict:
        pil = Image.fromarray(face_rgb)
        inp = self.transform(pil).unsqueeze(0).to(self.device)

        with torch.no_grad():
            gender_logits, age_norm = self.model(inp)

        probs        = torch.softmax(gender_logits, dim=1).squeeze()
        gender_idx   = int(probs.argmax().item())
        gender_conf  = float(probs[gender_idx].item())
        gender_label = GENDER_LABELS[gender_idx]
        age          = float(age_norm.item()) * MAX_AGE
        age          = round(max(1.0, min(MAX_AGE, age)), 1)

        emotion_label, emotion_conf = self.emotion.top_emotion(face_rgb)
        emotion_probs = self.emotion.predict(face_rgb)
        gender_int    = 0 if gender_label == "Male" else 1
        aged_face     = age_to_70(face_rgb, current_age=age, gender=gender_int)

        return {
            "gender":        gender_label,
            "gender_conf":   round(gender_conf * 100, 1),
            "age":           age,
            "emotion":       emotion_label,
            "emotion_conf":  round(emotion_conf * 100, 1),
            "emotion_probs": emotion_probs,
            "aged_face":     aged_face,   # RGB numpy array
        }

    # ── full image ────────────────────────────────────────────────────────

    def predict_image(self, image_rgb: np.ndarray) -> List[dict]:
        """
        Args:
            image_rgb: RGB numpy array
        Returns:
            List of result dicts per detected face (see _predict_crop)
        """
        bgr    = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
        crops, boxes = self.detector.crop_faces(bgr)
        results = []
        for crop, box in zip(crops, boxes):
            res       = self._predict_crop(crop)
            res["box"] = box
            results.append(res)
        return results

    # ── annotated image ───────────────────────────────────────────────────

    def annotate(self, image_rgb: np.ndarray) -> np.ndarray:
        """Return a copy of image_rgb with face boxes and labels drawn."""
        results = self.predict_image(image_rgb)
        out     = image_rgb.copy()
        for r in results:
            x1, y1, x2, y2 = r["box"]
            color = (52, 152, 219) if r["gender"] == "Male" else (231, 76, 60)

            cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)

            lines = [
                f"{r['gender']} {r['gender_conf']:.0f}%",
                f"Age ~{r['age']:.0f}",
                f"{r['emotion']} {r['emotion_conf']:.0f}%",
            ]
            y_off = max(y1 - 10, 60)
            for i, line in enumerate(reversed(lines)):
                cv2.putText(
                    out, line,
                    (x1 + 4, y_off - i * 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.58, color, 2, cv2.LINE_AA,
                )
        return out