| from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Response, Form |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, RedirectResponse |
| from starlette.middleware.sessions import SessionMiddleware |
| from authlib.integrations.starlette_client import OAuth |
| from datetime import datetime, timedelta |
| import jwt |
| import bcrypt |
| import json |
| import requests |
|
|
| import torch |
| import numpy as np |
| from PIL import Image |
| import io |
| import logging |
| import gc |
| from torchvision import transforms |
| import timm |
| import os |
| import sys |
| from dotenv import load_dotenv |
| import redis |
| import hashlib |
| from ytmusicapi import YTMusic |
|
|
| |
| load_dotenv() |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| app_dir = '/home/user' |
| api_dir = os.path.join(app_dir, 'api') |
|
|
| sys.path.insert(0, current_dir) |
| sys.path.insert(0, app_dir) |
| sys.path.insert(0, api_dir) |
| sys.path.insert(0, os.path.join(api_dir, 'models')) |
| sys.path.insert(0, os.path.join(api_dir, 'utils')) |
| sys.path.insert(0, os.path.join(api_dir, 'recommender')) |
|
|
| |
| os.chdir('/home/user') |
|
|
| |
| def configure_tensorflow_memory(): |
| try: |
| import tensorflow as tf |
| tf.config.threading.set_intra_op_parallelism_threads(2) |
| tf.config.threading.set_inter_op_parallelism_threads(2) |
| tf.config.set_visible_devices([], 'GPU') |
| logger.info("✅ TensorFlow configured for memory optimization") |
| return True |
| except Exception as e: |
| logger.error(f"❌ TensorFlow configuration failed: {e}") |
| return False |
|
|
| |
| DEEPFACE_AVAILABLE = False |
| if configure_tensorflow_memory(): |
| try: |
| from deepface import DeepFace |
| DEEPFACE_AVAILABLE = True |
| logger.info("🎉 DeepFace loaded with memory optimization on Hugging Face Spaces!") |
| except Exception as e: |
| logger.error(f"❌ DeepFace loading failed: {e}") |
| DEEPFACE_AVAILABLE = False |
|
|
| |
| app = FastAPI( |
| title="Vibe Detection Backend API", |
| description="Professional emotion recognition API with DeepFace + AA-DCN + HybridResNetViT + YouTube Music + Redis Caching + Authentication", |
| version="1.0.0", |
| docs_url="/docs", |
| redoc_url="/redoc" |
| ) |
|
|
| |
| app.add_middleware( |
| SessionMiddleware, |
| secret_key="vibe-detection-session-secret-key-very-long-and-secure-fixed-key", |
| max_age=3600, |
| same_site="lax", |
| https_only=False, |
| session_cookie="session_id" |
| ) |
|
|
| |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=[ |
| "https://huggingface.co", |
| "https://*.hf.space", |
| "https://vaibhav07112004-vibe-detection-backend-api.hf.space", |
| "https://vaibhav07112004-vibestory-frontend.hf.space", |
| "http://localhost:3000", |
| "*" |
| ], |
| allow_credentials=True, |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], |
| allow_headers=["*"], |
| expose_headers=["Set-Cookie", "Authorization"] |
| ) |
|
|
|
|
| |
| oauth = OAuth() |
| oauth.register( |
| name='google', |
| client_id=os.getenv('GOOGLE_CLIENT_ID', '1023474298007-b0obgdmpr9mj2j5o3hf6ah5ulchab9f4.apps.googleusercontent.com'), |
| client_secret=os.getenv('GOOGLE_CLIENT_SECRET', 'your-google-client-secret'), |
| server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
| client_kwargs={ |
| 'scope': 'openid email profile' |
| } |
| ) |
|
|
| |
| ML_AVAILABLE = False |
| CUSTOM_MODULES_AVAILABLE = False |
|
|
| |
| try: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ML_AVAILABLE = True |
| logger.info(f"✅ ML dependencies loaded successfully on {device}") |
| except ImportError as e: |
| logger.warning(f"❌ ML dependencies not available: {e}") |
| ML_AVAILABLE = False |
|
|
| |
| redis_client = None |
| try: |
| redis_url = os.getenv("REDIS_URL") |
| if redis_url: |
| redis_client = redis.from_url(redis_url, decode_responses=True) |
| redis_client.ping() |
| logger.info("✅ Redis connected successfully!") |
| else: |
| logger.warning("⚠️ REDIS_URL not found in environment variables") |
| except Exception as e: |
| logger.error(f"❌ Redis connection failed: {e}") |
| redis_client = None |
|
|
| |
| try: |
| ytmusic = YTMusic() |
| logger.info("✅ YouTube Music API initialized successfully!") |
| except Exception as e: |
| logger.error(f"❌ YouTube Music API initialization failed: {e}") |
| ytmusic = None |
|
|
| |
| def recommend_songs_for_moods_and_genres(moods, genres, limit_per_mood=50, limit_per_genre=50): |
| if not ytmusic: |
| logger.warning("⚠️ YouTube Music API not available") |
| return [ |
| {"title": "Sample Song", "artist": "Sample Artist", "url": "https://music.youtube.com", "source": "fallback"} |
| ] |
| all_results = [] |
| seen_ids = set() |
| try: |
| for mood in moods: |
| query = f"{mood} songs" |
| search_results = ytmusic.search(query, filter="songs") |
| for song in search_results[:limit_per_mood]: |
| song_id = song.get('videoId') |
| if song_id and song_id not in seen_ids: |
| seen_ids.add(song_id) |
| all_results.append({ |
| "title": song['title'], |
| "artist": song['artists'][0]['name'], |
| "url": f"https://music.youtube.com/watch?v={song['videoId']}", |
| "source": "mood" |
| }) |
| for genre in genres: |
| query = f"{genre} songs" |
| search_results = ytmusic.search(query, filter="songs") |
| for song in search_results[:limit_per_genre]: |
| song_id = song.get('videoId') |
| if song_id and song_id not in seen_ids: |
| seen_ids.add(song_id) |
| all_results.append({ |
| "title": song['title'], |
| "artist": song['artists'][0]['name'], |
| "url": f"https://music.youtube.com/watch?v={song['videoId']}", |
| "source": "genre" |
| }) |
| logger.info(f"🎵 YouTube Music API returned {len(all_results)} songs") |
| return all_results |
| except Exception as e: |
| logger.error(f"❌ YouTube Music API error: {e}") |
| return [ |
| {"title": "Fallback Song", "artist": "Fallback Artist", "url": "https://music.youtube.com", "source": "error"} |
| ] |
|
|
| |
| face_model = None |
| try: |
| |
| from api.models.model_definitions import create_aadcn_model, HybridResNetViT |
| from api.utils.image_utils import detect_face, preprocess_image |
| from api.utils.mood_utils import map_vibe_to_moods, get_genres_for_moods, load_mood_genre_mapping |
| from api.models.vibe_model import predict_vibe |
| from api.config import ensure_models_available, FACE_MODEL_PATH |
| |
| CUSTOM_MODULES_AVAILABLE = True |
| logger.info("✅ Custom modules loaded successfully from api/ folder!") |
| |
| |
| load_mood_genre_mapping() |
| |
| |
| if ensure_models_available(): |
| face_model = create_aadcn_model(num_classes=8) |
| face_model.load_state_dict(torch.load(FACE_MODEL_PATH, map_location=device)) |
| face_model.eval() |
| logger.info("🎉 AA-DCN face model loaded successfully!") |
| |
| except ImportError as e: |
| logger.error(f"❌ Custom modules import failed: {e}") |
| CUSTOM_MODULES_AVAILABLE = False |
|
|
| |
| emotion_idx_to_label = { |
| 0: 'angry', |
| 1: 'contempt', |
| 2: 'disgust', |
| 3: 'fear', |
| 4: 'happy', |
| 5: 'neutral', |
| 6: 'sad', |
| 7: 'surprise' |
| } |
|
|
| dataset_to_custom = { |
| 'angry': ['Anger', 'Annoyance', 'Disapproval'], |
| 'contempt': ['Disapproval', 'Disconnection', 'Annoyance'], |
| 'disgust': ['Aversion', 'Disapproval', 'Disconnection'], |
| 'fear': ['Fear', 'Disquietment', 'Doubt/Confusion'], |
| 'happy': ['Happiness', 'Affection', 'Pleasure', 'Excitement'], |
| 'neutral': ['Peace', 'Esteem', 'Confidence'], |
| 'sad': ['Sadness', 'Fatigue', 'Suffering'], |
| 'surprise': ['Surprise', 'Anticipation', 'Excitement'] |
| } |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
| ]) if ML_AVAILABLE else None |
|
|
| |
|
|
| |
| def hash_password(password: str) -> str: |
| """Hash password using bcrypt""" |
| return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') |
|
|
| def verify_password(password: str, hashed: str) -> bool: |
| """Verify password against hash""" |
| return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8')) |
|
|
| def create_jwt_token(user_data: dict) -> str: |
| """Create JWT token""" |
| token_data = { |
| **user_data, |
| "exp": datetime.utcnow() + timedelta(hours=24) |
| } |
| return jwt.encode(token_data, "vibe-detection-secret-key", algorithm="HS256") |
|
|
| |
| def store_user_account(email: str, user_data: dict): |
| """Store user account in Redis with email-based linking""" |
| if not redis_client: |
| logger.warning("⚠️ Redis not available - cannot store user account") |
| return |
| |
| try: |
| |
| redis_client.hset(f"user:{email}", mapping=user_data) |
| |
| |
| primary_email = user_data.get('primary_email', email) |
| |
| |
| redis_client.sadd(f"user_accounts:{primary_email}", email) |
| |
| |
| redis_client.set(f"primary_lookup:{email}", primary_email) |
| |
| logger.info(f"✅ User account stored and linked in Redis: {email}") |
| except Exception as e: |
| logger.error(f"❌ Redis user storage error: {e}") |
|
|
| def get_user_account(email: str) -> dict: |
| """Get user account from Redis""" |
| if not redis_client: |
| return None |
| |
| try: |
| user_data = redis_client.hgetall(f"user:{email}") |
| return user_data if user_data else None |
| except Exception as e: |
| logger.error(f"❌ Redis user retrieval error: {e}") |
| return None |
|
|
| def get_all_linked_accounts(current_email: str) -> list: |
| """Get all accounts linked to this email""" |
| if not redis_client: |
| return [] |
| |
| try: |
| |
| primary_email = redis_client.get(f"primary_lookup:{current_email}") or current_email |
| |
| |
| linked_emails = redis_client.smembers(f"user_accounts:{primary_email}") |
| |
| |
| accounts = [] |
| for email in linked_emails: |
| account_data = redis_client.hgetall(f"user:{email}") |
| if account_data: |
| accounts.append({ |
| 'email': email, |
| 'name': account_data.get('name'), |
| 'auth_type': account_data.get('auth_type'), |
| 'picture': account_data.get('picture'), |
| 'is_current': email == current_email, |
| 'created_at': account_data.get('created_at') |
| }) |
| |
| return accounts |
| except Exception as e: |
| logger.error(f"❌ Get linked accounts error: {e}") |
| return [] |
|
|
| |
|
|
| @app.post("/auth/signup") |
| async def signup( |
| response: Response, |
| email: str = Form(...), |
| password: str = Form(...), |
| name: str = Form(...) |
| ): |
| """Redis-based signup with email/password""" |
| try: |
| logger.info(f"📝 Signup attempt for: {email}") |
| |
| |
| existing_user = get_user_account(email) |
| if existing_user: |
| raise HTTPException(status_code=400, detail="User already exists") |
| |
| |
| hashed_password = hash_password(password) |
| user_data = { |
| "id": f"user_{email.split('@')[0]}_{int(datetime.utcnow().timestamp())}", |
| "email": email, |
| "name": name, |
| "password": hashed_password, |
| "auth_type": "email", |
| "primary_email": email, |
| "created_at": datetime.utcnow().isoformat() |
| } |
| |
| |
| store_user_account(email, user_data) |
| |
| |
| token_data = { |
| "sub": user_data["id"], |
| "email": email, |
| "name": name, |
| "auth_type": "email" |
| } |
| token = create_jwt_token(token_data) |
| |
| |
| response.set_cookie( |
| key="access_token", |
| value=token, |
| httponly=False, |
| secure=True, |
| samesite="none", |
| max_age=86400, |
| path="/" |
| |
| ) |
| |
| logger.info(f"✅ User signed up and stored in Redis: {email}") |
| return { |
| "message": "Signup successful", |
| "user": { |
| "id": user_data["id"], |
| "email": email, |
| "name": name |
| } |
| } |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ Signup error: {e}") |
| raise HTTPException(status_code=500, detail=f"Signup failed: {str(e)}") |
|
|
| @app.post("/auth/login") |
| async def login( |
| response: Response, |
| email: str = Form(...), |
| password: str = Form(...) |
| ): |
| """Redis-based login with email/password""" |
| try: |
| logger.info(f"🔐 Login attempt for: {email}") |
| |
| |
| user = get_user_account(email) |
| if not user: |
| raise HTTPException(status_code=401, detail="Invalid credentials") |
| |
| |
| if not verify_password(password, user["password"]): |
| raise HTTPException(status_code=401, detail="Invalid credentials") |
| |
| |
| token_data = { |
| "sub": user["id"], |
| "email": email, |
| "name": user["name"], |
| "auth_type": "email" |
| } |
| token = create_jwt_token(token_data) |
| |
| |
| response.set_cookie( |
| key="access_token", |
| value=token, |
| httponly=False, |
| secure=True, |
| samesite="none", |
| max_age=86400, |
| path="/" |
| |
| ) |
| |
| logger.info(f"✅ User logged in from Redis: {email}") |
| return { |
| "message": "Login successful", |
| "user": { |
| "id": user["id"], |
| "email": email, |
| "name": user["name"] |
| } |
| } |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ Login error: {e}") |
| raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
| @app.get("/auth/google/login") |
| async def login_google(): |
| """Get Google login URL""" |
| params = { |
| "response_type": "code", |
| "client_id": os.getenv('GOOGLE_CLIENT_ID'), |
| "redirect_uri": "https://vaibhav07112004-vibe-detection-backend-api.hf.space/auth/google/callback", |
| "scope": "openid email profile" |
| } |
| |
| query_string = "&".join(f"{key}={value}" for key, value in params.items()) |
| auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?{query_string}" |
| |
| return {"url": auth_url} |
|
|
| @app.get("/auth/google") |
| async def google_login_redirect(request: Request, primary_email: str = None): |
| """Redirect to Google OAuth - FIXED VERSION""" |
| try: |
| |
| if not hasattr(request, 'session'): |
| logger.error("❌ Session not available in request") |
| raise HTTPException(status_code=500, detail="Session middleware not configured") |
| |
| |
| if primary_email: |
| request.session['primary_email'] = primary_email |
| logger.info(f"🔗 Storing primary_email in session: {primary_email}") |
| |
| |
| redirect_uri = "https://vaibhav07112004-vibe-detection-backend-api.hf.space/auth/google/callback" |
| logger.info(f"🔗 OAuth redirect URI: {redirect_uri}") |
| |
| |
| session_keys_to_clear = [k for k in request.session.keys() if k.startswith('_state_')] |
| for key in session_keys_to_clear: |
| del request.session[key] |
| |
| |
| return await oauth.google.authorize_redirect(request, redirect_uri) |
| |
| except Exception as e: |
| logger.error(f"❌ Google OAuth initialization error: {e}") |
| raise HTTPException(status_code=500, detail=f"Google OAuth initialization failed: {str(e)}") |
|
|
| @app.get("/auth/google/callback") |
| async def google_callback(request: Request, response: Response, code: str = None, state: str = None): |
| """Handle Google OAuth callback with enhanced error handling""" |
| try: |
| if not code: |
| raise HTTPException(status_code=400, detail="No code provided") |
| |
| logger.info(f"🔄 Processing Google OAuth callback with state: {state}") |
| |
| |
| try: |
| token = await oauth.google.authorize_access_token(request) |
| user_info = token.get('userinfo') |
| logger.info("✅ OAuth token retrieved via authlib") |
| except Exception as authlib_error: |
| logger.warning(f"⚠️ Authlib failed: {authlib_error}") |
| |
| |
| try: |
| logger.info("🔄 Attempting manual token exchange") |
| |
| token_data = { |
| "code": code, |
| "client_id": os.getenv('GOOGLE_CLIENT_ID'), |
| "client_secret": os.getenv('GOOGLE_CLIENT_SECRET'), |
| "redirect_uri": "https://vaibhav07112004-vibe-detection-backend-api.hf.space/auth/google/callback", |
| "grant_type": "authorization_code", |
| } |
| |
| |
| token_response = requests.post( |
| "https://oauth2.googleapis.com/token", |
| data=token_data, |
| headers={"Content-Type": "application/x-www-form-urlencoded"} |
| ) |
| |
| logger.info(f"📝 Token exchange response status: {token_response.status_code}") |
| |
| if not token_response.ok: |
| logger.error(f"❌ Token exchange failed: {token_response.text}") |
| raise HTTPException(status_code=400, detail=f"Token exchange failed: {token_response.text}") |
| |
| token_data = token_response.json() |
| access_token = token_data.get("access_token") |
| |
| if not access_token: |
| raise HTTPException(status_code=400, detail="No access token received") |
| |
| |
| user_info_response = requests.get( |
| "https://www.googleapis.com/oauth2/v2/userinfo", |
| headers={"Authorization": f"Bearer {access_token}"} |
| ) |
| |
| if not user_info_response.ok: |
| logger.error(f"❌ User info request failed: {user_info_response.text}") |
| raise HTTPException(status_code=400, detail="Failed to get user info") |
| |
| user_info = user_info_response.json() |
| logger.info("✅ OAuth token retrieved via manual method") |
| |
| except requests.RequestException as req_error: |
| logger.error(f"❌ Manual token exchange request error: {req_error}") |
| raise HTTPException(status_code=400, detail=f"Network error during token exchange: {str(req_error)}") |
| |
| |
| if not user_info: |
| raise HTTPException(status_code=400, detail="Failed to get user info from Google") |
| |
| |
| email = user_info.get('email') |
| name = user_info.get('name') |
| google_id = user_info.get('id') or user_info.get('sub') |
| picture = user_info.get('picture') |
| |
| |
| primary_email = request.session.get('primary_email', email) |
| logger.info(f"🔗 Using primary_email: {primary_email}") |
| |
| |
| user_data = { |
| "id": f"google_{google_id}", |
| "email": email, |
| "name": name, |
| "google_id": google_id, |
| "picture": picture, |
| "auth_type": "google", |
| "primary_email": primary_email, |
| "created_at": datetime.utcnow().isoformat() |
| } |
| |
| store_user_account(email, user_data) |
| |
| |
| token_data = { |
| "sub": user_data["id"], |
| "email": email, |
| "name": name, |
| "picture": picture, |
| "auth_type": "google" |
| } |
| jwt_token = create_jwt_token(token_data) |
| |
| |
| response.set_cookie( |
| key="access_token", |
| value=jwt_token, |
| httponly=False, |
| secure=True, |
| samesite="none", |
| max_age=86400, |
| path="/" |
| |
| ) |
| |
| logger.info(f"✅ Google user authenticated: {email}") |
| |
| |
| frontend_url = "https://huggingface.co/spaces/vaibhav07112004/vibestory-frontend" |
| return RedirectResponse(url=frontend_url) |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ Google OAuth callback error: {e}") |
| raise HTTPException(status_code=400, detail=f"Google authentication failed: {str(e)}") |
|
|
| @app.get("/auth/me") |
| async def get_current_user(request: Request): |
| """Get current user with all linked accounts""" |
| try: |
| |
| token = request.cookies.get("access_token") |
| |
| |
| if not token: |
| auth_header = request.headers.get("authorization") |
| if auth_header and auth_header.startswith("Bearer "): |
| token = auth_header.split(" ")[1] |
| |
| if not token: |
| logger.warning("❌ No access token found in cookies or headers") |
| raise HTTPException(status_code=401, detail="Not authenticated") |
| |
| |
| try: |
| payload = jwt.decode(token, "vibe-detection-secret-key", algorithms=["HS256"]) |
| user_id = payload.get("sub") |
| email = payload.get("email") |
| |
| if not user_id: |
| raise HTTPException(status_code=401, detail="Invalid token") |
| |
| |
| linked_accounts = get_all_linked_accounts(email) |
| |
| logger.info(f"✅ User authenticated: {user_id}") |
| return { |
| "id": user_id, |
| "email": email, |
| "name": payload.get("name"), |
| "picture": payload.get("picture"), |
| "auth_type": payload.get("auth_type"), |
| "authenticated": True, |
| "linked_accounts": linked_accounts, |
| "total_accounts": len(linked_accounts) |
| } |
| except jwt.ExpiredSignatureError: |
| raise HTTPException(status_code=401, detail="Token expired") |
| except jwt.InvalidTokenError: |
| raise HTTPException(status_code=401, detail="Invalid token") |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ Auth check error: {e}") |
| raise HTTPException(status_code=500, detail="Authentication check failed") |
|
|
| @app.get("/auth/accounts/all") |
| async def get_all_user_accounts(request: Request): |
| """Get all accounts linked to current user across devices""" |
| try: |
| |
| token = request.cookies.get("access_token") |
| if not token: |
| raise HTTPException(status_code=401, detail="Not authenticated") |
| |
| payload = jwt.decode(token, "vibe-detection-secret-key", algorithms=["HS256"]) |
| email = payload.get("email") |
| |
| |
| linked_accounts = get_all_linked_accounts(email) |
| |
| return { |
| "accounts": linked_accounts, |
| "total_accounts": len(linked_accounts), |
| "current_email": email |
| } |
| |
| except Exception as e: |
| logger.error(f"❌ Get all accounts error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to retrieve accounts") |
|
|
| @app.post("/auth/logout") |
| async def logout(response: Response): |
| """Logout user""" |
| try: |
| response.delete_cookie("access_token", path="/") |
| logger.info("✅ User logged out successfully") |
| return {"message": "Logout successful"} |
| except Exception as e: |
| logger.error(f"❌ Logout error: {e}") |
| raise HTTPException(status_code=500, detail="Logout failed") |
|
|
| @app.get("/debug/cookies") |
| async def debug_cookies(request: Request): |
| """Debug cookie transmission""" |
| return { |
| "cookies": dict(request.cookies), |
| "headers": dict(request.headers), |
| "has_access_token": "access_token" in request.cookies, |
| "auth_header": request.headers.get("authorization"), |
| "cookie_header": request.headers.get("cookie") |
| } |
|
|
| @app.get("/debug/oauth-state") |
| async def debug_oauth_state(request: Request): |
| """Debug OAuth state in session""" |
| session_data = dict(request.session) |
| state_keys = [k for k in session_data.keys() if 'state' in k.lower()] |
| |
| return { |
| "all_session_keys": list(session_data.keys()), |
| "state_keys": state_keys, |
| "session_data": session_data, |
| "has_google_state": any('google' in k for k in state_keys) |
| } |
|
|
| @app.get("/debug/session") |
| async def debug_session(request: Request): |
| """Debug session state""" |
| return { |
| "session_keys": list(request.session.keys()), |
| "session_data": dict(request.session), |
| "has_session": hasattr(request, 'session'), |
| "session_id": getattr(request.session, 'session_id', 'No session ID') |
| } |
|
|
| |
|
|
| def predict_face_emotion_hybrid(image_bytes): |
| """YOUR EXACT FACE EMOTION LOGIC: DeepFace for happy → AA-DCN for others""" |
| |
| |
| if DEEPFACE_AVAILABLE: |
| try: |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| np_image = np.array(image) |
| result = DeepFace.analyze( |
| np_image, |
| actions=['emotion'], |
| enforce_detection=False, |
| detector_backend='opencv' |
| ) |
| if isinstance(result, list): |
| result = result[0] |
| |
| |
| if result['dominant_emotion'] == 'happy': |
| logger.info("🎉 DeepFace detected HAPPY - returning happy result!") |
| gc.collect() |
| return 'happy', dataset_to_custom['happy'] |
| else: |
| logger.info(f"🔄 DeepFace detected {result['dominant_emotion']} (not happy) - sending to AA-DCN") |
| gc.collect() |
| |
| except Exception as e: |
| logger.warning(f"DeepFace error: {e}. Falling back to AA-DCN model.") |
| gc.collect() |
|
|
| |
| if CUSTOM_MODULES_AVAILABLE and face_model is not None and transform is not None: |
| try: |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| img_t = transform(image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output = face_model(img_t) |
| pred_idx = output.argmax(dim=1).item() |
| dataset_label = emotion_idx_to_label.get(pred_idx, 'unknown') |
| custom_moods = dataset_to_custom.get(dataset_label, ['Unknown']) |
| |
| logger.info(f"✅ AA-DCN detected: {dataset_label} - returning AA-DCN result!") |
| return dataset_label, custom_moods |
| |
| except Exception as e: |
| logger.error(f"❌ AA-DCN prediction error: {e}") |
| |
| |
| logger.warning("⚠️ Using basic emotion fallback") |
| return 'neutral', ['Peace', 'Esteem', 'Confidence'] |
|
|
| def predict_environment_vibe_resnetvit(image_bytes): |
| """YOUR EXACT VIBE LOGIC: HybridResNetViT for environment/surrounding detection""" |
| try: |
| if CUSTOM_MODULES_AVAILABLE: |
| |
| image_tensor = preprocess_image(image_bytes) |
| vibe_idx = predict_vibe(image_tensor) |
| moods = map_vibe_to_moods(vibe_idx) |
| |
| logger.info(f"🌍 YOUR HybridResNetViT detected environment vibe: {moods} (vibe_idx: {vibe_idx})") |
| return None, moods, "vibe_hybridresnetvit" |
| else: |
| logger.error("❌ Custom modules not available - cannot perform vibe detection") |
| return None, ['error'], "vibe_unavailable" |
| |
| except Exception as e: |
| logger.error(f"❌ Environment vibe detection failed: {e}") |
| return None, ['error'], "vibe_error" |
|
|
| |
| def get_cache_key(image_bytes): |
| image_hash = hashlib.md5(image_bytes).hexdigest() |
| return f"emotion_analysis:{image_hash}" |
|
|
| def get_cached_result(cache_key): |
| if not redis_client: |
| return None |
| try: |
| cached = redis_client.get(cache_key) |
| if cached: |
| logger.info(f"🎯 Cache HIT for {cache_key[:20]}...") |
| return json.loads(cached) |
| except Exception as e: |
| logger.error(f"❌ Cache retrieval error: {e}") |
| return None |
|
|
| def cache_result(cache_key, result, ttl=3600): |
| if not redis_client: |
| return |
| try: |
| redis_client.setex(cache_key, ttl, json.dumps(result)) |
| logger.info(f"💾 Cached result for {cache_key[:20]}...") |
| except Exception as e: |
| logger.error(f"❌ Cache storage error: {e}") |
|
|
| |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "🎭 Vibe Detection Backend API with Authentication, YouTube Music and Redis Caching", |
| "status": "healthy", |
| "version": "1.0.0", |
| "endpoints": { |
| "analyze": "/analyze", |
| "health": "/health", |
| "docs": "/docs", |
| "auth": { |
| "signup": "/auth/signup", |
| "login": "/auth/login", |
| "google": "/auth/google", |
| "google_login": "/auth/google/login", |
| "me": "/auth/me", |
| "accounts": "/auth/accounts/all", |
| "logout": "/auth/logout", |
| "debug": "/debug/session" |
| } |
| }, |
| "models": { |
| "deepface_available": DEEPFACE_AVAILABLE, |
| "custom_modules_available": CUSTOM_MODULES_AVAILABLE, |
| "face_model_loaded": face_model is not None, |
| "youtube_music_available": ytmusic is not None, |
| "redis_available": redis_client is not None |
| } |
| } |
|
|
| @app.get("/health") |
| async def health_check(): |
| return { |
| "status": "healthy", |
| "ml_available": ML_AVAILABLE, |
| "deepface_available": DEEPFACE_AVAILABLE, |
| "custom_modules_available": CUSTOM_MODULES_AVAILABLE, |
| "face_model_loaded": face_model is not None, |
| "youtube_music_available": ytmusic is not None, |
| "redis_available": redis_client is not None, |
| "device": str(device) if ML_AVAILABLE else "N/A", |
| "platform": "Hugging Face Spaces (16GB RAM)", |
| "authentication": { |
| "email_signup": "Enabled", |
| "google_oauth": "Enabled with dual method", |
| "redis_user_storage": "Enabled", |
| "cross_device_accounts": "Enabled" |
| }, |
| "your_exact_logic": { |
| "face_detection": "DeepFace (happy) → AA-DCN (other emotions)", |
| "environment_detection": "HybridResNetViT (surrounding vibe)", |
| "csv_mapping": "Mood to genre mapping from CSV file", |
| "music_api": "YouTube Music API songs retrieved", |
| "redis_caching": "Enabled for emotion analysis results" |
| } |
| } |
|
|
| @app.post("/analyze") |
| async def analyze_emotion(file: UploadFile = File(...)): |
| if not file.content_type.startswith('image/'): |
| raise HTTPException(status_code=400, detail="File must be an image") |
| if not ML_AVAILABLE: |
| raise HTTPException(status_code=503, detail="ML models not available") |
| try: |
| image_bytes = await file.read() |
| cache_key = get_cache_key(image_bytes) |
| cached_result = get_cached_result(cache_key) |
| if cached_result: |
| logger.info("🎯 Returning cached emotion analysis result") |
| cached_result['cached'] = True |
| return JSONResponse(content=cached_result) |
| if CUSTOM_MODULES_AVAILABLE: |
| if detect_face(image_bytes): |
| logger.info("👤 Face detected - Using YOUR EXACT LOGIC: DeepFace → AA-DCN") |
| label, moods = predict_face_emotion_hybrid(image_bytes) |
| source = "face_emotion" |
| logger.info(f"🎭 Face emotion result: {label}") |
| else: |
| logger.info("🌍 No face detected - Using HybridResNetViT for environment vibe") |
| label, moods, source = predict_environment_vibe_resnetvit(image_bytes) |
| logger.info(f"🌍 Environment vibe result: {moods}") |
| genres = [] |
| tracks = [] |
| try: |
| genres = get_genres_for_moods(moods) |
| tracks = recommend_songs_for_moods_and_genres(moods, genres, limit_per_mood=50, limit_per_genre=50) |
| logger.info(f"🎵 Music recommendations: {len(tracks)} tracks from {len(genres)} genres") |
| except Exception as e: |
| logger.warning(f"⚠️ Music recommendation failed: {e}") |
| genres = ['pop', 'indie'] |
| tracks = [{"title": "Happy Song", "artist": "Sample Artist", "url": "https://music.youtube.com", "source": "fallback"}] |
| else: |
| label = 'neutral' |
| moods = ['Peace', 'Esteem', 'Confidence'] |
| genres = ['pop', 'indie'] |
| tracks = [{"title": "Fallback Track", "artist": "Fallback Artist", "url": "https://music.youtube.com", "source": "fallback"}] |
| source = "fallback" |
| result = { |
| "emotion": label, |
| "moods": moods, |
| "genres": genres, |
| "tracks": tracks, |
| "source": source, |
| "status": "success", |
| "cached": False, |
| "redis_available": redis_client is not None, |
| "your_exact_logic": { |
| "face_detection": "DeepFace (happy) → AA-DCN (other emotions)", |
| "environment_detection": "HybridResNetViT (surrounding vibe)", |
| "csv_mapping": "Mood to genre mapping completed", |
| "music_api": "YouTube Music API songs retrieved", |
| "redis_caching": "Enabled for emotion analysis results" |
| }, |
| "model_info": { |
| "device": str(device), |
| "deepface_available": DEEPFACE_AVAILABLE, |
| "custom_modules_available": CUSTOM_MODULES_AVAILABLE, |
| "face_model_loaded": face_model is not None, |
| "youtube_music_available": ytmusic is not None, |
| "memory_optimized": True, |
| "platform": "Hugging Face Spaces (16GB RAM)" |
| } |
| } |
| cache_result(cache_key, result) |
| logger.info(f"🎉 YOUR EXACT LOGIC COMPLETE: {label} - {source}") |
| gc.collect() |
| return JSONResponse(content=result) |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"❌ Analysis error: {e}") |
| gc.collect() |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
|
|
| @app.exception_handler(Exception) |
| async def global_exception_handler(request, exc): |
| logger.error(f"❌ Global error: {exc}") |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"Internal server error: {str(exc)}"} |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|