File size: 6,793 Bytes
eacd6a2
 
 
 
 
f017803
1f4e421
eacd6a2
 
 
 
 
 
 
 
 
 
 
1f4e421
 
 
 
eacd6a2
 
8e3c6f8
 
 
 
 
 
 
 
 
 
 
1f4e421
8e3c6f8
 
1f4e421
8e3c6f8
eacd6a2
1f4e421
eacd6a2
1f4e421
 
8e3c6f8
1f4e421
 
eacd6a2
 
 
 
 
 
1f4e421
eacd6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4e421
 
 
eacd6a2
1f4e421
eacd6a2
1f4e421
eacd6a2
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
eacd6a2
 
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor
import cv2
from huggingface_hub import hf_hub_download
from mtcnn import MTCNN  # For high-quality
from pathlib import Path
import sys
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from safetensors.torch import load_file as load_safetensors

try:
    src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    if src_path not in sys.path: sys.path.append(src_path)
    from components.multi_task_model_trainer import MultiTaskEfficientNet
    from utils.common import read_yaml
except ImportError:
    # Fallback for Hugging Face Spaces
    from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
    from src.cnnClassifier.utils.common import read_yaml

class PredictionPipeline:
    def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
        self.device = "cpu"
        self.repo_id = repo_id
        
        print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
        
        # --- THE FIX: Download all artifacts from your HF Model Repo ---
        self.model_path = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors")
        self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml")
        self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv")
        # --- END FIX ---
        
        self.base_model_name = "google/efficientnet-b2"
        self.params = read_yaml(Path(self.params_path))
        
        self.label_maps = self._load_label_maps()
        self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
        self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
        self.model = self._load_model()
        
        haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
        self.face_detector = cv2.CascadeClassifier(haar_cascade_path)
        
        print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
    
    def _load_model(self):
        num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
        model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
        weight_file = self.model_path / 'model.safetensors'
        if not weight_file.exists(): weight_file = self.model_path / 'pytorch_model.bin'
        if not weight_file.exists(): raise FileNotFoundError(f"Weights not found in {self.model_path}")
        state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        return model

    def _draw_predictions(self, image, box, labels):
        x, y, w, h = [int(c) for c in box]
        font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1
        text_color, bg_color = (255, 255, 255), (255, 75, 75)
        text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"]
        max_width, line_height = 0, 25
        for line in text_lines:
            (w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness)
            if w_text > max_width: max_width = w_text
        total_height = len(text_lines) * line_height - 5
        cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2)
        cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1)
        for i, line in enumerate(text_lines):
            y_text = y - total_height + (i * line_height) + 18
            cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)

    def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
        """High-quality prediction using MTCNN for images and videos."""
        annotated_image, predictions = image_array.copy(), []
        face_results = self.hq_face_detector.detect_faces(image_array)
        if not face_results: return annotated_image, predictions

        for face in face_results:
            if face['confidence'] < 0.95: continue
            x, y, w, h = face['box']
            face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions

    def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
        """Lightweight prediction using Haar Cascade for live feed."""
        annotated_image, predictions = image_array.copy(), []
        gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
        faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
        if len(faces) == 0: return annotated_image, predictions

        for (x, y, w, h) in faces:
            face_img = image_array[y:y+h, x:x+w]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions