retinaface / app.py
benstaf's picture
Update app.py
feef697 verified
import os
import tempfile
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set environment variables BEFORE importing feat
os.environ['FEAT_MODELS_DIR'] = '/tmp/feat_models'
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['TORCH_HOME'] = '/tmp/torch'
os.environ['HF_HOME'] = '/tmp/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
# Create necessary directories
for dir_path in ['/tmp/feat_models', '/tmp/matplotlib', '/tmp/torch', '/tmp/huggingface', '/tmp/transformers']:
os.makedirs(dir_path, exist_ok=True)
# Import py-feat after setting environment variables
try:
import feat
print("Py-feat imported.")
logger.info("Py-feat imported successfully")
except Exception as e:
print(f"Warning: py-feat import failed at app start: {e}")
logger.error(f"Py-feat import failed: {e}")
app = FastAPI(title="Face Detection API")
# Global variable to store the detector instance
global_face_detector = None
@app.on_event("startup")
async def load_model():
global global_face_detector
try:
logger.info("Starting model loading process...")
print("Using CPU for inference.")
# Define a writable cache directory for models
model_cache_dir = "/tmp/feat_models_cache"
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Set model cache directory to: {model_cache_dir}")
logger.info(f"Model cache directory: {model_cache_dir}")
# Try to monkey patch the resource path function if needed
try:
import feat.utils
# Override the get_resource_path function to use /tmp
def custom_get_resource_path():
cache_dir = "/tmp/feat_models"
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
# Patch the function if it exists
if hasattr(feat.utils, 'get_resource_path'):
feat.utils.get_resource_path = custom_get_resource_path
logger.info("Patched feat.utils.get_resource_path")
except Exception as patch_error:
logger.warning(f"Could not patch resource path function: {patch_error}")
# Set additional environment variables that py-feat might use
os.environ['HF_HOME'] = model_cache_dir
os.environ['FEAT_CACHE_DIR'] = model_cache_dir
# Try different model configurations based on what's available
model_configs = [
# Configuration 1: Your original setup
{
"face_model": "retinaface",
"landmark_model": "mobilefacenet",
"au_model": "xgb",
"emotion_model": "resmasknet",
"facebox_model": "retinaface",
"device": "cpu"
},
# Configuration 2: Simpler setup
{
"face_model": "retinaface",
"landmark_model": "mobilenet",
"au_model": "svm",
"emotion_model": "resmasknet",
"device": "cpu"
},
# Configuration 3: Minimal setup
{
"face_model": "retinaface",
"device": "cpu"
}
]
for i, config in enumerate(model_configs):
try:
logger.info(f"Trying model configuration {i+1}: {config}")
global_face_detector = feat.detector.Detector(**config, verbose=True)
logger.info(f"Model configuration {i+1} loaded successfully!")
print("RetinaFace model loaded successfully.")
break
except Exception as config_error:
logger.warning(f"Model configuration {i+1} failed: {config_error}")
if i == len(model_configs) - 1: # Last configuration
raise config_error
continue
except PermissionError as pe:
error_msg = f"Permission error during model loading: {pe}"
logger.error(error_msg)
print(f"Failed to load RetinaFace model during startup: {pe}")
raise RuntimeError(f"Could not load model due to permissions: {pe}")
except Exception as e:
error_msg = f"Failed to load RetinaFace model during startup: {e}"
logger.error(error_msg)
print(error_msg)
# Print additional debugging information
logger.info("Environment variables:")
for key in ['FEAT_MODELS_DIR', 'MPLCONFIGDIR', 'TORCH_HOME', 'HF_HOME']:
logger.info(f" {key}: {os.environ.get(key, 'Not set')}")
# Check directory permissions
for dir_path in ['/tmp/feat_models', '/tmp/matplotlib', '/tmp/torch']:
if os.path.exists(dir_path):
logger.info(f"Directory {dir_path} exists, writable: {os.access(dir_path, os.W_OK)}")
else:
logger.warning(f"Directory {dir_path} does not exist")
raise RuntimeError(f"Could not load model: {e}")
# Define a Pydantic model for the response expected by your client script
class FaceDetectionResult(BaseModel):
box: List[float]
landmarks: Dict[str, List[float]]
@app.post("/extract-faces/", response_model=List[FaceDetectionResult])
async def extract_faces_api(file: UploadFile = File(...)):
if not global_face_detector:
raise HTTPException(status_code=500, detail="Face detector not initialized.")
# Create a temporary file to save the uploaded image in a writable location
temp_dir = "/tmp/"
os.makedirs(temp_dir, exist_ok=True)
temp_image_path = os.path.join(temp_dir, file.filename)
try:
# Read the uploaded file content and save it
contents = await file.read()
with open(temp_image_path, "wb") as f:
f.write(contents)
logger.info(f"Image saved to temporary path: {temp_image_path}")
print(f"Server: Image saved to temporary path: {temp_image_path}")
# --- THIS IS THE CRITICAL SECTION FOR ERROR LOGGING ---
try:
# Use the global detector to detect faces
logger.info("Starting face detection...")
detected_faces_df = global_face_detector.detect_faces(
input_path=temp_image_path
)
results = []
if detected_faces_df is not None and not detected_faces_df.empty:
logger.info(f"Found {len(detected_faces_df)} faces in DataFrame")
for idx, row in detected_faces_df.iterrows():
try:
# Extract bounding box - try different column name formats
box_coords = None
# Try different possible column names for bounding box
if all(col in row for col in ['FaceRectX', 'FaceRectY', 'FaceRectWidth', 'FaceRectHeight']):
x, y, w, h = row['FaceRectX'], row['FaceRectY'], row['FaceRectWidth'], row['FaceRectHeight']
box_coords = [float(x), float(y), float(x + w), float(y + h)]
elif all(col in row for col in ['x', 'y', 'w', 'h']):
x, y, w, h = row['x'], row['y'], row['w'], row['h']
box_coords = [float(x), float(y), float(x + w), float(y + h)]
else:
# Log available columns for debugging
logger.warning(f"Unknown bounding box format. Available columns: {list(row.index)}")
# Try to find any numeric columns that might be coordinates
numeric_cols = [col for col in row.index if isinstance(row[col], (int, float))]
logger.info(f"Numeric columns: {numeric_cols}")
continue
# Extract landmarks
current_landmarks = {}
# Try different landmark naming conventions
for k in range(10): # Check up to 10 landmarks
# Try different naming patterns
patterns = [
(f'Landmark{k}_x', f'Landmark{k}_y'),
(f'landmark_{k}_x', f'landmark_{k}_y'),
(f'lm{k}_x', f'lm{k}_y'),
(f'point{k}_x', f'point{k}_y')
]
for x_col, y_col in patterns:
if x_col in row and y_col in row:
lm_x, lm_y = row[x_col], row[y_col]
if lm_x is not None and lm_y is not None:
current_landmarks[f'point{k}'] = [float(lm_x), float(lm_y)]
break
if box_coords:
results.append(FaceDetectionResult(box=box_coords, landmarks=current_landmarks))
except Exception as row_error:
logger.error(f"Error processing face {idx}: {row_error}")
continue
else:
logger.info("No faces detected or empty DataFrame returned")
logger.info(f"Successfully processed {len(results)} faces")
print(f"Server: Detected {len(results)} faces.")
return results
except Exception as e:
# THIS IS WHERE THE ERROR WILL BE PRINTED
error_msg = f"Exception during face extraction: {e}"
logger.error(f"Server ERROR: {error_msg}")
print(f"Server ERROR: {error_msg}")
import traceback
traceback.print_exc() # Prints the full traceback to logs
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
finally:
# Clean up the temporary file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
logger.info(f"Cleaned up temporary file: {temp_image_path}")
print(f"Server: Cleaned up temporary file: {temp_image_path}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": global_face_detector is not None,
"cache_dirs": {
"feat_models": os.environ.get('FEAT_MODELS_DIR'),
"matplotlib": os.environ.get('MPLCONFIGDIR'),
"torch": os.environ.get('TORCH_HOME'),
"hf_home": os.environ.get('HF_HOME')
}
}
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Face Detection API is running",
"endpoints": ["/extract-faces/", "/health"],
"model": "RetinaFace via py-feat"
}