Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoImageProcessor | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| from mtcnn import MTCNN | |
| from pathlib import Path | |
| import sys | |
| import os | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
| from safetensors.torch import load_file as load_safetensors | |
| import pandas as pd | |
| try: | |
| src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
| if src_path not in sys.path: sys.path.append(src_path) | |
| from components.multi_task_model_trainer import MultiTaskEfficientNet | |
| from utils.common import read_yaml | |
| except ImportError: | |
| from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet | |
| from src.cnnClassifier.utils.common import read_yaml | |
| class PredictionPipeline: | |
| def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"): | |
| self.device = "cpu" | |
| self.repo_id = repo_id | |
| print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---") | |
| cache_dir = os.getenv("HF_HOME", "/app/hf_cache") | |
| self.model_file = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors", cache_dir=cache_dir) | |
| self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml", cache_dir=cache_dir) | |
| self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv", cache_dir=cache_dir) | |
| self.base_model_name = "google/efficientnet-b2" | |
| self.params = read_yaml(Path(self.params_path)) | |
| # --- THE FIX IS HERE: CALL THE METHOD THAT EXISTS --- | |
| self.label_maps = self._load_label_maps_from_csv() | |
| # --- END FIX --- | |
| self.processor = AutoImageProcessor.from_pretrained(self.base_model_name) | |
| self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)]) | |
| self.model = self._load_model() | |
| haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| self.lq_face_detector = cv2.CascadeClassifier(haar_cascade_path) | |
| self.hq_face_detector = MTCNN() | |
| print(f"--- Pipeline Initialized Successfully on device: {self.device} ---") | |
| # --- THE MISSING METHOD --- | |
| def _load_label_maps_from_csv(self): | |
| print(f"Generating label maps from downloaded CSV: {self.data_csv_path}") | |
| df = pd.read_csv(self.data_csv_path) | |
| label_maps = {} | |
| tasks = {'age': lambda x: int(str(x).split('-')[0]), 'gender': None} | |
| for task, sort_key in tasks.items(): | |
| labels_str = [str(label) for label in df[task].unique()] | |
| sorted_labels = sorted(labels_str, key=sort_key) | |
| label_maps[f'{task}_id2label'] = {str(i): label for i, label in enumerate(sorted_labels)} | |
| return label_maps | |
| # --- END MISSING METHOD --- | |
| def _load_model(self): | |
| num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7 | |
| model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race) | |
| weight_file = Path(self.model_file) | |
| if not weight_file.exists(): raise FileNotFoundError(f"Weights not found: {weight_file}") | |
| state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def _draw_predictions(self, image, box, labels): | |
| x, y, w, h = [int(c) for c in box] | |
| font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1 | |
| text_color, bg_color = (255, 255, 255), (255, 75, 75) | |
| text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"] | |
| max_width, line_height = 0, 25 | |
| for line in text_lines: | |
| (w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness) | |
| if w_text > max_width: max_width = w_text | |
| total_height = len(text_lines) * line_height - 5 | |
| cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2) | |
| cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1) | |
| for i, line in enumerate(text_lines): | |
| y_text = y - total_height + (i * line_height) + 18 | |
| cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA) | |
| def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list): | |
| annotated_image, predictions = image_array.copy(), [] | |
| face_results = self.hq_face_detector.detect_faces(image_array) | |
| if not face_results: return annotated_image, predictions | |
| for face in face_results: | |
| if face['confidence'] < 0.95: continue | |
| x, y, w, h = face['box'] | |
| face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)] | |
| if face_img.size == 0: continue | |
| pil_face = Image.fromarray(face_img) | |
| pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): outputs = self.model(pixel_values=pixel_values) | |
| pred_id_age = str(outputs['age_logits'].argmax(1).item()) | |
| pred_id_gender = str(outputs['gender_logits'].argmax(1).item()) | |
| age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A") | |
| gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A") | |
| prediction_labels = {"age": age_label, "gender": gender_label} | |
| predictions.append({**prediction_labels, 'box': (x, y, w, h)}) | |
| self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels) | |
| return annotated_image, predictions | |
| def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list): | |
| annotated_image, predictions = image_array.copy(), [] | |
| gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) | |
| faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)) | |
| if len(faces) == 0: return annotated_image, predictions | |
| for (x, y, w, h) in faces: | |
| face_img = image_array[y:y+h, x:x+w] | |
| if face_img.size == 0: continue | |
| pil_face = Image.fromarray(face_img) | |
| pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): outputs = self.model(pixel_values=pixel_values) | |
| pred_id_age = str(outputs['age_logits'].argmax(1).item()) | |
| pred_id_gender = str(outputs['gender_logits'].argmax(1).item()) | |
| age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A") | |
| gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A") | |
| prediction_labels = {"age": age_label, "gender": gender_label} | |
| predictions.append({**prediction_labels, 'box': (x, y, w, h)}) | |
| self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels) | |
| return annotated_image, predictions |