File size: 6,763 Bytes
eacd6a2
 
 
 
 
1f4e421
eacd6a2
 
 
 
 
 
 
 
 
 
 
1f4e421
 
 
 
eacd6a2
 
1f4e421
 
eacd6a2
 
1f4e421
 
 
 
 
 
 
 
eacd6a2
1f4e421
eacd6a2
1f4e421
 
 
 
 
 
 
 
 
 
eacd6a2
 
 
 
 
 
1f4e421
eacd6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4e421
 
 
eacd6a2
1f4e421
eacd6a2
1f4e421
eacd6a2
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
eacd6a2
 
1f4e421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor
import cv2
from mtcnn import MTCNN  # For high-quality
from pathlib import Path
import sys
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from safetensors.torch import load_file as load_safetensors

try:
    src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    if src_path not in sys.path: sys.path.append(src_path)
    from components.multi_task_model_trainer import MultiTaskEfficientNet
    from utils.common import read_yaml
except ImportError:
    # Fallback for Hugging Face Spaces
    from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
    from src.cnnClassifier.utils.common import read_yaml

class PredictionPipeline:
    def __init__(self, model_path: str = "model/checkpoint-26873"):
        self.device = "cpu"  # Force CPU for deployment
        self.model_path = Path(model_path)
        self.base_model_name = "google/efficientnet-b2"
        self.params = read_yaml(Path("model/params.yaml"))
        
        self.label_maps = {
            'age_id2label': {'0': '0-2', '1': '3-9', '2': '10-19', '3': '20-29', '4': '30-39', '5': '40-49', '6': '50-59', '7': '60-69', '8': 'more than 70'},
            'gender_id2label': {'0': 'Male', '1': 'Female'}
        }
        
        print("--- Initializing Prediction Pipeline ---")
        self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
        self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
        self.model = self._load_model()
        
        # --- THE FIX: LOAD BOTH DETECTORS ---
        # High-quality detector for offline tasks
        self.hq_face_detector = MTCNN() 
        # Lightweight detector for live feed
        haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
        self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
        # --- END FIX ---
        
        print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
    
    def _load_model(self):
        num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
        model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
        weight_file = self.model_path / 'model.safetensors'
        if not weight_file.exists(): weight_file = self.model_path / 'pytorch_model.bin'
        if not weight_file.exists(): raise FileNotFoundError(f"Weights not found in {self.model_path}")
        state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        return model

    def _draw_predictions(self, image, box, labels):
        x, y, w, h = [int(c) for c in box]
        font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1
        text_color, bg_color = (255, 255, 255), (255, 75, 75)
        text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"]
        max_width, line_height = 0, 25
        for line in text_lines:
            (w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness)
            if w_text > max_width: max_width = w_text
        total_height = len(text_lines) * line_height - 5
        cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2)
        cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1)
        for i, line in enumerate(text_lines):
            y_text = y - total_height + (i * line_height) + 18
            cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)

    def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
        """High-quality prediction using MTCNN for images and videos."""
        annotated_image, predictions = image_array.copy(), []
        face_results = self.hq_face_detector.detect_faces(image_array)
        if not face_results: return annotated_image, predictions

        for face in face_results:
            if face['confidence'] < 0.95: continue
            x, y, w, h = face['box']
            face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions

    def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
        """Lightweight prediction using Haar Cascade for live feed."""
        annotated_image, predictions = image_array.copy(), []
        gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
        faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
        if len(faces) == 0: return annotated_image, predictions

        for (x, y, w, h) in faces:
            face_img = image_array[y:y+h, x:x+w]
            if face_img.size == 0: continue
            pil_face = Image.fromarray(face_img)
            pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
            with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
            pred_id_age = str(outputs['age_logits'].argmax(1).item())
            pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
            age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
            gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
            prediction_labels = {"age": age_label, "gender": gender_label}
            predictions.append({**prediction_labels, 'box': (x, y, w, h)})
            self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
        return annotated_image, predictions