Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import base64 | |
| from typing import List, Dict, Any | |
| import tempfile | |
| import time | |
| from PIL import Image, ImageDraw | |
| import json | |
| import io | |
| # Import RetinaFace model components | |
| from models.retinaface import RetinaFace | |
| from utils.prior_box import PriorBox | |
| from utils.py_cpu_nms import py_cpu_nms | |
| from utils.box_utils import decode, decode_landm | |
| # Global variables for models | |
| mobilenet_model = None | |
| resnet_model = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_models(): | |
| """Load both MobileNet and ResNet RetinaFace models""" | |
| global mobilenet_model, resnet_model | |
| try: | |
| # Model configurations | |
| mobilenet_cfg = { | |
| 'name': 'mobilenet0.25', | |
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |
| 'steps': [8, 16, 32], | |
| 'variance': [0.1, 0.2], | |
| 'clip': False, | |
| 'loc_weight': 2.0, | |
| 'gpu_train': True, | |
| 'batch_size': 32, | |
| 'ngpu': 1, | |
| 'epoch': 250, | |
| 'decay1': 190, | |
| 'decay2': 220, | |
| 'image_size': 640, | |
| 'pretrain': False, | |
| 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, | |
| 'in_channel': 32, | |
| 'out_channel': 64 | |
| } | |
| resnet_cfg = { | |
| 'name': 'Resnet50', | |
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |
| 'steps': [8, 16, 32], | |
| 'variance': [0.1, 0.2], | |
| 'clip': False, | |
| 'loc_weight': 2.0, | |
| 'gpu_train': True, | |
| 'batch_size': 24, | |
| 'ngpu': 4, | |
| 'epoch': 100, | |
| 'decay1': 70, | |
| 'decay2': 90, | |
| 'image_size': 840, | |
| 'pretrain': False, | |
| 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, | |
| 'in_channel': 256, | |
| 'out_channel': 256 | |
| } | |
| # Load MobileNet model | |
| mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test') | |
| mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device)) | |
| mobilenet_model.eval() | |
| mobilenet_model = mobilenet_model.to(device) | |
| # Load ResNet model | |
| resnet_model = RetinaFace(cfg=resnet_cfg, phase='test') | |
| resnet_model.load_state_dict(torch.load('Resnet50_Final.pth', map_location=device)) | |
| resnet_model.eval() | |
| resnet_model = resnet_model.to(device) | |
| print("β Models loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Error loading models: {e}") | |
| return False | |
| def detect_faces(image, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4): | |
| """Core face detection function""" | |
| try: | |
| start_time = time.time() | |
| # Choose model | |
| if model_type == "resnet": | |
| model = resnet_model | |
| cfg = { | |
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |
| 'steps': [8, 16, 32], | |
| 'variance': [0.1, 0.2], | |
| 'clip': False, | |
| 'image_size': 840 | |
| } | |
| else: | |
| model = mobilenet_model | |
| cfg = { | |
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |
| 'steps': [8, 16, 32], | |
| 'variance': [0.1, 0.2], | |
| 'clip': False, | |
| 'image_size': 640 | |
| } | |
| if model is None: | |
| return None, "Models not loaded" | |
| # Convert PIL to numpy array | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| # Preprocessing | |
| img = np.float32(image) | |
| im_height, im_width, _ = img.shape | |
| scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |
| img -= (104, 117, 123) | |
| img = img.transpose(2, 0, 1) | |
| img = torch.from_numpy(img).unsqueeze(0) | |
| img = img.to(device) | |
| scale = scale.to(device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| loc, conf, landms = model(img) | |
| # Generate priors | |
| priorbox = PriorBox(cfg, image_size=(im_height, im_width)) | |
| priors = priorbox.forward() | |
| priors = priors.to(device) | |
| prior_data = priors.data | |
| boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) | |
| boxes = boxes * scale | |
| boxes = boxes.cpu().numpy() | |
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |
| landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance']) | |
| scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |
| img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |
| img.shape[3], img.shape[2]]) | |
| scale1 = scale1.to(device) | |
| landms = landms * scale1 | |
| landms = landms.cpu().numpy() | |
| # Ignore low scores | |
| inds = np.where(scores > confidence_threshold)[0] | |
| boxes = boxes[inds] | |
| landms = landms[inds] | |
| scores = scores[inds] | |
| # Keep top-K before NMS | |
| order = scores.argsort()[::-1][:5000] | |
| boxes = boxes[order] | |
| landms = landms[order] | |
| scores = scores[order] | |
| # Apply NMS | |
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) | |
| keep = py_cpu_nms(dets, nms_threshold) | |
| dets = dets[keep, :] | |
| landms = landms[keep] | |
| # Draw results | |
| result_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(result_image) | |
| faces = [] | |
| for b, landmarks in zip(dets, landms): | |
| if b[4] < confidence_threshold: | |
| continue | |
| # Draw bounding box | |
| draw.rectangle([b[0], b[1], b[2], b[3]], outline="red", width=2) | |
| # Draw confidence score | |
| draw.text((b[0], b[1] - 15), f'{b[4]:.2f}', fill="red") | |
| # Draw landmarks | |
| for i in range(0, 10, 2): | |
| draw.ellipse([landmarks[i]-2, landmarks[i+1]-2, landmarks[i]+2, landmarks[i+1]+2], fill="blue") | |
| faces.append({ | |
| "bbox": {"x1": float(b[0]), "y1": float(b[1]), "x2": float(b[2]), "y2": float(b[3])}, | |
| "confidence": float(b[4]), | |
| "landmarks": { | |
| "left_eye": [float(landmarks[0]), float(landmarks[1])], | |
| "right_eye": [float(landmarks[2]), float(landmarks[3])], | |
| "nose": [float(landmarks[4]), float(landmarks[5])], | |
| "left_mouth": [float(landmarks[6]), float(landmarks[7])], | |
| "right_mouth": [float(landmarks[8]), float(landmarks[9])] | |
| } | |
| }) | |
| processing_time = time.time() - start_time | |
| result_text = f""" | |
| **Detection Results:** | |
| - **Faces Detected:** {len(faces)} | |
| - **Model Used:** {model_type} | |
| - **Processing Time:** {processing_time:.3f}s | |
| - **Confidence Threshold:** {confidence_threshold} | |
| - **NMS Threshold:** {nms_threshold} | |
| """ | |
| return result_image, result_text | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Load models on startup | |
| print("Loading RetinaFace models...") | |
| model_loaded = load_models() | |
| # Create simple Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="RetinaFace Face Detection") as demo: | |
| gr.Markdown("# π₯ RetinaFace Face Detection API") | |
| gr.Markdown("Real-time face detection using RetinaFace with MobileNet and ResNet backbones") | |
| if model_loaded: | |
| gr.Markdown("β **Status**: Models loaded successfully!") | |
| else: | |
| gr.Markdown("β **Status**: Error loading models") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| model_choice = gr.Dropdown( | |
| choices=["mobilenet", "resnet"], | |
| value="mobilenet", | |
| label="Model" | |
| ) | |
| confidence = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.5, step=0.1, | |
| label="Confidence" | |
| ) | |
| nms = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.4, step=0.1, | |
| label="NMS Threshold" | |
| ) | |
| detect_btn = gr.Button("π Detect Faces", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Results") | |
| output_text = gr.Markdown() | |
| detect_btn.click( | |
| fn=detect_faces, | |
| inputs=[input_image, model_choice, confidence, nms], | |
| outputs=[output_image, output_text] | |
| ) | |
| gr.Markdown(""" | |
| ## API Usage | |
| Use `/api/predict` endpoint with: | |
| ```json | |
| { | |
| "data": [image, "mobilenet", 0.5, 0.4] | |
| } | |
| ``` | |
| """) | |
| return demo | |
| # Create and launch the interface | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |