| | """ |
| | Wakee API - Production |
| | ONNX Runtime UNIQUEMENT (pas de PyTorch) |
| | """ |
| | import os |
| | os.environ["HUGGINGFACE_HUB_DISABLE_XET"] = "1" |
| |
|
| | from fastapi import FastAPI, File, UploadFile, HTTPException, Form |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from typing import List, Optional |
| | from huggingface_hub import hf_hub_download |
| | import onnxruntime as ort |
| | import onnxscript |
| | from PIL import Image |
| | import io |
| | import numpy as np |
| | from datetime import datetime |
| | import base64 |
| |
|
| |
|
| | from sqlalchemy import create_engine, text |
| | from sqlalchemy.exc import SQLAlchemyError |
| | import boto3 |
| | from botocore.exceptions import ClientError |
| |
|
| | |
| | |
| | |
| |
|
| | def preprocess_image(pil_image: Image.Image) -> np.ndarray: |
| | """ |
| | Preprocessing identique à ton cnn.py |
| | SANS dépendance PyTorch (juste Pillow + numpy) |
| | """ |
| | |
| | img = pil_image.resize((256, 256), Image.BILINEAR) |
| | |
| | |
| | left = (256 - 224) // 2 |
| | top = (256 - 224) // 2 |
| | img = img.crop((left, top, left + 224, top + 224)) |
| | |
| | |
| | img_array = np.array(img).astype(np.float32) / 255.0 |
| | |
| | |
| | mean = np.array([0.485, 0.456, 0.406]) |
| | std = np.array([0.229, 0.224, 0.225]) |
| | img_array = (img_array - mean) / std |
| | |
| | |
| | img_array = np.transpose(img_array, (2, 0, 1)) |
| | |
| | |
| | img_array = np.expand_dims(img_array, axis=0).astype(np.float32) |
| | |
| | return img_array |
| |
|
| | |
| | |
| | |
| |
|
| | def load_env_vars(): |
| | """Charge .env en local, utilise env vars en prod""" |
| | is_production = os.getenv("SPACE_ID") is not None |
| | |
| | if not is_production: |
| | from pathlib import Path |
| | try: |
| | from dotenv import load_dotenv |
| | root_dir = Path(__file__).resolve().parent.parent |
| | dotenv_path = root_dir / '.env' |
| | if dotenv_path.exists(): |
| | load_dotenv(dotenv_path) |
| | print(f"✅ .env chargé depuis : {dotenv_path}") |
| | except ImportError: |
| | print("⚠️ python-dotenv non installé (OK en production)") |
| |
|
| | load_env_vars() |
| |
|
| | HF_MODEL_REPO = "Terorra/wakee-reloaded" |
| | MODEL_FILENAME = "model.onnx" |
| |
|
| | NEON_DATABASE_URL = os.getenv("NEONDB_WR") |
| | R2_ACCOUNT_ID = os.getenv("R2_ACCOUNT_ID") |
| | R2_ACCESS_KEY_ID = os.getenv("R2_ACCESS_KEY_ID") |
| | R2_SECRET_ACCESS_KEY = os.getenv("R2_SECRET_ACCESS_KEY") |
| | R2_BUCKET_NAME = os.getenv("R2_WR_IMG_BUCKET_NAME", "wr-img-store") |
| |
|
| | |
| | |
| | |
| |
|
| | class PredictionResponse(BaseModel): |
| | boredom: float = Field(..., ge=0, le=3) |
| | confusion: float = Field(..., ge=0, le=3) |
| | engagement: float = Field(..., ge=0, le=3) |
| | frustration: float = Field(..., ge=0, le=3) |
| | timestamp: str |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class InsertResponse(BaseModel): |
| | status: str |
| | message: str |
| | img_name: str |
| | s3_url: Optional[str] = None |
| |
|
| | class LoadResponse(BaseModel): |
| | total_samples: int |
| | validated_samples: int |
| | recent_predictions: List[dict] |
| | statistics: dict |
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI( |
| | title="Wakee Emotion API", |
| | description="Multi-label emotion detection (ONNX Runtime)", |
| | version="1.0.0", |
| | docs_url="/docs", |
| | redoc_url="/redoc" |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | onnx_session = None |
| | db_engine = None |
| | s3_client = None |
| |
|
| | |
| | |
| | |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | global onnx_session, db_engine, s3_client |
| | |
| | print("=" * 70) |
| | print("🚀 DÉMARRAGE API WAKEE (ONNX Runtime)") |
| | print("=" * 70) |
| |
|
| | onnx_session = None |
| |
|
| | try: |
| | print("\n📥 Tentative chargement ONNX depuis HF...") |
| |
|
| | onnx_path = hf_hub_download( |
| | repo_id=HF_MODEL_REPO, |
| | filename="model.onnx", |
| | cache_dir="/tmp/models" |
| | ) |
| |
|
| | |
| | file_size_mb = os.path.getsize(onnx_path) / 1e6 |
| | print(f" ONNX file size: {file_size_mb:.2f} MB") |
| | |
| | if file_size_mb < 10: |
| | print(f"⚠️ ONNX file too small ({file_size_mb:.2f} MB), using fallback") |
| | raise ValueError("ONNX file incomplete") |
| |
|
| | onnx_session = ort.InferenceSession(onnx_path) |
| | print("✅ ONNX chargé directement") |
| |
|
| | except Exception as e: |
| | print(f"⚠️ ONNX indisponible: {e}") |
| | print("🔁 Fallback → PyTorch .bin → conversion ONNX...") |
| |
|
| | try: |
| | |
| | |
| | |
| | bin_path = hf_hub_download( |
| | repo_id=HF_MODEL_REPO, |
| | filename="pytorch_model.bin", |
| | cache_dir="/tmp/models" |
| | ) |
| | |
| | |
| | bin_size_mb = os.path.getsize(bin_path) / 1e6 |
| | print(f" PyTorch .bin size: {bin_size_mb:.2f} MB") |
| |
|
| | |
| | |
| | |
| | import torch |
| | from torchvision import models |
| | import torch.nn as nn |
| |
|
| | NUM_CLASSES = 4 |
| | DEVICE = "cpu" |
| |
|
| | model = models.efficientnet_b4(weights=None) |
| | model.classifier[1] = nn.Linear( |
| | model.classifier[1].in_features, |
| | NUM_CLASSES |
| | ) |
| |
|
| | |
| | state_dict = torch.load(bin_path, map_location=DEVICE, weights_only=False) |
| | |
| | |
| | if isinstance(state_dict, dict): |
| | if 'model' in state_dict: |
| | state_dict = state_dict['model'] |
| | elif 'state_dict' in state_dict: |
| | state_dict = state_dict['state_dict'] |
| | |
| | model.load_state_dict(state_dict, strict=False) |
| | model.eval() |
| |
|
| | print("✅ PyTorch chargé") |
| |
|
| | |
| | |
| | |
| | tmp_onnx = "/tmp/models/fallback_model.onnx" |
| |
|
| | dummy = torch.randn(1, 3, 224, 224) |
| |
|
| | |
| | torch.onnx.export( |
| | model, |
| | dummy, |
| | tmp_onnx, |
| | export_params=True, |
| | opset_version=17, |
| | do_constant_folding=True, |
| | input_names=["input"], |
| | output_names=["output"], |
| | dynamic_axes={ |
| | 'input': {0: 'batch_size'}, |
| | 'output': {0: 'batch_size'} |
| | }, |
| | verbose=False |
| | ) |
| |
|
| | print("✅ Conversion ONNX locale OK") |
| | |
| | |
| | onnx_size_mb = os.path.getsize(tmp_onnx) / 1e6 |
| | print(f" ONNX file size: {onnx_size_mb:.2f} MB") |
| | |
| | if onnx_size_mb < 10: |
| | raise ValueError(f"ONNX file too small ({onnx_size_mb:.2f} MB)! Weights not exported.") |
| |
|
| | |
| | |
| | |
| | onnx_session = ort.InferenceSession(tmp_onnx) |
| | |
| | |
| | test_input = np.random.randn(1, 3, 224, 224).astype(np.float32) |
| | test_output = onnx_session.run(['output'], {'input': test_input}) |
| | print(f" Test inference OK, output shape: {test_output[0].shape}") |
| |
|
| | except Exception as e2: |
| | print(f"❌ Fallback PyTorch échoué : {e2}") |
| | onnx_session = None |
| |
|
| | if onnx_session: |
| | input_name = onnx_session.get_inputs()[0].name |
| | input_shape = onnx_session.get_inputs()[0].shape |
| | print(f" Input : {input_name} {input_shape}\n") |
| | |
| | |
| | if NEON_DATABASE_URL: |
| | try: |
| | db_engine = create_engine(NEON_DATABASE_URL) |
| | with db_engine.connect() as conn: |
| | conn.execute(text("SELECT 1")) |
| | print("✅ Connexion NeonDB établie\n") |
| | except Exception as e: |
| | print(f"⚠️ NeonDB non disponible : {e}\n") |
| | db_engine = None |
| | else: |
| | print("⚠️ NEON_DATABASE_URL non défini\n") |
| | |
| | |
| | if all([R2_ACCOUNT_ID, R2_ACCESS_KEY_ID, R2_SECRET_ACCESS_KEY]): |
| | try: |
| | s3_client = boto3.client( |
| | 's3', |
| | endpoint_url=f'https://{R2_ACCOUNT_ID}.r2.cloudflarestorage.com', |
| | aws_access_key_id=R2_ACCESS_KEY_ID, |
| | aws_secret_access_key=R2_SECRET_ACCESS_KEY, |
| | region_name='auto' |
| | ) |
| | s3_client.head_bucket(Bucket=R2_BUCKET_NAME) |
| | print(f"✅ Connexion Cloudflare R2 (bucket: {R2_BUCKET_NAME})\n") |
| | except Exception as e: |
| | print(f"⚠️ Cloudflare R2 non disponible : {e}\n") |
| | s3_client = None |
| | else: |
| | print("⚠️ R2 secrets non définis\n") |
| | |
| | print("=" * 70) |
| | print("🎉 API WAKEE PRÊTE !") |
| | print("=" * 70) |
| | print(f"📊 Status :") |
| | print(f" - Modèle ONNX : {'✅' if onnx_session else '❌'}") |
| | print(f" - Database : {'✅' if db_engine else '❌'}") |
| | print(f" - Storage : {'✅' if s3_client else '❌'}") |
| | print("=" * 70 + "\n") |
| |
|
| | |
| | |
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | return { |
| | "message": "Wakee Emotion API", |
| | "version": "1.0.0", |
| | "runtime": "ONNX Runtime (no PyTorch)", |
| | "model_source": HF_MODEL_REPO |
| | } |
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | return { |
| | "status": "healthy", |
| | "model_loaded": onnx_session is not None, |
| | "runtime": "ONNX", |
| | "timestamp": datetime.now().isoformat() |
| | } |
| |
|
| | @app.post("/predict", response_model=PredictionResponse) |
| | async def predict_emotion(file: UploadFile = File(...)): |
| | """ |
| | Prédiction des 4 émotions depuis une image |
| | |
| | ⚠️ RIEN N'EST SAUVEGARDÉ à cette étape |
| | |
| | L'utilisateur doit ensuite appeler /insert pour sauvegarder |
| | """ |
| | |
| | if not onnx_session: |
| | raise HTTPException( |
| | status_code=503, |
| | detail="Model not loaded" |
| | ) |
| | |
| | if not file.content_type.startswith('image/'): |
| | raise HTTPException(status_code=400, detail="File must be an image") |
| | |
| | try: |
| | |
| | image_bytes = await file.read() |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | |
| | |
| | input_tensor = preprocess_image(image) |
| | |
| | |
| | outputs = onnx_session.run(['output'], {'input': input_tensor}) |
| | scores_array = outputs[0][0] |
| | |
| | |
| | |
| | |
| | return PredictionResponse( |
| | boredom=round(float(scores_array[0]), 2), |
| | confusion=round(float(scores_array[1]), 2), |
| | engagement=round(float(scores_array[2]), 2), |
| | frustration=round(float(scores_array[3]), 2), |
| | timestamp=datetime.now().isoformat() |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | except Exception as e: |
| | print(f"❌ Erreur prédiction : {e}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/insert", response_model=InsertResponse) |
| | async def insert_annotation( |
| | file: UploadFile = File(...), |
| | predicted_boredom: float = Form(...), |
| | predicted_confusion: float = Form(...), |
| | predicted_engagement: float = Form(...), |
| | predicted_frustration: float = Form(...), |
| | user_boredom: float = Form(...), |
| | user_confusion: float = Form(...), |
| | user_engagement: float = Form(...), |
| | user_frustration: float = Form(...) |
| | ): |
| | """ |
| | Insert annotation utilisateur |
| | |
| | NOUVEAU : Reçoit directement l'image (pas de base64) |
| | """ |
| | |
| | |
| | if not db_engine: |
| | raise HTTPException(status_code=503, detail="Database not available") |
| | |
| | if not s3_client: |
| | raise HTTPException(status_code=503, detail="Storage not available") |
| | |
| | if not file.content_type.startswith('image/'): |
| | raise HTTPException(status_code=400, detail="File must be an image") |
| | |
| | try: |
| | |
| | image_bytes = await file.read() |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | random_suffix = hash(image_bytes) % 10000 |
| | img_name = f"{timestamp}_{random_suffix:04d}.jpg" |
| | s3_key = f"{img_name}" |
| | |
| | |
| | print(f"📤 Upload vers R2 : {s3_key}") |
| | try: |
| | s3_client.put_object( |
| | Bucket=R2_BUCKET_NAME, |
| | Key=s3_key, |
| | Body=image_bytes, |
| | ContentType='image/jpeg' |
| | ) |
| | print(f"✅ Upload R2 réussi : {img_name}") |
| | except ClientError as e: |
| | print(f"❌ Erreur upload R2 : {e}") |
| | raise HTTPException(status_code=500, detail=f"R2 upload failed: {e}") |
| | |
| | |
| | query = text(""" |
| | INSERT INTO emotion_labels |
| | (img_name, s3_path, |
| | predicted_boredom, predicted_confusion, predicted_engagement, predicted_frustration, |
| | user_boredom, user_confusion, user_engagement, user_frustration, |
| | source, is_validated, timestamp) |
| | VALUES |
| | (:img_name, :s3_path, |
| | :pred_boredom, :pred_confusion, :pred_engagement, :pred_frustration, |
| | :user_boredom, :user_confusion, :user_engagement, :user_frustration, |
| | 'app_sourcing', TRUE, :timestamp) |
| | """) |
| | |
| | with db_engine.connect() as conn: |
| | conn.execute(query, { |
| | 'img_name': img_name, |
| | 's3_path': s3_key, |
| | 'pred_boredom': predicted_boredom, |
| | 'pred_confusion': predicted_confusion, |
| | 'pred_engagement': predicted_engagement, |
| | 'pred_frustration': predicted_frustration, |
| | 'user_boredom': user_boredom, |
| | 'user_confusion': user_confusion, |
| | 'user_engagement': user_engagement, |
| | 'user_frustration': user_frustration, |
| | 'timestamp': datetime.now() |
| | }) |
| | conn.commit() |
| | |
| | print(f"✅ Insert NeonDB réussi : {img_name}") |
| | |
| | return InsertResponse( |
| | status="success", |
| | message="Image uploaded and labels saved", |
| | img_name=img_name, |
| | s3_url=None |
| | ) |
| | |
| | except SQLAlchemyError as e: |
| | print(f"❌ Erreur NeonDB : {e}") |
| | raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") |
| | |
| | except Exception as e: |
| | print(f"❌ Erreur insert : {e}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.get("/load", response_model=LoadResponse) |
| | async def load_data(limit: int = 10): |
| | """ |
| | Charge les données depuis NeonDB |
| | |
| | Retourne : |
| | - Nombre total d'échantillons |
| | - Nombre d'échantillons validés |
| | - Dernières prédictions (avec corrections utilisateur) |
| | - Statistiques globales |
| | """ |
| | |
| | if not db_engine: |
| | raise HTTPException(status_code=503, detail="Database not available") |
| | |
| | try: |
| | with db_engine.connect() as conn: |
| | |
| | total = conn.execute(text( |
| | "SELECT COUNT(*) FROM emotion_labels" |
| | )).scalar() |
| | |
| | |
| | validated = conn.execute(text( |
| | "SELECT COUNT(*) FROM emotion_labels WHERE is_validated = TRUE" |
| | )).scalar() |
| | |
| | |
| | recent = conn.execute(text(f""" |
| | SELECT |
| | img_name, |
| | s3_path, |
| | predicted_boredom, |
| | predicted_confusion, |
| | predicted_engagement, |
| | predicted_frustration, |
| | user_boredom, |
| | user_confusion, |
| | user_engagement, |
| | user_frustration, |
| | timestamp |
| | FROM emotion_labels |
| | WHERE is_validated = TRUE |
| | ORDER BY timestamp DESC |
| | LIMIT :limit |
| | """), {'limit': limit}).fetchall() |
| | |
| | recent_list = [ |
| | { |
| | 'img_name': row[0], |
| | 's3_path': row[1], |
| | 'predicted': { |
| | 'boredom': float(row[2]), |
| | 'confusion': float(row[3]), |
| | 'engagement': float(row[4]), |
| | 'frustration': float(row[5]) |
| | }, |
| | 'user_corrected': { |
| | 'boredom': float(row[6]), |
| | 'confusion': float(row[7]), |
| | 'engagement': float(row[8]), |
| | 'frustration': float(row[9]) |
| | }, |
| | 'timestamp': row[10].isoformat() if row[10] else None |
| | } |
| | for row in recent |
| | ] |
| | |
| | |
| | stats = conn.execute(text(""" |
| | SELECT |
| | AVG(predicted_boredom) as avg_pred_boredom, |
| | AVG(predicted_confusion) as avg_pred_confusion, |
| | AVG(predicted_engagement) as avg_pred_engagement, |
| | AVG(predicted_frustration) as avg_pred_frustration, |
| | AVG(user_boredom) as avg_user_boredom, |
| | AVG(user_confusion) as avg_user_confusion, |
| | AVG(user_engagement) as avg_user_engagement, |
| | AVG(user_frustration) as avg_user_frustration |
| | FROM emotion_labels |
| | WHERE is_validated = TRUE |
| | """)).fetchone() |
| | |
| | statistics = { |
| | 'predictions': { |
| | 'boredom': round(float(stats[0] or 0), 2), |
| | 'confusion': round(float(stats[1] or 0), 2), |
| | 'engagement': round(float(stats[2] or 0), 2), |
| | 'frustration': round(float(stats[3] or 0), 2) |
| | }, |
| | 'user_corrections': { |
| | 'boredom': round(float(stats[4] or 0), 2), |
| | 'confusion': round(float(stats[5] or 0), 2), |
| | 'engagement': round(float(stats[6] or 0), 2), |
| | 'frustration': round(float(stats[7] or 0), 2) |
| | } |
| | } |
| | |
| | return LoadResponse( |
| | total_samples=total or 0, |
| | validated_samples=validated or 0, |
| | recent_predictions=recent_list, |
| | statistics=statistics |
| | ) |
| | |
| | except SQLAlchemyError as e: |
| | raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |