Spaces:
Runtime error
Runtime error
File size: 6,793 Bytes
eacd6a2 f017803 1f4e421 eacd6a2 1f4e421 eacd6a2 8e3c6f8 1f4e421 8e3c6f8 1f4e421 8e3c6f8 eacd6a2 1f4e421 eacd6a2 1f4e421 8e3c6f8 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor
import cv2
from huggingface_hub import hf_hub_download
from mtcnn import MTCNN # For high-quality
from pathlib import Path
import sys
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from safetensors.torch import load_file as load_safetensors
try:
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if src_path not in sys.path: sys.path.append(src_path)
from components.multi_task_model_trainer import MultiTaskEfficientNet
from utils.common import read_yaml
except ImportError:
# Fallback for Hugging Face Spaces
from src.cnnClassifier.components.multi_task_model_trainer import MultiTaskEfficientNet
from src.cnnClassifier.utils.common import read_yaml
class PredictionPipeline:
def __init__(self, repo_id: str = "ALYYAN/Facial-Age-Det"):
self.device = "cpu"
self.repo_id = repo_id
print("--- Initializing Prediction Pipeline by downloading artifacts from Hub ---")
# --- THE FIX: Download all artifacts from your HF Model Repo ---
self.model_path = hf_hub_download(repo_id=self.repo_id, filename="checkpoint-26873/model.safetensors")
self.params_path = hf_hub_download(repo_id=self.repo_id, filename="params.yaml")
self.data_csv_path = hf_hub_download(repo_id=self.repo_id, filename="fairface_cleaned.csv")
# --- END FIX ---
self.base_model_name = "google/efficientnet-b2"
self.params = read_yaml(Path(self.params_path))
self.label_maps = self._load_label_maps()
self.processor = AutoImageProcessor.from_pretrained(self.base_model_name)
self.transforms = Compose([Resize((self.params.IMAGE_SIZE, self.params.IMAGE_SIZE)), ToTensor(), Normalize(mean=self.processor.image_mean, std=self.processor.image_std)])
self.model = self._load_model()
haar_cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
self.face_detector = cv2.CascadeClassifier(haar_cascade_path)
print(f"--- Pipeline Initialized Successfully on device: {self.device} ---")
def _load_model(self):
num_age, num_gender, num_race = len(self.label_maps['age_id2label']), len(self.label_maps['gender_id2label']), 7
model = MultiTaskEfficientNet(self.base_model_name, num_age, num_gender, num_race)
weight_file = self.model_path / 'model.safetensors'
if not weight_file.exists(): weight_file = self.model_path / 'pytorch_model.bin'
if not weight_file.exists(): raise FileNotFoundError(f"Weights not found in {self.model_path}")
state_dict = load_safetensors(weight_file, device="cpu") if weight_file.suffix == ".safetensors" else torch.load(weight_file, map_location="cpu")
model.load_state_dict(state_dict)
model.to(self.device)
model.eval()
return model
def _draw_predictions(self, image, box, labels):
x, y, w, h = [int(c) for c in box]
font, font_scale, font_thickness = cv2.FONT_HERSHEY_DUPLEX, 0.6, 1
text_color, bg_color = (255, 255, 255), (255, 75, 75)
text_lines = [f"Gender: {labels['gender']}", f"Age: {labels['age']}"]
max_width, line_height = 0, 25
for line in text_lines:
(w_text, _), _ = cv2.getTextSize(line, font, font_scale, font_thickness)
if w_text > max_width: max_width = w_text
total_height = len(text_lines) * line_height - 5
cv2.rectangle(image, (x, y), (x + w, y + h), bg_color, 2)
cv2.rectangle(image, (x-1, y - total_height), (x + max_width + 10, y), bg_color, -1)
for i, line in enumerate(text_lines):
y_text = y - total_height + (i * line_height) + 18
cv2.putText(image, line, (x + 5, y_text), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
def predict_hq(self, image_array: np.ndarray) -> (np.ndarray, list):
"""High-quality prediction using MTCNN for images and videos."""
annotated_image, predictions = image_array.copy(), []
face_results = self.hq_face_detector.detect_faces(image_array)
if not face_results: return annotated_image, predictions
for face in face_results:
if face['confidence'] < 0.95: continue
x, y, w, h = face['box']
face_img = image_array[max(0,y):min(image_array.shape[0],y+h), max(0,x):min(image_array.shape[1],x+w)]
if face_img.size == 0: continue
pil_face = Image.fromarray(face_img)
pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
pred_id_age = str(outputs['age_logits'].argmax(1).item())
pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
prediction_labels = {"age": age_label, "gender": gender_label}
predictions.append({**prediction_labels, 'box': (x, y, w, h)})
self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
return annotated_image, predictions
def predict_lq(self, image_array: np.ndarray) -> (np.ndarray, list):
"""Lightweight prediction using Haar Cascade for live feed."""
annotated_image, predictions = image_array.copy(), []
gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
faces = self.lq_face_detector.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
if len(faces) == 0: return annotated_image, predictions
for (x, y, w, h) in faces:
face_img = image_array[y:y+h, x:x+w]
if face_img.size == 0: continue
pil_face = Image.fromarray(face_img)
pixel_values = self.transforms(pil_face).unsqueeze(0).to(self.device)
with torch.no_grad(): outputs = self.model(pixel_values=pixel_values)
pred_id_age = str(outputs['age_logits'].argmax(1).item())
pred_id_gender = str(outputs['gender_logits'].argmax(1).item())
age_label = self.label_maps['age_id2label'].get(pred_id_age, "N/A")
gender_label = self.label_maps['gender_id2label'].get(pred_id_gender, "N/A")
prediction_labels = {"age": age_label, "gender": gender_label}
predictions.append({**prediction_labels, 'box': (x, y, w, h)})
self._draw_predictions(annotated_image, (x, y, w, h), prediction_labels)
return annotated_image, predictions |