vaisagan commited on
Commit
12406b6
Β·
verified Β·
1 Parent(s): 7c5cdf1

Upload src/inference/predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/inference/predictor.py +123 -0
src/inference/predictor.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full inference pipeline: image β†’ gender + age + emotion + age-at-70 per face.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import List, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+
14
+ from src.data.dataset import eval_transforms
15
+ from src.inference.age_progression import age_to_70
16
+ from src.inference.emotion_detector import EmotionDetector
17
+ from src.inference.face_detector import FaceDetector
18
+ from src.models.face_model import load_model
19
+
20
+ GENDER_LABELS = ["Male", "Female"]
21
+ MAX_AGE = 90.0
22
+
23
+ EMOTION_COLORS = {
24
+ "Happy": (0, 200, 0),
25
+ "Sad": (200, 50, 50),
26
+ "Angry": (0, 0, 220),
27
+ "Fear": (150, 0, 200),
28
+ "Surprise": (200, 150, 0),
29
+ "Disgust": (0, 150, 150),
30
+ "Neutral": (120, 120, 120),
31
+ }
32
+
33
+
34
+ class Predictor:
35
+ def __init__(
36
+ self,
37
+ model_path: str,
38
+ img_size: int = 224,
39
+ confidence: float = 0.7,
40
+ device: Optional[str] = None,
41
+ ) -> None:
42
+ if device is None:
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ self.device = torch.device(device)
45
+ self.transform = eval_transforms(img_size)
46
+ self.detector = FaceDetector(confidence_threshold=confidence)
47
+ self.emotion = EmotionDetector()
48
+ self.model = load_model(model_path, self.device)
49
+
50
+ # ── single face ───────────────────────────────────────────────────────
51
+
52
+ def _predict_crop(self, face_rgb: np.ndarray) -> dict:
53
+ pil = Image.fromarray(face_rgb)
54
+ inp = self.transform(pil).unsqueeze(0).to(self.device)
55
+
56
+ with torch.no_grad():
57
+ gender_logits, age_norm = self.model(inp)
58
+
59
+ probs = torch.softmax(gender_logits, dim=1).squeeze()
60
+ gender_idx = int(probs.argmax().item())
61
+ gender_conf = float(probs[gender_idx].item())
62
+ gender_label = GENDER_LABELS[gender_idx]
63
+ age = float(age_norm.item()) * MAX_AGE
64
+ age = round(max(1.0, min(MAX_AGE, age)), 1)
65
+
66
+ emotion_label, emotion_conf = self.emotion.top_emotion(face_rgb)
67
+ emotion_probs = self.emotion.predict(face_rgb)
68
+ gender_int = 0 if gender_label == "Male" else 1
69
+ aged_face = age_to_70(face_rgb, current_age=age, gender=gender_int)
70
+
71
+ return {
72
+ "gender": gender_label,
73
+ "gender_conf": round(gender_conf * 100, 1),
74
+ "age": age,
75
+ "emotion": emotion_label,
76
+ "emotion_conf": round(emotion_conf * 100, 1),
77
+ "emotion_probs": emotion_probs,
78
+ "aged_face": aged_face, # RGB numpy array
79
+ }
80
+
81
+ # ── full image ────────────────────────────────────────────────────────
82
+
83
+ def predict_image(self, image_rgb: np.ndarray) -> List[dict]:
84
+ """
85
+ Args:
86
+ image_rgb: RGB numpy array
87
+ Returns:
88
+ List of result dicts per detected face (see _predict_crop)
89
+ """
90
+ bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
91
+ crops, boxes = self.detector.crop_faces(bgr)
92
+ results = []
93
+ for crop, box in zip(crops, boxes):
94
+ res = self._predict_crop(crop)
95
+ res["box"] = box
96
+ results.append(res)
97
+ return results
98
+
99
+ # ── annotated image ───────────────────────────────────────────────────
100
+
101
+ def annotate(self, image_rgb: np.ndarray) -> np.ndarray:
102
+ """Return a copy of image_rgb with face boxes and labels drawn."""
103
+ results = self.predict_image(image_rgb)
104
+ out = image_rgb.copy()
105
+ for r in results:
106
+ x1, y1, x2, y2 = r["box"]
107
+ color = (52, 152, 219) if r["gender"] == "Male" else (231, 76, 60)
108
+
109
+ cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)
110
+
111
+ lines = [
112
+ f"{r['gender']} {r['gender_conf']:.0f}%",
113
+ f"Age ~{r['age']:.0f}",
114
+ f"{r['emotion']} {r['emotion_conf']:.0f}%",
115
+ ]
116
+ y_off = max(y1 - 10, 60)
117
+ for i, line in enumerate(reversed(lines)):
118
+ cv2.putText(
119
+ out, line,
120
+ (x1 + 4, y_off - i * 22),
121
+ cv2.FONT_HERSHEY_SIMPLEX, 0.58, color, 2, cv2.LINE_AA,
122
+ )
123
+ return out