File size: 4,203 Bytes
c126626
 
 
1cf4369
c126626
 
053568e
de9af52
1cf4369
 
c126626
 
 
 
 
 
 
 
053568e
 
c126626
 
 
 
86c7cf3
1cf4369
053568e
1cf4369
c126626
de9af52
1cf4369
c126626
 
053568e
 
1cf4369
053568e
1cf4369
053568e
 
 
1cf4369
053568e
 
de9af52
86c7cf3
c126626
 
053568e
1cf4369
 
053568e
 
 
86c7cf3
053568e
 
 
86c7cf3
 
 
 
 
053568e
1cf4369
053568e
1cf4369
 
053568e
 
 
86c7cf3
053568e
 
 
 
 
 
 
1cf4369
053568e
 
1cf4369
 
 
053568e
 
86c7cf3
053568e
86c7cf3
1cf4369
053568e
 
 
 
 
 
 
c126626
1cf4369
 
 
053568e
 
 
1cf4369
053568e
 
1cf4369
 
 
 
053568e
 
 
86c7cf3
1cf4369
 
 
c126626
1cf4369
c126626
 
1cf4369
c126626
053568e
1cf4369
c126626
1cf4369
c126626
053568e
de9af52
 
 
1cf4369
c126626
1cf4369
 
 
 
 
053568e
de9af52
86c7cf3
c126626
053568e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# app.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import time
import base64
from typing import List

from model import load_model, predict, predict_from_frames, DEVICE, _DTYPE

app = FastAPI(
    title="ISL Recognition API",
    description="Indian Sign Language recognition using Swin3D-S",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

#  Global state ─
model        = None
model_loaded = False
model_error  = None


#  Startup 
@app.on_event("startup")
async def startup_event():
    global model, model_loaded, model_error
    try:
        model        = load_model()
        model_loaded = True
        model_error  = None
        print("Model loaded and API is ready!")
    except Exception as e:
        model_loaded = False
        model_error  = str(e)
        print("Model failed to load:", e)


#  Root ─
@app.get("/")
def root():
    return {
        "status":  "ISL API is running",
        "message": "POST to /predict (video file) or /predict_frames (base64 frames)"
    }


#  Health ─
@app.get("/health")
def health():
    if not model_loaded or model is None:
        # Return 503 so the wake_up() retry loop in backend knows to keep waiting
        raise HTTPException(
            status_code=503,
            detail={"status": "error", "model_loaded": False, "error": model_error}
        )
    return {
        "status":       "ok",
        "model_loaded": True,
        "device":       str(DEVICE),
        "fp16":         str(_DTYPE),
    }


#  Deep health 
@app.get("/health/deep")
def health_deep():
    if not model_loaded or model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        import torch
        dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE)
        with torch.no_grad():
            _ = model(dummy)
        return {"status": "ok", "inference": "working", "device": str(DEVICE)}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")


#  Predict from frames (real-time path) ─
class FramesPayload(BaseModel):
    frames: List[str]
    top_k:  int = 5

@app.post("/predict_frames")
async def predict_frames_api(payload: FramesPayload):
    if not model_loaded or model is None:
        raise HTTPException(status_code=503, detail="Model is not ready")
    if not payload.frames or len(payload.frames) != 16:
        raise HTTPException(status_code=400, detail="Exactly 16 frames required")

    start_time   = time.time()
    frames_bytes = [base64.b64decode(f) for f in payload.frames]

    try:
        result = predict_from_frames(model, frames_bytes, top_k=payload.top_k)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")

    return {
        "prediction":        result["prediction"],
        "confidence":        result["confidence"],
        "top_k":             result["top_k"],
        "inference_time_ms": round((time.time() - start_time) * 1000, 2),
    }


#  Predict from video file 
ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv')

@app.post("/predict")
async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
    if not file.filename.lower().endswith(ALLOWED_EXTENSIONS):
        raise HTTPException(
            status_code=400,
            detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}"
        )
    if not model_loaded or model is None:
        raise HTTPException(status_code=503, detail="Model is not ready")

    start_time  = time.time()
    video_bytes = await file.read()

    try:
        result = predict(model, video_bytes, top_k=top_k)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")

    return {
        **result,
        "inference_time_ms": round((time.time() - start_time) * 1000, 2),
        "filename":          file.filename,
    }


#  Entry point 
if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)