Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Union, Dict, Any | |
| import uuid | |
| import json | |
| import os | |
| from datetime import datetime | |
| from handler import EndpointHandler | |
| import numpy as np | |
| # Run diagnostics on startup (safe import) | |
| try: | |
| from debug import run_all_checks | |
| run_all_checks() | |
| except Exception as e: | |
| print(f"⚠️ Diagnostics failed to run: {e}") | |
| app = FastAPI() | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # In-memory user session (stateless, resets on restart) | |
| user_sessions = {} | |
| USER_PREFERENCES_FILE = 'user_preferences.json' | |
| face_handler = EndpointHandler() | |
| # Pydantic model for recommendations | |
| class RecommendationRequest(BaseModel): | |
| query_images: List[str] = Field(..., description="List of Azure URLs for query images") | |
| gender: Optional[str] = Field('all', description="Gender filter: 'male', 'female', or 'all'") | |
| top_n: Optional[int] = Field(5, description="Number of recommendations to return") | |
| # Pydantic model for Hugging Face format | |
| class HuggingFaceRequest(BaseModel): | |
| inputs: RecommendationRequest | |
| # Helper functions | |
| def load_user_preferences(): | |
| if os.path.exists(USER_PREFERENCES_FILE): | |
| with open(USER_PREFERENCES_FILE, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def save_user_preferences(preferences): | |
| with open(USER_PREFERENCES_FILE, 'w') as f: | |
| json.dump(preferences, f, indent=2) | |
| def index(): | |
| # Serve the UI if needed, or just a welcome message | |
| return "<h2>FaceMatch FastAPI is running!</h2>" | |
| def health_check(): | |
| """Health check endpoint for Azure Container Apps""" | |
| return { | |
| "status": "healthy", | |
| "service": "facematch-api", | |
| "model_loaded": face_handler.app is not None | |
| } | |
| def init_user(): | |
| user_id = str(uuid.uuid4()) | |
| user_sessions[user_id] = True | |
| preferences = load_user_preferences() | |
| if user_id not in preferences: | |
| preferences[user_id] = { | |
| 'liked_images': [], | |
| 'disliked_images': [], | |
| 'preference_embedding': None, | |
| 'created_at': datetime.now().isoformat() | |
| } | |
| save_user_preferences(preferences) | |
| return {"user_id": user_id, "status": "initialized"} | |
| def get_training_images(): | |
| try: | |
| training_images = [] | |
| for gender_folder in ['men', 'women']: | |
| gender_prefix = f'ai-images/{gender_folder}/' | |
| blob_list = face_handler.container_client.list_blobs(name_starts_with=gender_prefix) | |
| for blob in blob_list: | |
| if blob.name.endswith(('.jpg', '.jpeg', '.png')): | |
| image_url = f'https://{face_handler.blob_service_client.account_name}.blob.core.windows.net/{face_handler.container_name}/{blob.name}' | |
| training_images.append(image_url) | |
| return {"training_images": training_images[:10], "status": "success"} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def record_preference(request: Request): | |
| try: | |
| data = await request.json() | |
| user_id = data.get('user_id') | |
| image_url = data.get('image_url') | |
| preference = data.get('preference') | |
| if not user_id or not image_url or not preference: | |
| raise HTTPException(status_code=400, detail="Missing required parameters") | |
| preferences = load_user_preferences() | |
| if user_id not in preferences: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| if preference == 'like': | |
| if image_url not in preferences[user_id]['liked_images']: | |
| preferences[user_id]['liked_images'].append(image_url) | |
| elif preference == 'dislike': | |
| if image_url not in preferences[user_id]['disliked_images']: | |
| preferences[user_id]['disliked_images'].append(image_url) | |
| save_user_preferences(preferences) | |
| return {"status": "preference_recorded"} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def get_matches(request: Request): | |
| try: | |
| data = await request.json() | |
| user_id = data.get('user_id') | |
| gender = data.get('gender', 'all') | |
| top_n = data.get('top_n', 10) | |
| if not user_id: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| preferences = load_user_preferences() | |
| if user_id not in preferences: | |
| raise HTTPException(status_code=404, detail="User preferences not found") | |
| user_prefs = preferences[user_id] | |
| if user_prefs['liked_images']: | |
| liked_embeddings = [] | |
| for image_url in user_prefs['liked_images']: | |
| try: | |
| img = face_handler.load_image_from_url(image_url) | |
| faces = face_handler.app.get(img) | |
| if len(faces) > 0: | |
| liked_embeddings.append(faces[0].embedding) | |
| except Exception as e: | |
| continue | |
| if liked_embeddings: | |
| preference_embedding = np.mean(liked_embeddings, axis=0) | |
| user_prefs['preference_embedding'] = preference_embedding.tolist() | |
| save_user_preferences(preferences) | |
| similar_images = face_handler.find_similar_images_by_embedding( | |
| preference_embedding, gender, top_n, user_prefs['disliked_images'] | |
| ) | |
| return {"similar_images": similar_images} | |
| return {"similar_images": []} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def get_recommendations( | |
| body: Union[RecommendationRequest, HuggingFaceRequest] = Body(...) | |
| ): | |
| try: | |
| # Handle both direct format and Hugging Face format | |
| if isinstance(body, HuggingFaceRequest): | |
| # Hugging Face format: {"inputs": {...}} | |
| query_images = body.inputs.query_images | |
| gender = body.inputs.gender or 'all' | |
| top_n = body.inputs.top_n or 5 | |
| else: | |
| # Direct format: {...} | |
| query_images = body.query_images | |
| gender = body.gender or 'all' | |
| top_n = body.top_n or 5 | |
| if not query_images: | |
| raise HTTPException(status_code=400, detail="No query images provided") | |
| similar_images = face_handler.find_similar_images_aggregate(query_images, gender, top_n) | |
| if not similar_images: | |
| return {"message": "No suggestions found please try with other images."} | |
| return {"similar_images": similar_images} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| def extract_embeddings(): | |
| try: | |
| face_handler.extract_and_save_embeddings() | |
| return {"status": "Embeddings extraction completed"} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) |