Creator-090 commited on
Commit
103cb2c
·
1 Parent(s): 99f8a05

update: api calls and data preprocessing

Browse files
Files changed (2) hide show
  1. app.py +122 -21
  2. model.py +77 -3
app.py CHANGED
@@ -2,9 +2,11 @@
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import uvicorn
5
- from model import load_model, predict
6
  import time
7
- import os
 
 
8
 
9
  app = FastAPI(
10
  title="ISL Recognition API",
@@ -12,50 +14,149 @@ app = FastAPI(
12
  version="1.0.0"
13
  )
14
 
15
- # Allow all origins so your Flutter app can talk to it
16
  app.add_middleware(
17
- CORSMiddleware, allow_origins=["*"],
 
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
- # Global variable for the model
23
  model = None
 
 
24
 
 
 
25
  @app.on_event("startup")
26
  async def startup_event():
27
- global model
28
- # This calls the function in your model.py to download and load the .pt file
29
- model = load_model()
30
- print("Model loaded and API is ready!")
 
 
 
 
 
 
 
31
 
 
32
  @app.get("/")
33
  def root():
34
- return {"status": "ISL API is running", "message": "Send a POST request to /predict"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- @app.post("/predict")
 
37
  async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
38
- # Validate that it's a video
39
  if not file.filename.lower().endswith(('.mp4', '.mov', '.avi', '.mkv')):
40
  raise HTTPException(
41
  status_code=400,
42
  detail="Invalid file type. Please upload a video (.mp4, .mov, etc.)"
43
  )
44
 
45
- if model is None:
46
- raise HTTPException(status_code=503, detail="Model is still loading...")
 
 
 
 
47
 
48
  start_time = time.time()
49
  video_bytes = await file.read()
50
-
51
- # This calls the prediction logic in your model.py
52
- result = predict(model, video_bytes, top_k=top_k)
53
-
 
 
 
 
 
54
  result["inference_time_ms"] = round((time.time() - start_time) * 1000, 2)
55
  result["filename"] = file.filename
56
-
57
  return result
58
 
 
59
  if __name__ == "__main__":
60
- # Hugging Face Spaces usually look for port 7860
61
- uvicorn.run("app.py:app", host="0.0.0.0", port=7860)
 
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import uvicorn
5
+ from model import load_model, predict, predict_from_frames
6
  import time
7
+ from pydantic import BaseModel
8
+ from typing import List
9
+ import base64
10
 
11
  app = FastAPI(
12
  title="ISL Recognition API",
 
14
  version="1.0.0"
15
  )
16
 
17
+ # Allow all origins (for Flutter / frontend apps)
18
  app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Global state
26
  model = None
27
+ model_loaded = False
28
+ model_error = None
29
 
30
+
31
+ # STARTUP
32
  @app.on_event("startup")
33
  async def startup_event():
34
+ global model, model_loaded, model_error
35
+ try:
36
+ model = load_model()
37
+ model_loaded = True
38
+ model_error = None
39
+ print("Model loaded and API is ready!")
40
+ except Exception as e:
41
+ model_loaded = False
42
+ model_error = str(e)
43
+ print("Model failed to load:", e)
44
+
45
 
46
+ # ROOT
47
  @app.get("/")
48
  def root():
49
+ return {
50
+ "status": "ISL API is running",
51
+ "message": "Send a POST request to /predict (video) or /predict_frames (frames list)"
52
+ }
53
+
54
+
55
+ # HEALTH
56
+ @app.get("/health")
57
+ def health():
58
+ if not model_loaded or model is None:
59
+ return {
60
+ "status": "error",
61
+ "model_loaded": False,
62
+ "error": model_error
63
+ }
64
+
65
+ return {
66
+ "status": "ok",
67
+ "model_loaded": True,
68
+ "device": str(next(model.parameters()).device)
69
+ }
70
+
71
+
72
+ # DEEP HEALTH
73
+ @app.get("/health/deep")
74
+ def health_deep():
75
+ if not model_loaded or model is None:
76
+ raise HTTPException(status_code=503, detail="Model not loaded")
77
+
78
+ try:
79
+ import torch
80
+
81
+ dummy = torch.zeros(1, 3, 16, 224, 224).to(
82
+ next(model.parameters()).device
83
+ )
84
+
85
+ with torch.no_grad():
86
+ _ = model(dummy)
87
+
88
+ return {
89
+ "status": "ok",
90
+ "inference": "working"
91
+ }
92
+
93
+ except Exception as e:
94
+ raise HTTPException(
95
+ status_code=500,
96
+ detail=f"Inference failed: {str(e)}"
97
+ )
98
+
99
+ class FramesPayload(BaseModel):
100
+ frames: List[str] # List of base64 encoded JPEG/PNG images
101
+ top_k: int = 5
102
+
103
+ @app.post("/predict_frames")
104
+ async def predict_frames_api(payload: FramesPayload):
105
+ if not model_loaded or model is None:
106
+ raise HTTPException(status_code=503, detail="Model is not ready")
107
+
108
+ if not payload.frames:
109
+ raise HTTPException(status_code=400, detail="No frames provided")
110
+
111
+ start_time = time.time()
112
+
113
+ try:
114
+ # Convert base64 strings to bytes
115
+ frames_bytes = [base64.b64decode(f) for f in payload.frames]
116
+ result = predict_from_frames(model, frames_bytes, top_k=payload.top_k)
117
+ except Exception as e:
118
+ raise HTTPException(
119
+ status_code=500,
120
+ detail=f"Inference error: {str(e)}"
121
+ )
122
+
123
+ result["inference_time_ms"] = round((time.time() - start_time) * 1000, 2)
124
+ return result
125
+
126
 
127
+ # PREDICT
128
+ @app.post("/predict")
129
  async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
130
+ # Validate file type
131
  if not file.filename.lower().endswith(('.mp4', '.mov', '.avi', '.mkv')):
132
  raise HTTPException(
133
  status_code=400,
134
  detail="Invalid file type. Please upload a video (.mp4, .mov, etc.)"
135
  )
136
 
137
+ # Ensure model is ready
138
+ if not model_loaded or model is None:
139
+ raise HTTPException(
140
+ status_code=503,
141
+ detail="Model is not ready"
142
+ )
143
 
144
  start_time = time.time()
145
  video_bytes = await file.read()
146
+
147
+ try:
148
+ result = predict(model, video_bytes, top_k=top_k)
149
+ except Exception as e:
150
+ raise HTTPException(
151
+ status_code=500,
152
+ detail=f"Inference error: {str(e)}"
153
+ )
154
+
155
  result["inference_time_ms"] = round((time.time() - start_time) * 1000, 2)
156
  result["filename"] = file.filename
157
+
158
  return result
159
 
160
+
161
  if __name__ == "__main__":
162
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
model.py CHANGED
@@ -8,6 +8,8 @@ from decord.bridge import set_bridge
8
  import gc
9
  import tempfile
10
  import os
 
 
11
 
12
  # Exactly 76 classes from your notebook metadata
13
  CLASSES = [
@@ -90,12 +92,13 @@ def preprocess_video(video_bytes: bytes, clip_length: int = 16):
90
 
91
  # Ensure video is a torch tensor in (Frames, Channels, Height, Width)
92
  video = vr.get_batch(indices)
93
- video = video.permute(0, 3, 1, 2).float() # Convert to Float for the processor
94
 
95
  # Pass as a list of Tensors
96
  processed = image_processor(
97
- list(video),
98
- return_tensors='pt'
 
99
  )
100
 
101
  pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
@@ -106,6 +109,77 @@ def preprocess_video(video_bytes: bytes, clip_length: int = 16):
106
  if os.path.exists(tmp_path):
107
  os.remove(tmp_path)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def predict(model, video_bytes: bytes, top_k: int = 5):
110
  """Runs inference and returns the top results"""
111
  pixel_values = preprocess_video(video_bytes).to(DEVICE)
 
8
  import gc
9
  import tempfile
10
  import os
11
+ import cv2
12
+ import numpy as np
13
 
14
  # Exactly 76 classes from your notebook metadata
15
  CLASSES = [
 
92
 
93
  # Ensure video is a torch tensor in (Frames, Channels, Height, Width)
94
  video = vr.get_batch(indices)
95
+ video = video.permute(0, 3, 1, 2).to(torch.uint8) # Convert to Float for the processor
96
 
97
  # Pass as a list of Tensors
98
  processed = image_processor(
99
+ list(video),
100
+ return_tensors='pt',
101
+ input_data_format='channels_first'
102
  )
103
 
104
  pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
 
109
  if os.path.exists(tmp_path):
110
  os.remove(tmp_path)
111
 
112
+ def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = 16):
113
+ """
114
+ Processes a list of raw frame bytes (JPEG/PNG encoded) into the Swin3D model input format.
115
+ Eliminates video encoding/decoding and disk I/O.
116
+ """
117
+ image_processor = VivitImageProcessor(
118
+ do_resize=True,
119
+ size={"shortest_edge": 224},
120
+ do_center_crop=True,
121
+ crop_size={"height": 224, "width": 224},
122
+ do_rescale=True,
123
+ rescale_factor=1/255,
124
+ do_normalize=True,
125
+ image_mean=[0.5, 0.5, 0.5],
126
+ image_std=[0.5, 0.5, 0.5],
127
+ )
128
+
129
+ frames = []
130
+ for frame_bytes in frames_list_bytes:
131
+ # Decode image from bytes
132
+ nparr = np.frombuffer(frame_bytes, np.uint8)
133
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
134
+ if img is not None:
135
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
136
+ frames.append(img)
137
+
138
+ if not frames:
139
+ raise ValueError("No valid frames decoded")
140
+
141
+ # Temporal sampling/padding
142
+ if len(frames) < clip_length:
143
+ frames += [frames[-1]] * (clip_length - len(frames))
144
+ elif len(frames) > clip_length:
145
+ frames = frames[:clip_length]
146
+
147
+ # Processor expects list of numpy arrays (H, W, C)
148
+ processed = image_processor(
149
+ frames,
150
+ return_tensors='pt',
151
+ # image_processor handles (T, C, H, W) return with return_tensors='pt'
152
+ # but we need to check internal dimension order
153
+ )
154
+
155
+ pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
156
+ pixel_values = pixel_values.permute(1, 0, 2, 3) # (C, T, H, W) for Swin3D
157
+
158
+ return pixel_values.unsqueeze(0)
159
+
160
+ def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5):
161
+ """Runs inference from raw frame bytes"""
162
+ pixel_values = preprocess_frames(frames_list_bytes).to(DEVICE)
163
+
164
+ with torch.no_grad():
165
+ outputs = model(pixel_values)
166
+ probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
167
+
168
+ top_probs, top_indices = torch.topk(probabilities, k=top_k)
169
+
170
+ results = []
171
+ for i in range(top_k):
172
+ results.append({
173
+ "class": CLASSES[top_indices[i].item()],
174
+ "confidence": float(top_probs[i].item())
175
+ })
176
+
177
+ return {
178
+ "prediction": results[0]["class"],
179
+ "confidence": results[0]["confidence"],
180
+ "top_k": results
181
+ }
182
+
183
  def predict(model, video_bytes: bytes, top_k: int = 5):
184
  """Runs inference and returns the top results"""
185
  pixel_values = preprocess_video(video_bytes).to(DEVICE)