Spaces:
Sleeping
Sleeping
| """ | |
| TUNILip+ — HuggingFace Spaces | |
| Nouveau endpoint /extract-features-frames : | |
| Reçoit 16 frames JPEG base64 déjà croppées sur la bouche (par MediaPipe côté browser) | |
| → VideoMAE frozen → mean-pool spatial → (8, 768) | |
| → Identique au pipeline d'entraînement Kaggle | |
| """ | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from typing import List | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import tempfile | |
| import os | |
| import base64 | |
| import logging | |
| from contextlib import asynccontextmanager | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("tunilip") | |
| vmae_processor = None | |
| vmae_model = None | |
| DEVICE = None | |
| VMAE_MODEL_ID = "MCG-NJU/videomae-base" | |
| NUM_FRAMES = 16 | |
| async def lifespan(app: FastAPI): | |
| global vmae_processor, vmae_model, DEVICE | |
| logger.info(f"⏳ Chargement {VMAE_MODEL_ID} …") | |
| try: | |
| from transformers import VideoMAEModel, VideoMAEImageProcessor | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f" Device : {DEVICE}") | |
| vmae_processor = VideoMAEImageProcessor.from_pretrained(VMAE_MODEL_ID) | |
| vmae_model = VideoMAEModel.from_pretrained(VMAE_MODEL_ID) | |
| vmae_model.eval() | |
| vmae_model = vmae_model.to(DEVICE) | |
| for p in vmae_model.parameters(): | |
| p.requires_grad = False | |
| n = sum(p.numel() for p in vmae_model.parameters()) | |
| logger.info(f"✅ VideoMAE chargé — {n:,} params (GELÉS) sur {DEVICE}") | |
| except Exception as e: | |
| logger.error(f"❌ Erreur chargement VideoMAE : {e}") | |
| yield | |
| logger.info("Shutdown") | |
| app = FastAPI(title="TUNILip+ Feature Extractor", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ── Helper : base64 → numpy (H, W, 3) uint8 ────────────────── | |
| def b64_to_frame(b64str: str) -> np.ndarray: | |
| img_bytes = base64.b64decode(b64str) | |
| arr = np.frombuffer(img_bytes, dtype=np.uint8) | |
| img = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise ValueError("Impossible de décoder une frame JPEG") | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # S'assurer que c'est bien 224×224 | |
| if img.shape[:2] != (224, 224): | |
| img = cv2.resize(img, (224, 224)) | |
| return img # uint8 RGB | |
| def run_videomae(frames_np: List[np.ndarray]) -> np.ndarray: | |
| """ | |
| frames_np : liste de 16 arrays uint8 (224, 224, 3) RGB | |
| Retourne : np.ndarray float32 (8, 768) | |
| Identique à extract_videomae_features() du notebook Kaggle. | |
| """ | |
| if vmae_model is None or vmae_processor is None: | |
| raise RuntimeError("VideoMAE non chargé") | |
| # vmae_processor attend une liste de arrays uint8 (H, W, 3) | |
| inputs = vmae_processor(frames_np, return_tensors="pt") | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| out = vmae_model(**inputs) | |
| hidden = out.last_hidden_state.squeeze(0).cpu().numpy() # (1568, 768) | |
| # VideoMAE-base : 8 temporal × 196 spatial = 1568 | |
| T_temp, T_spat = 8, 196 | |
| hidden = hidden[:T_temp * T_spat].reshape(T_temp, T_spat, 768) | |
| hidden = hidden.mean(axis=1) # mean-pool spatial → (8, 768) | |
| logger.info(f"Features stats — mean:{hidden.mean():.4f} std:{hidden.std():.4f}") | |
| return hidden.astype(np.float32) | |
| # ══════════════════════════════════════════════════════════════ | |
| # ROUTES | |
| # ══════════════════════════════════════════════════════════════ | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model_ready": vmae_model is not None, | |
| "device": str(DEVICE) if DEVICE else "unknown", | |
| "model_id": VMAE_MODEL_ID, | |
| } | |
| async def extract_features_frames(frames_json: str = Form(...)): | |
| """ | |
| Reçoit : FormData { frames_json: "['<base64>', ...]" } | |
| Retourne: { "features": [[...], ...], "shape": [8, 768] } | |
| Utilise FormData (multipart) pour éviter le preflight CORS de HuggingFace. | |
| """ | |
| if vmae_model is None: | |
| raise HTTPException(status_code=503, detail="VideoMAE non chargé") | |
| try: | |
| import json as _json | |
| frames_list = _json.loads(frames_json) | |
| except Exception: | |
| raise HTTPException(status_code=422, detail="frames_json invalide") | |
| n = len(frames_list) | |
| if n == 0: | |
| raise HTTPException(status_code=422, detail="Aucune frame reçue") | |
| # Padding ou troncature à NUM_FRAMES | |
| frames_b64 = frames_list[:NUM_FRAMES] | |
| while len(frames_b64) < NUM_FRAMES: | |
| frames_b64.append(frames_b64[-1]) | |
| try: | |
| frames_np = [b64_to_frame(f) for f in frames_b64] | |
| except Exception as e: | |
| raise HTTPException(status_code=422, detail=f"Erreur décodage frames: {e}") | |
| try: | |
| features = run_videomae(frames_np) | |
| except Exception as e: | |
| logger.error(f"Erreur VideoMAE : {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return JSONResponse({ | |
| "features": features.tolist(), | |
| "shape": list(features.shape), | |
| "model_id": VMAE_MODEL_ID, | |
| "frames_received": n, | |
| }) | |
| # Ancien endpoint gardé pour compatibilité (envoie vidéo brute) | |
| async def extract_features(video: UploadFile = File(...)): | |
| """Endpoint legacy — envoie vidéo brute (sans crop bouche).""" | |
| suffix = os.path.splitext(video.filename or "video.mp4")[-1] or ".mp4" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| content = await video.read() | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| try: | |
| cap = cv2.VideoCapture(tmp_path) | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total == 0: | |
| cap.release() | |
| raise HTTPException(status_code=422, detail="Vidéo vide") | |
| indices = np.linspace(0, total - 1, NUM_FRAMES, dtype=int) | |
| frames_np = [] | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.resize(frame, (224, 224)) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames_np.append(frame) | |
| cap.release() | |
| while len(frames_np) < NUM_FRAMES: | |
| frames_np.append(np.zeros((224, 224, 3), dtype=np.uint8)) | |
| features = run_videomae(frames_np) | |
| return JSONResponse({ | |
| "features": features.tolist(), | |
| "shape": list(features.shape), | |
| "model_id": VMAE_MODEL_ID, | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| os.unlink(tmp_path) | |
| def root(): | |
| return { | |
| "service": "TUNILip+ VideoMAE Feature Extractor", | |
| "endpoints": { | |
| "POST /extract-features-frames": "Frames base64 croppées → (8,768) — RECOMMANDÉ", | |
| "POST /extract-features": "Vidéo brute → (8,768) — legacy", | |
| "GET /health": "Statut", | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |