Spaces:
Runtime error
Runtime error
| # main.py | |
| import uvicorn | |
| import numpy as np | |
| import clip | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from retinaface import RetinaFace | |
| from PIL import Image | |
| import io | |
| import os | |
| # --- Constants & Configuration --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODELS_DIR = "models" | |
| GENDER_PROMPTS = ["a photo of a man", "a photo of a woman"] | |
| # --- Error Messages --- | |
| ERROR_MESSAGES = { | |
| "NO_FACE": "No face detected. Please upload a clear, front-facing picture of a single person.", | |
| "MULTIPLE_FACES": "Multiple faces detected. Please upload an image with only one face.", | |
| "ANALYSIS_ERROR": "An unexpected error occurred during analysis. Please try again.", | |
| "FILE_READ_ERROR": "Could not read the uploaded file. Please ensure it's a valid image." | |
| } | |
| # --- Model Loading --- | |
| # Create models directory if it doesn't exist | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| try: | |
| print(f"Loading CLIP model on device: {DEVICE}...") | |
| # Load the model, downloading to the specified directory if necessary | |
| model, preprocess = clip.load("ViT-B/32", device=DEVICE, download_root=MODELS_DIR) | |
| print("✓ CLIP model loaded successfully.") | |
| except Exception as e: | |
| print(f"✗ Failed to load CLIP model: {e}") | |
| exit() | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI( | |
| title="Gender Detection API", | |
| description="A simple API using CLIP to predict gender from an image." | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins for simplicity | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Core Logic --- | |
| def predict_gender_with_clip(image: Image.Image) -> dict: | |
| """ | |
| Predicts gender from a PIL Image using the loaded CLIP model. | |
| Args: | |
| image (Image.Image): The input image. | |
| Returns: | |
| dict: A dictionary with gender labels and their confidence scores. | |
| """ | |
| image_input = preprocess(image).unsqueeze(0).to(DEVICE) | |
| text_inputs = clip.tokenize(GENDER_PROMPTS).to(DEVICE) | |
| with torch.no_grad(): | |
| logits_per_image, _ = model(image_input, text_inputs) | |
| # Softmax to get probabilities | |
| probabilities = logits_per_image.softmax(dim=-1).cpu().numpy()[0] | |
| # Map probabilities to labels | |
| return {GENDER_PROMPTS[i].split("of a ")[-1]: float(prob) for i, prob in enumerate(probabilities)} | |
| # --- API Endpoints --- | |
| async def health_check(): | |
| """Health check endpoint to verify if the API is running.""" | |
| return {"status": "healthy"} | |
| async def predict(file: UploadFile = File(...)): | |
| """ | |
| Main prediction endpoint. It validates the image and returns gender probabilities. | |
| """ | |
| try: | |
| # 1. Read and validate the uploaded image | |
| contents = await file.read() | |
| image_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Convert to numpy array for face detection (expects BGR) | |
| image_np = np.array(image_pil) | |
| image_np = image_np[:, :, ::-1].copy() # RGB -> BGR | |
| except Exception: | |
| raise HTTPException(status_code=400, detail=ERROR_MESSAGES["FILE_READ_ERROR"]) | |
| try: | |
| # 2. Detect faces using RetinaFace | |
| faces = RetinaFace.detect_faces(image_np) | |
| num_faces = len(faces) | |
| if num_faces == 0: | |
| raise HTTPException(status_code=422, detail=ERROR_MESSAGES["NO_FACE"]) | |
| if num_faces > 1: | |
| raise HTTPException(status_code=422, detail=ERROR_MESSAGES["MULTIPLE_FACES"]) | |
| # 3. Predict gender using CLIP | |
| gender_probabilities = predict_gender_with_clip(image_pil) | |
| return gender_probabilities | |
| except HTTPException as e: | |
| # Re-raise known HTTP exceptions | |
| raise e | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=ERROR_MESSAGES["ANALYSIS_ERROR"]) | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |