Spaces:
Running
Running
File size: 4,453 Bytes
e5f3fc8 fe47fa8 9e03ade fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 33572cd e5f3fc8 33572cd ab46ed3 e5f3fc8 ab46ed3 e5f3fc8 33572cd ab46ed3 20083da e5f3fc8 ab46ed3 e5f3fc8 33572cd fe47fa8 e5f3fc8 33572cd fe47fa8 e5f3fc8 fe47fa8 33572cd fe47fa8 e5f3fc8 fe47fa8 251fe98 04c6577 51ee15b e5f3fc8 04c6577 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 33572cd fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 e5f3fc8 fe47fa8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import sys
import os
import cv2
import numpy as np
import tempfile
import logging
import argparse
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add the cloned MiVOLO repo to Python path
sys.path.insert(0, '/app/mivolo_repo')
app = FastAPI(title="MiVOLO Age & Gender Detection API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global predictor — loaded lazily on first request to avoid OOM during build
predictor = None
def get_predictor():
"""
Lazy-loads the MiVOLO predictor on the first request.
Downloads model weights from Hugging Face Hub automatically.
"""
global predictor
if predictor is not None:
return predictor
logger.info("Loading MiVOLO predictor for the first time...")
from huggingface_hub import hf_hub_download
from mivolo.predictor import Predictor
# Download the YOLOv8 person+face detector weights from the public working repo
detector_weights = hf_hub_download(
repo_id="iitolstykh/demo_yolov8_detector",
filename="yolov8x_person_face.pt"
)
# The age/gender checkpoint is no longer publicly hosted on HF.
# We auto-detect any .pth.tar file uploaded to the Space.
import glob
checkpoint_files = glob.glob("*.pth.tar")
if not checkpoint_files:
raise FileNotFoundError("No checkpoint file ending in .pth.tar found in Space root! Please upload it.")
checkpoint_path = checkpoint_files[0]
# Build MiVOLO config
config = argparse.Namespace(
detector_weights=detector_weights,
checkpoint=checkpoint_path,
device="cpu",
with_persons=True, # Use full-body context for better accuracy
disable_faces=False, # Also use face features
draw=False
)
predictor = Predictor(config, verbose=False)
logger.info("MiVOLO predictor loaded successfully.")
return predictor
@app.get("/")
def health_check():
return {
"status": "MiVOLO API is running!",
"model": "MiVOLO D1 — State-of-the-Art Age & Gender Estimation"
}
@app.post("/predict")
async def predict_age_gender(file: UploadFile = File(...)):
try:
# Read and decode image
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid or unreadable image file.")
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
# This dramatically improves face visibility in bad webcam lighting
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel, a_channel, b_channel = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl = clahe.apply(l_channel)
limg = cv2.merge((cl, a_channel, b_channel))
img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
# Convert BGR (OpenCV default) to RGB (MiVOLO expectation)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Run MiVOLO prediction directly on the numpy image array
pred = get_predictor()
detected_objects, _ = pred.recognize(img_rgb)
if detected_objects is None or not detected_objects.ages:
raise HTTPException(
status_code=422,
detail="No face detected. Please use a clear, well-lit photo."
)
# Take the primary (highest-confidence) detection
age = round(float(detected_objects.ages[0]))
gender_raw = detected_objects.genders[0] # "male" or "female"
gender_score = float(detected_objects.gender_scores[0])
# Format gender to match dashboard expectations
gender = "Man" if gender_raw == "male" else "Woman"
logger.info(f"MiVOLO Result — Age: {age}, Gender: {gender} ({gender_score:.2f})")
return {
"success": True,
"age": age,
"gender": gender,
"confidence": round(gender_score, 2),
"model_used": "MiVOLO D1"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
|