Spaces:
Sleeping
Sleeping
| # app.py | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import time | |
| import base64 | |
| from typing import List | |
| from model import load_model, predict, predict_from_frames, DEVICE, _DTYPE | |
| app = FastAPI( | |
| title="ISL Recognition API", | |
| description="Indian Sign Language recognition using Swin3D-S", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state ─ | |
| model = None | |
| model_loaded = False | |
| model_error = None | |
| # Startup | |
| async def startup_event(): | |
| global model, model_loaded, model_error | |
| try: | |
| model = load_model() | |
| model_loaded = True | |
| model_error = None | |
| print("Model loaded and API is ready!") | |
| except Exception as e: | |
| model_loaded = False | |
| model_error = str(e) | |
| print("Model failed to load:", e) | |
| # Root ─ | |
| def root(): | |
| return { | |
| "status": "ISL API is running", | |
| "message": "POST to /predict (video file) or /predict_frames (base64 frames)" | |
| } | |
| # Health ─ | |
| def health(): | |
| if not model_loaded or model is None: | |
| # Return 503 so the wake_up() retry loop in backend knows to keep waiting | |
| raise HTTPException( | |
| status_code=503, | |
| detail={"status": "error", "model_loaded": False, "error": model_error} | |
| ) | |
| return { | |
| "status": "ok", | |
| "model_loaded": True, | |
| "device": str(DEVICE), | |
| "fp16": str(_DTYPE), | |
| } | |
| # Deep health | |
| def health_deep(): | |
| if not model_loaded or model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import torch | |
| dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE) | |
| with torch.no_grad(): | |
| _ = model(dummy) | |
| return {"status": "ok", "inference": "working", "device": str(DEVICE)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # Predict from frames (real-time path) ─ | |
| class FramesPayload(BaseModel): | |
| frames: List[str] | |
| top_k: int = 5 | |
| async def predict_frames_api(payload: FramesPayload): | |
| if not model_loaded or model is None: | |
| raise HTTPException(status_code=503, detail="Model is not ready") | |
| if not payload.frames or len(payload.frames) != 16: | |
| raise HTTPException(status_code=400, detail="Exactly 16 frames required") | |
| start_time = time.time() | |
| frames_bytes = [base64.b64decode(f) for f in payload.frames] | |
| try: | |
| result = predict_from_frames(model, frames_bytes, top_k=payload.top_k) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") | |
| return { | |
| "prediction": result["prediction"], | |
| "confidence": result["confidence"], | |
| "top_k": result["top_k"], | |
| "inference_time_ms": round((time.time() - start_time) * 1000, 2), | |
| } | |
| # Predict from video file | |
| ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv') | |
| async def predict_sign(file: UploadFile = File(...), top_k: int = 5): | |
| if not file.filename.lower().endswith(ALLOWED_EXTENSIONS): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}" | |
| ) | |
| if not model_loaded or model is None: | |
| raise HTTPException(status_code=503, detail="Model is not ready") | |
| start_time = time.time() | |
| video_bytes = await file.read() | |
| try: | |
| result = predict(model, video_bytes, top_k=top_k) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") | |
| return { | |
| **result, | |
| "inference_time_ms": round((time.time() - start_time) * 1000, 2), | |
| "filename": file.filename, | |
| } | |
| # Entry point | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) |