Spaces:
Runtime error
Runtime error
File size: 7,250 Bytes
eacd6a2 f017803 9e440e5 eacd6a2 9e440e5 eacd6a2 1f4e421 eacd6a2 8e3c6f8 3b1eb50 9e440e5 8e3c6f8 9e440e5 3b1eb50 eacd6a2 9e440e5 eacd6a2 1f4e421 3b1eb50 1f4e421 3b1eb50 9e440e5 eacd6a2 9e440e5 eacd6a2 3b1eb50 9e440e5 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor
import cv2
from huggingface_hub import hf_hub_download
from mtcnn import MTCNN
from pathlib import Path
import sys
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from safetensors.torch import load_file as load_safetensors
import pandas as pd
try:
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if src_path not in sys.path: sys.path.append(src_path)
from components.multi_task_model_trainer import MultiTaskEfficientNet
from utils.common import read_yaml
except ImportError:
from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
from src.cnnClassifier.utils.common import read_yaml
class PredictionPipeline:
def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
self.device = "cpu"
self.repo_id = repo_id
print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
self.model_file = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors", cache_dir=cache_dir)
self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml", cache_dir=cache_dir)
self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv", cache_dir=cache_dir)
self.base_model_name = "google/efficientnet-b2"
self.params = read_yaml(Path(self.params_path))
# --- THE FIX IS HERE: CALL THE METHOD THAT EXISTS ---
self.label_maps = self._load_label_maps_from_csv()
# --- END FIX ---
self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
self.model = self._load_model()
haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path)
self.hq_face_detector = MTCNN()
print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
# --- THE MISSING METHOD ---
def _load_label_maps_from_csv(self):
print(f"Generating label maps from downloaded CSV: {self.data_csv_path}")
df = pd.read_csv(self.data_csv_path)
label_maps = {}
tasks = {'age': lambda x: int(str(x).split('-')[0]), 'gender': None}
for task, sort_key in tasks.items():
labels_str = [str(label) for label in df[task].unique()]
sorted_labels = sorted(labels_str, key=sort_key)
label_maps[f'{task}_id2label'] = {str(i): label for i, label in enumerate(sorted_labels)}
return label_maps
# --- END MISSING METHOD ---
def _load_model(self):
num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
weight_file = Path(self.model_file)
if not weight_file.exists(): raise FileNotFoundError(f"Weights not found: {weight_file}")
state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
model.load_state_dict(state_dict)
model.to(self.device)
model.eval()
return model
def _draw_predictions(self, image, box, labels):
x, y, w, h = [int(c) for c in box]
font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1
text_color, bg_color = (255, 255, 255), (255, 75, 75)
text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"]
max_width, line_height = 0, 25
for line in text_lines:
(w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness)
if w_text > max_width: max_width = w_text
total_height = len(text_lines) * line_height - 5
cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2)
cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1)
for i, line in enumerate(text_lines):
y_text = y - total_height + (i * line_height) + 18
cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
annotated_image, predictions = image_array.copy(), []
face_results = self.hq_face_detector.detect_faces(image_array)
if not face_results: return annotated_image, predictions
for face in face_results:
if face['confidence'] < 0.95: continue
x, y, w, h = face['box']
face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
if face_img.size == 0: continue
pil_face = Image.fromarray(face_img)
pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
pred_id_age = str(outputs['age_logits'].argmax(1).item())
pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
prediction_labels = {"age": age_label, "gender": gender_label}
predictions.append({**prediction_labels, 'box': (x, y, w, h)})
self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
return annotated_image, predictions
def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
annotated_image, predictions = image_array.copy(), []
gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
if len(faces) == 0: return annotated_image, predictions
for (x, y, w, h) in faces:
face_img = image_array[y:y+h, x:x+w]
if face_img.size == 0: continue
pil_face = Image.fromarray(face_img)
pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
pred_id_age = str(outputs['age_logits'].argmax(1).item())
pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
prediction_labels = {"age": age_label, "gender": gender_label}
predictions.append({**prediction_labels, 'box': (x, y, w, h)})
self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
return annotated_image, predictions |