File size: 7,250 Bytes
eacd6a2
 
 
 
 
f017803
9e440e5
eacd6a2
 
 
 
 
9e440e5
eacd6a2
 
 
 
 
 
1f4e421
 
 
eacd6a2
 
8e3c6f8
 
 
 
3b1eb50
9e440e5
 
 
8e3c6f8
 
9e440e5
 
 
 
3b1eb50
eacd6a2
9e440e5
eacd6a2
1f4e421
3b1eb50
 
1f4e421
3b1eb50
9e440e5
 
 
 
 
 
 
 
 
 
 
 
 
eacd6a2
9e440e5
eacd6a2
3b1eb50
9e440e5
 
eacd6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4e421
 
eacd6a2
1f4e421
eacd6a2
 
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
eacd6a2
 
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor
import cv2
from huggingface_hub import hf_hub_download
from mtcnn import MTCNN
from pathlib import Path
import sys
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from safetensors.torch import load_file as load_safetensors
import pandas as pd

try:
    src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    if src_path not in sys.path: sys.path.append(src_path)
    from components.multi_task_model_trainer import MultiTaskEfficientNet
    from utils.common import read_yaml
except ImportError:
    from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
    from src.cnnClassifier.utils.common import read_yaml

class PredictionPipeline:
    def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
        self.device = "cpu"
        self.repo_id = repo_id
        print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
        cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
        self.model_file = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors", cache_dir=cache_dir)
        self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml", cache_dir=cache_dir)
        self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv", cache_dir=cache_dir)
        self.base_model_name = "google/efficientnet-b2"
        self.params = read_yaml(Path(self.params_path))
        
        # --- THE FIX IS HERE: CALL THE METHOD THAT EXISTS ---
        self.label_maps = self._load_label_maps_from_csv()
        # --- END FIX ---

        self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
        self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
        self.model = self._load_model()
        haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
        self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
        self.hq_face_detector = MTCNN()
        print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")

    # --- THE MISSING METHOD ---
    def _load_label_maps_from_csv(self):
        print(f"Generating label maps from downloaded CSV: {self.data_csv_path}")
        df = pd.read_csv(self.data_csv_path)
        label_maps = {}
        tasks = {'age': lambda x: int(str(x).split('-')[0]), 'gender': None}
        for task, sort_key in tasks.items():
            labels_str = [str(label) for label in df[task].unique()]
            sorted_labels = sorted(labels_str, key=sort_key)
            label_maps[f'{task}_id2label'] = {str(i): label for i, label in enumerate(sorted_labels)}
        return label_maps
    # --- END MISSING METHOD ---

    def _load_model(self):
        num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
        model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
        weight_file = Path(self.model_file)
        if not weight_file.exists(): raise FileNotFoundError(f"Weights not found: {weight_file}")
        state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        return model

    def _draw_predictions(self, image, box, labels):
        x, y, w, h = [int(c) for c in box]
        font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1
        text_color, bg_color = (255, 255, 255), (255, 75, 75)
        text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"]
        max_width, line_height = 0, 25
        for line in text_lines:
            (w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness)
            if w_text > max_width: max_width = w_text
        total_height = len(text_lines) * line_height - 5
        cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2)
        cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1)
        for i, line in enumerate(text_lines):
            y_text = y - total_height + (i * line_height) + 18
            cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)

    def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
        annotated_image, predictions = image_array.copy(), []
        face_results = self.hq_face_detector.detect_faces(image_array)
        if not face_results: return annotated_image, predictions
        for face in face_results:
            if face['confidence'] < 0.95: continue
            x, y, w, h = face['box']
            face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions

    def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
        annotated_image, predictions = image_array.copy(), []
        gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
        faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
        if len(faces) == 0: return annotated_image, predictions
        for (x, y, w, h) in faces:
            face_img = image_array[y:y+h, x:x+w]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions