isl-api / app.py
Creator-090's picture
fix: CPU-safe inference for HF free tier
86c7cf3
# app.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import time
import base64
from typing import List
from model import load_model, predict, predict_from_frames, DEVICE, _DTYPE
app = FastAPI(
title="ISL Recognition API",
description="Indian Sign Language recognition using Swin3D-S",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global state ─
model = None
model_loaded = False
model_error = None
# Startup
@app.on_event("startup")
async def startup_event():
global model, model_loaded, model_error
try:
model = load_model()
model_loaded = True
model_error = None
print("Model loaded and API is ready!")
except Exception as e:
model_loaded = False
model_error = str(e)
print("Model failed to load:", e)
# Root ─
@app.get("/")
def root():
return {
"status": "ISL API is running",
"message": "POST to /predict (video file) or /predict_frames (base64 frames)"
}
# Health ─
@app.get("/health")
def health():
if not model_loaded or model is None:
# Return 503 so the wake_up() retry loop in backend knows to keep waiting
raise HTTPException(
status_code=503,
detail={"status": "error", "model_loaded": False, "error": model_error}
)
return {
"status": "ok",
"model_loaded": True,
"device": str(DEVICE),
"fp16": str(_DTYPE),
}
# Deep health
@app.get("/health/deep")
def health_deep():
if not model_loaded or model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
import torch
dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE)
with torch.no_grad():
_ = model(dummy)
return {"status": "ok", "inference": "working", "device": str(DEVICE)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# Predict from frames (real-time path) ─
class FramesPayload(BaseModel):
frames: List[str]
top_k: int = 5
@app.post("/predict_frames")
async def predict_frames_api(payload: FramesPayload):
if not model_loaded or model is None:
raise HTTPException(status_code=503, detail="Model is not ready")
if not payload.frames or len(payload.frames) != 16:
raise HTTPException(status_code=400, detail="Exactly 16 frames required")
start_time = time.time()
frames_bytes = [base64.b64decode(f) for f in payload.frames]
try:
result = predict_from_frames(model, frames_bytes, top_k=payload.top_k)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
return {
"prediction": result["prediction"],
"confidence": result["confidence"],
"top_k": result["top_k"],
"inference_time_ms": round((time.time() - start_time) * 1000, 2),
}
# Predict from video file
ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv')
@app.post("/predict")
async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
if not file.filename.lower().endswith(ALLOWED_EXTENSIONS):
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}"
)
if not model_loaded or model is None:
raise HTTPException(status_code=503, detail="Model is not ready")
start_time = time.time()
video_bytes = await file.read()
try:
result = predict(model, video_bytes, top_k=top_k)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
return {
**result,
"inference_time_ms": round((time.time() - start_time) * 1000, 2),
"filename": file.filename,
}
# Entry point
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)