Gender-Detection / main.py
Ni3SinghR's picture
Upload 4 files
d4e1911 verified
# main.py
import uvicorn
import numpy as np
import clip
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from retinaface import RetinaFace
from PIL import Image
import io
import os
# --- Constants & Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS_DIR = "models"
GENDER_PROMPTS = ["a photo of a man", "a photo of a woman"]
# --- Error Messages ---
ERROR_MESSAGES = {
"NO_FACE": "No face detected. Please upload a clear, front-facing picture of a single person.",
"MULTIPLE_FACES": "Multiple faces detected. Please upload an image with only one face.",
"ANALYSIS_ERROR": "An unexpected error occurred during analysis. Please try again.",
"FILE_READ_ERROR": "Could not read the uploaded file. Please ensure it's a valid image."
}
# --- Model Loading ---
# Create models directory if it doesn't exist
os.makedirs(MODELS_DIR, exist_ok=True)
try:
print(f"Loading CLIP model on device: {DEVICE}...")
# Load the model, downloading to the specified directory if necessary
model, preprocess = clip.load("ViT-B/32", device=DEVICE, download_root=MODELS_DIR)
print("✓ CLIP model loaded successfully.")
except Exception as e:
print(f"✗ Failed to load CLIP model: {e}")
exit()
# --- FastAPI App Initialization ---
app = FastAPI(
title="Gender Detection API",
description="A simple API using CLIP to predict gender from an image."
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins for simplicity
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Core Logic ---
def predict_gender_with_clip(image: Image.Image) -> dict:
"""
Predicts gender from a PIL Image using the loaded CLIP model.
Args:
image (Image.Image): The input image.
Returns:
dict: A dictionary with gender labels and their confidence scores.
"""
image_input = preprocess(image).unsqueeze(0).to(DEVICE)
text_inputs = clip.tokenize(GENDER_PROMPTS).to(DEVICE)
with torch.no_grad():
logits_per_image, _ = model(image_input, text_inputs)
# Softmax to get probabilities
probabilities = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
# Map probabilities to labels
return {GENDER_PROMPTS[i].split("of a ")[-1]: float(prob) for i, prob in enumerate(probabilities)}
# --- API Endpoints ---
@app.get("/health")
async def health_check():
"""Health check endpoint to verify if the API is running."""
return {"status": "healthy"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Main prediction endpoint. It validates the image and returns gender probabilities.
"""
try:
# 1. Read and validate the uploaded image
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
# Convert to numpy array for face detection (expects BGR)
image_np = np.array(image_pil)
image_np = image_np[:, :, ::-1].copy() # RGB -> BGR
except Exception:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES["FILE_READ_ERROR"])
try:
# 2. Detect faces using RetinaFace
faces = RetinaFace.detect_faces(image_np)
num_faces = len(faces)
if num_faces == 0:
raise HTTPException(status_code=422, detail=ERROR_MESSAGES["NO_FACE"])
if num_faces > 1:
raise HTTPException(status_code=422, detail=ERROR_MESSAGES["MULTIPLE_FACES"])
# 3. Predict gender using CLIP
gender_probabilities = predict_gender_with_clip(image_pil)
return gender_probabilities
except HTTPException as e:
# Re-raise known HTTP exceptions
raise e
except Exception as e:
print(f"An unexpected error occurred: {e}")
raise HTTPException(status_code=500, detail=ERROR_MESSAGES["ANALYSIS_ERROR"])
# --- Main Execution ---
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)