File size: 4,453 Bytes
e5f3fc8
 
fe47fa8
 
9e03ade
fe47fa8
e5f3fc8
 
 
 
fe47fa8
e5f3fc8
 
fe47fa8
e5f3fc8
 
 
 
fe47fa8
 
 
 
 
 
 
 
e5f3fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
33572cd
e5f3fc8
 
33572cd
ab46ed3
e5f3fc8
ab46ed3
e5f3fc8
 
33572cd
ab46ed3
20083da
 
 
 
 
 
e5f3fc8
 
 
 
ab46ed3
e5f3fc8
 
 
 
 
 
 
 
 
33572cd
 
fe47fa8
e5f3fc8
 
 
 
 
33572cd
fe47fa8
 
 
 
e5f3fc8
fe47fa8
 
 
33572cd
fe47fa8
e5f3fc8
fe47fa8
251fe98
 
 
 
 
 
 
 
 
04c6577
 
 
51ee15b
e5f3fc8
04c6577
e5f3fc8
 
 
 
 
 
fe47fa8
e5f3fc8
 
 
 
fe47fa8
e5f3fc8
 
fe47fa8
e5f3fc8
33572cd
fe47fa8
 
e5f3fc8
fe47fa8
e5f3fc8
 
fe47fa8
 
e5f3fc8
 
fe47fa8
e5f3fc8
fe47fa8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sys
import os
import cv2
import numpy as np
import tempfile
import logging
import argparse

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Add the cloned MiVOLO repo to Python path
sys.path.insert(0, '/app/mivolo_repo')

app = FastAPI(title="MiVOLO Age & Gender Detection API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global predictor — loaded lazily on first request to avoid OOM during build
predictor = None


def get_predictor():
    """
    Lazy-loads the MiVOLO predictor on the first request.
    Downloads model weights from Hugging Face Hub automatically.
    """
    global predictor
    if predictor is not None:
        return predictor

    logger.info("Loading MiVOLO predictor for the first time...")

    from huggingface_hub import hf_hub_download
    from mivolo.predictor import Predictor

    # Download the YOLOv8 person+face detector weights from the public working repo
    detector_weights = hf_hub_download(
        repo_id="iitolstykh/demo_yolov8_detector",
        filename="yolov8x_person_face.pt"
    )

    # The age/gender checkpoint is no longer publicly hosted on HF.
    # We auto-detect any .pth.tar file uploaded to the Space.
    import glob
    checkpoint_files = glob.glob("*.pth.tar")
    if not checkpoint_files:
        raise FileNotFoundError("No checkpoint file ending in .pth.tar found in Space root! Please upload it.")
    checkpoint_path = checkpoint_files[0]

    # Build MiVOLO config
    config = argparse.Namespace(
        detector_weights=detector_weights,
        checkpoint=checkpoint_path,
        device="cpu",
        with_persons=True,     # Use full-body context for better accuracy
        disable_faces=False,   # Also use face features
        draw=False
    )

    predictor = Predictor(config, verbose=False)
    logger.info("MiVOLO predictor loaded successfully.")
    return predictor


@app.get("/")
def health_check():
    return {
        "status": "MiVOLO API is running!",
        "model": "MiVOLO D1 — State-of-the-Art Age & Gender Estimation"
    }


@app.post("/predict")
async def predict_age_gender(file: UploadFile = File(...)):
    try:
        # Read and decode image
        contents = await file.read()
        nparr = np.frombuffer(contents, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

        if img is None:
            raise HTTPException(status_code=400, detail="Invalid or unreadable image file.")

        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        # This dramatically improves face visibility in bad webcam lighting
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        l_channel, a_channel, b_channel = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        cl = clahe.apply(l_channel)
        limg = cv2.merge((cl, a_channel, b_channel))
        img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)

        # Convert BGR (OpenCV default) to RGB (MiVOLO expectation)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Run MiVOLO prediction directly on the numpy image array
        pred = get_predictor()
        detected_objects, _ = pred.recognize(img_rgb)

        if detected_objects is None or not detected_objects.ages:
            raise HTTPException(
                status_code=422,
                detail="No face detected. Please use a clear, well-lit photo."
            )

        # Take the primary (highest-confidence) detection
        age = round(float(detected_objects.ages[0]))
        gender_raw = detected_objects.genders[0]        # "male" or "female"
        gender_score = float(detected_objects.gender_scores[0])

        # Format gender to match dashboard expectations
        gender = "Man" if gender_raw == "male" else "Woman"

        logger.info(f"MiVOLO Result — Age: {age}, Gender: {gender} ({gender_score:.2f})")

        return {
            "success": True,
            "age": age,
            "gender": gender,
            "confidence": round(gender_score, 2),
            "model_used": "MiVOLO D1"
        }

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))