Spaces:
Running
Running
| import sys | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import tempfile | |
| import logging | |
| import argparse | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Add the cloned MiVOLO repo to Python path | |
| sys.path.insert(0, '/app/mivolo_repo') | |
| app = FastAPI(title="MiVOLO Age & Gender Detection API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global predictor — loaded lazily on first request to avoid OOM during build | |
| predictor = None | |
| def get_predictor(): | |
| """ | |
| Lazy-loads the MiVOLO predictor on the first request. | |
| Downloads model weights from Hugging Face Hub automatically. | |
| """ | |
| global predictor | |
| if predictor is not None: | |
| return predictor | |
| logger.info("Loading MiVOLO predictor for the first time...") | |
| from huggingface_hub import hf_hub_download | |
| from mivolo.predictor import Predictor | |
| # Download the YOLOv8 person+face detector weights from the public working repo | |
| detector_weights = hf_hub_download( | |
| repo_id="iitolstykh/demo_yolov8_detector", | |
| filename="yolov8x_person_face.pt" | |
| ) | |
| # The age/gender checkpoint is no longer publicly hosted on HF. | |
| # We auto-detect any .pth.tar file uploaded to the Space. | |
| import glob | |
| checkpoint_files = glob.glob("*.pth.tar") | |
| if not checkpoint_files: | |
| raise FileNotFoundError("No checkpoint file ending in .pth.tar found in Space root! Please upload it.") | |
| checkpoint_path = checkpoint_files[0] | |
| # Build MiVOLO config | |
| config = argparse.Namespace( | |
| detector_weights=detector_weights, | |
| checkpoint=checkpoint_path, | |
| device="cpu", | |
| with_persons=True, # Use full-body context for better accuracy | |
| disable_faces=False, # Also use face features | |
| draw=False | |
| ) | |
| predictor = Predictor(config, verbose=False) | |
| logger.info("MiVOLO predictor loaded successfully.") | |
| return predictor | |
| def health_check(): | |
| return { | |
| "status": "MiVOLO API is running!", | |
| "model": "MiVOLO D1 — State-of-the-Art Age & Gender Estimation" | |
| } | |
| async def predict_age_gender(file: UploadFile = File(...)): | |
| tmp_path = None | |
| try: | |
| # Read and decode image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(status_code=400, detail="Invalid or unreadable image file.") | |
| # Run MiVOLO prediction directly on the numpy image array | |
| pred = get_predictor() | |
| detected_objects, _ = pred.recognize(img) | |
| if detected_objects is None or not detected_objects.ages: | |
| raise HTTPException( | |
| status_code=422, | |
| detail="No face detected. Please use a clear, well-lit photo." | |
| ) | |
| # Take the primary (highest-confidence) detection | |
| age = round(float(detected_objects.ages[0])) | |
| gender_raw = detected_objects.genders[0] # "male" or "female" | |
| gender_score = float(detected_objects.gender_scores[0]) | |
| # Format gender to match dashboard expectations | |
| gender = "Man" if gender_raw == "male" else "Woman" | |
| logger.info(f"MiVOLO Result — Age: {age}, Gender: {gender} ({gender_score:.2f})") | |
| return { | |
| "success": True, | |
| "age": age, | |
| "gender": gender, | |
| "confidence": round(gender_score, 2), | |
| "model_used": "MiVOLO D1" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Always clean up temp file | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |