Creator-090 commited on
Commit
1cf4369
·
1 Parent(s): 3f99a4e

Update: model.py and app.py to remove multiple instances of same methods and add quantization (f16) to reduce inference time

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +49 -98
  3. model.py +150 -223
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode/
2
+ .venv/
3
+ __pycache__
app.py CHANGED
@@ -1,16 +1,13 @@
1
  # app.py
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  import uvicorn
5
- from model import load_model, predict, predict_from_frames
6
- from model import load_model, predict, predict_from_frames
7
  import time
8
- from pydantic import BaseModel
9
- from typing import List
10
  import base64
11
- from pydantic import BaseModel
12
  from typing import List
13
- import base64
 
14
 
15
  app = FastAPI(
16
  title="ISL Recognition API",
@@ -18,8 +15,6 @@ app = FastAPI(
18
  version="1.0.0"
19
  )
20
 
21
- # Allow all origins (for Flutter / frontend apps)
22
- # Allow all origins (for Flutter / frontend apps)
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
@@ -28,55 +23,49 @@ app.add_middleware(
28
  )
29
 
30
  # Global state
31
- model = None
32
  model_loaded = False
33
- model_error = None
34
 
35
 
36
-
37
- # STARTUP
38
  @app.on_event("startup")
39
  async def startup_event():
40
  global model, model_loaded, model_error
41
  try:
42
- model = load_model()
43
  model_loaded = True
44
- model_error = None
45
  print("Model loaded and API is ready!")
46
  except Exception as e:
47
  model_loaded = False
48
- model_error = str(e)
49
  print("Model failed to load:", e)
50
 
51
 
52
-
53
- # ROOT
54
  @app.get("/")
55
  def root():
56
  return {
57
- "status": "ISL API is running",
58
- "message": "Send a POST request to /predict (video) or /predict_frames (frames list)"
59
  }
60
 
61
 
62
- # HEALTH
63
  @app.get("/health")
64
  def health():
65
  if not model_loaded or model is None:
66
- return {
67
- "status": "error",
68
- "model_loaded": False,
69
- "error": model_error
70
- }
71
-
72
  return {
73
- "status": "ok",
74
  "model_loaded": True,
75
- "device": str(next(model.parameters()).device)
 
76
  }
77
 
78
 
79
- # DEEP HEALTH
80
  @app.get("/health/deep")
81
  def health_deep():
82
  if not model_loaded or model is None:
@@ -84,109 +73,71 @@ def health_deep():
84
 
85
  try:
86
  import torch
87
-
88
- dummy = torch.zeros(1, 3, 16, 224, 224).to(
89
- next(model.parameters()).device
90
- )
91
-
92
  with torch.no_grad():
93
  _ = model(dummy)
 
 
 
94
 
95
- return {
96
- "status": "ok",
97
- "inference": "working"
98
- }
99
 
100
- except Exception as e:
101
- raise HTTPException(
102
- status_code=500,
103
- detail=f"Inference failed: {str(e)}"
104
- )
105
-
106
  class FramesPayload(BaseModel):
107
- frames: List[str] # List of base64 encoded JPEG/PNG images
108
- top_k: int = 5
109
 
110
  @app.post("/predict_frames")
111
  async def predict_frames_api(payload: FramesPayload):
112
  if not model_loaded or model is None:
113
  raise HTTPException(status_code=503, detail="Model is not ready")
114
-
115
  if not payload.frames or len(payload.frames) != 16:
116
  raise HTTPException(status_code=400, detail="Exactly 16 frames required")
117
 
118
- start_time = time.time()
119
-
 
120
  try:
121
- # Convert base64 strings to bytes
122
- frames_bytes = [base64.b64decode(f) for f in payload.frames]
123
  result = predict_from_frames(model, frames_bytes, top_k=payload.top_k)
124
  except Exception as e:
125
- raise HTTPException(
126
- status_code=500,
127
- detail=f"Inference error: {str(e)}"
128
- )
129
 
130
- # Standardized response format as per checklist
131
  return {
132
- "prediction": result["prediction"],
133
- "confidence": result["confidence"],
134
- "inference_time_ms": round((time.time() - start_time) * 1000, 2)
 
135
  }
136
 
137
 
138
- # PREDICT
139
- @app.post("/predict")
 
 
140
  async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
141
- # Validate file type
142
- # Validate file type
143
- if not file.filename.lower().endswith(('.mp4', '.mov', '.avi', '.mkv')):
144
  raise HTTPException(
145
  status_code=400,
146
- detail="Invalid file type. Please upload a video (.mp4, .mov, etc.)"
147
  )
148
-
149
- # Ensure model is ready
150
- if not model_loaded or model is None:
151
- raise HTTPException(
152
- status_code=503,
153
- detail="Model is not ready"
154
- )
155
- # Ensure model is ready
156
  if not model_loaded or model is None:
157
- raise HTTPException(
158
- status_code=503,
159
- detail="Model is not ready"
160
- )
161
 
162
- start_time = time.time()
163
  video_bytes = await file.read()
164
 
165
  try:
166
  result = predict(model, video_bytes, top_k=top_k)
167
  except Exception as e:
168
- raise HTTPException(
169
- status_code=500,
170
- detail=f"Inference error: {str(e)}"
171
- )
172
-
173
-
174
- try:
175
- result = predict(model, video_bytes, top_k=top_k)
176
- except Exception as e:
177
- raise HTTPException(
178
- status_code=500,
179
- detail=f"Inference error: {str(e)}"
180
- )
181
-
182
- result["inference_time_ms"] = round((time.time() - start_time) * 1000, 2)
183
- result["filename"] = file.filename
184
-
185
-
186
- return result
187
 
 
 
 
 
 
188
 
189
 
 
190
  if __name__ == "__main__":
191
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
192
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  # app.py
2
  from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
  import uvicorn
 
 
6
  import time
 
 
7
  import base64
 
8
  from typing import List
9
+
10
+ from model import load_model, predict, predict_from_frames, DEVICE, _DTYPE
11
 
12
  app = FastAPI(
13
  title="ISL Recognition API",
 
15
  version="1.0.0"
16
  )
17
 
 
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
  allow_origins=["*"],
 
23
  )
24
 
25
  # Global state
26
+ model = None
27
  model_loaded = False
28
+ model_error = None
29
 
30
 
31
+ # Startup
 
32
  @app.on_event("startup")
33
  async def startup_event():
34
  global model, model_loaded, model_error
35
  try:
36
+ model = load_model()
37
  model_loaded = True
38
+ model_error = None
39
  print("Model loaded and API is ready!")
40
  except Exception as e:
41
  model_loaded = False
42
+ model_error = str(e)
43
  print("Model failed to load:", e)
44
 
45
 
46
+ # Root
 
47
  @app.get("/")
48
  def root():
49
  return {
50
+ "status": "ISL API is running",
51
+ "message": "POST to /predict (video file) or /predict_frames (base64 frames)"
52
  }
53
 
54
 
55
+ # Health
56
  @app.get("/health")
57
  def health():
58
  if not model_loaded or model is None:
59
+ return {"status": "error", "model_loaded": False, "error": model_error}
 
 
 
 
 
60
  return {
61
+ "status": "ok",
62
  "model_loaded": True,
63
+ "device": str(DEVICE),
64
+ "fp16": str(_DTYPE),
65
  }
66
 
67
 
68
+ # Deep health
69
  @app.get("/health/deep")
70
  def health_deep():
71
  if not model_loaded or model is None:
 
73
 
74
  try:
75
  import torch
76
+ # Must match the dtype the model now runs in (FP16 on GPU)
77
+ dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE)
 
 
 
78
  with torch.no_grad():
79
  _ = model(dummy)
80
+ return {"status": "ok", "inference": "working", "device": str(DEVICE)}
81
+ except Exception as e:
82
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
83
 
 
 
 
 
84
 
85
+ # Predict from frames (real-time path)
 
 
 
 
 
86
  class FramesPayload(BaseModel):
87
+ frames: List[str] # base64-encoded JPEG/PNG, exactly 16
88
+ top_k: int = 5
89
 
90
  @app.post("/predict_frames")
91
  async def predict_frames_api(payload: FramesPayload):
92
  if not model_loaded or model is None:
93
  raise HTTPException(status_code=503, detail="Model is not ready")
 
94
  if not payload.frames or len(payload.frames) != 16:
95
  raise HTTPException(status_code=400, detail="Exactly 16 frames required")
96
 
97
+ start_time = time.time()
98
+ frames_bytes = [base64.b64decode(f) for f in payload.frames]
99
+
100
  try:
 
 
101
  result = predict_from_frames(model, frames_bytes, top_k=payload.top_k)
102
  except Exception as e:
103
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
 
 
 
104
 
 
105
  return {
106
+ "prediction": result["prediction"],
107
+ "confidence": result["confidence"],
108
+ "top_k": result["top_k"],
109
+ "inference_time_ms": round((time.time() - start_time) * 1000, 2),
110
  }
111
 
112
 
113
+ # Predict from video file
114
+ ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv')
115
+
116
+ @app.post("/predict")
117
  async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
118
+ if not file.filename.lower().endswith(ALLOWED_EXTENSIONS):
 
 
119
  raise HTTPException(
120
  status_code=400,
121
+ detail=f"Invalid file type. Allowed: {ALLOWED_EXTENSIONS}"
122
  )
 
 
 
 
 
 
 
 
123
  if not model_loaded or model is None:
124
+ raise HTTPException(status_code=503, detail="Model is not ready")
 
 
 
125
 
126
+ start_time = time.time()
127
  video_bytes = await file.read()
128
 
129
  try:
130
  result = predict(model, video_bytes, top_k=top_k)
131
  except Exception as e:
132
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ return {
135
+ **result,
136
+ "inference_time_ms": round((time.time() - start_time) * 1000, 2),
137
+ "filename": file.filename,
138
+ }
139
 
140
 
141
+ # Entry point
142
  if __name__ == "__main__":
 
143
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
model.py CHANGED
@@ -1,49 +1,52 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision.models import video as ptv
4
  from torchvision.transforms import v2
5
- from transformers import VivitImageProcessor
6
  from decord import VideoReader
7
  from decord.bridge import set_bridge
8
- import gc
9
- import tempfile
10
- import os
11
- import cv2
12
- import numpy as np
13
  import cv2
14
  import numpy as np
15
 
16
- # Exactly 76 classes from your notebook metadata
17
  CLASSES = [
18
- 'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
19
- 'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
20
- 'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
21
- 'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
22
- 'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
23
- 'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
24
- 'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
25
- 'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
26
- 'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
27
  'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
28
  ]
29
 
30
- # Constants matched to your hyperparameters
31
- CLIP_LENGTH = 16
32
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  class SwinTClassifications(nn.Module):
35
- """Model architecture from your notebook cell 79/197"""
36
  def __init__(self, classes, weights="KINETICS400_V1"):
37
  super().__init__()
38
  self.classes = classes
39
- # Load Swin3D-S backbone
40
  self.base_model = ptv.swin3d_s(weights=weights)
41
-
42
- # Classification head with your 76 output features
43
  self.classification_head = nn.Sequential(
44
  nn.Linear(self.base_model.head.in_features, len(self.classes))
45
  )
46
- # Head replaced with Identity as per your architecture
47
  self.base_model.head = nn.Identity()
48
 
49
  def forward(self, x):
@@ -51,229 +54,153 @@ class SwinTClassifications(nn.Module):
51
  x = self.classification_head(x)
52
  return x
53
 
 
54
  def load_model():
55
- """Downloads best model from your HF repo and loads weights"""
56
  from huggingface_hub import hf_hub_download
57
-
58
- print("Fetching model from Hugging Face Hub...")
59
  model_path = hf_hub_download(
60
- repo_id="Creator-090/isl-swin3d-model",
61
  filename="ISL_best_model.pt"
62
  )
63
-
64
  model = SwinTClassifications(classes=CLASSES)
65
  model.load_state_dict(
66
  torch.load(model_path, map_location=DEVICE, weights_only=True)
67
  )
68
  model = model.to(DEVICE)
 
 
 
 
 
69
  model.eval()
 
 
 
 
 
 
 
 
 
 
70
  return model
71
 
72
- def preprocess_video(video_bytes: bytes, clip_length: int = 16):
73
- set_bridge("torch")
74
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
75
- f.write(video_bytes)
76
- tmp_path = f.name
77
- try:
78
- image_processor = VivitImageProcessor(
79
- do_resize=True,
80
- size={"shortest_edge": 224},
81
- do_center_crop=True,
82
- crop_size={"height": 224, "width": 224},
83
- do_rescale=True,
84
- rescale_factor=1/255,
85
- do_normalize=True,
86
- image_mean=[0.5, 0.5, 0.5],
87
- image_std=[0.5, 0.5, 0.5],
88
- )
89
- vr = VideoReader(tmp_path)
90
- total_frames = len(vr)
91
- indices = list(range(min(total_frames, clip_length)))
92
- if len(indices) < clip_length:
93
- indices += [indices[-1]] * (clip_length - len(indices))
94
-
95
- # Ensure video is a torch tensor in (Frames, Channels, Height, Width)
96
- video = vr.get_batch(indices)
97
- video = video.permute(0, 3, 1, 2).to(torch.uint8) # Convert to Float for the processor
98
-
99
- # Pass as a list of Tensors
100
- processed = image_processor(
101
- list(video),
102
- return_tensors='pt',
103
- input_data_format='channels_first'
104
- )
105
-
106
- pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
107
- pixel_values = pixel_values.permute(1, 0, 2, 3) # (C, T, H, W) for Swin3D
108
-
109
- return pixel_values.unsqueeze(0)
110
- finally:
111
- if os.path.exists(tmp_path):
112
- os.remove(tmp_path)
113
-
114
- def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = 16):
115
  """
116
- Processes a list of raw frame bytes into the Swin3D model input format.
117
- Following the exact 'no-BS' checklist implementation.
118
  """
119
- image_processor = VivitImageProcessor(
120
- do_resize=True,
121
- size={"shortest_edge": 224},
122
- do_center_crop=True,
123
- crop_size={"height": 224, "width": 224},
124
- do_rescale=True,
125
- rescale_factor=1/255,
126
- do_normalize=True,
127
- image_mean=[0.5, 0.5, 0.5],
128
- image_std=[0.5, 0.5, 0.5],
129
- )
130
-
131
- # 1. Decode bytes to PIL Images
132
- from io import BytesIO
133
- from PIL import Image
134
-
135
- decoded_frames = []
136
- for f_bytes in frames_list_bytes:
137
- img = Image.open(BytesIO(f_bytes)).convert("RGB")
138
- decoded_frames.append(img)
139
-
140
- if len(decoded_frames) != clip_length:
141
- raise ValueError(f"Exactly {clip_length} frames required, got {len(decoded_frames)}")
142
-
143
- # 2. Convert to tensor stack (T, C, H, W)
144
- # Note: User's snippet used torch.from_numpy(np.array(img)).permute(2, 0, 1)
145
  video = torch.stack([
146
- torch.from_numpy(np.array(img)).permute(2, 0, 1)
147
- for img in decoded_frames
148
- ])
149
-
150
- # 3. Apply ImageProcessor
151
- processed = image_processor(
152
- list(video),
153
- return_tensors='pt',
154
- input_data_format='channels_first'
155
- )
156
-
157
- # 4. Standardize dimensions for Swin3D: (Batch, Channels, Time, Height, Width)
158
- pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
159
- pixel_values = pixel_values.permute(1, 0, 2, 3) # (C, T, H, W)
160
-
161
- return pixel_values.unsqueeze(0) # (1, C, T, H, W)
162
-
163
- def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5):
164
- """Runs inference from raw frame bytes"""
165
- pixel_values = preprocess_frames(frames_list_bytes).to(DEVICE)
166
-
167
- with torch.no_grad():
168
- outputs = model(pixel_values)
169
- probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
170
-
171
- top_probs, top_indices = torch.topk(probabilities, k=top_k)
172
-
173
- results = []
174
- for i in range(top_k):
175
- results.append({
176
- "class": CLASSES[top_indices[i].item()],
177
- "confidence": float(top_probs[i].item())
178
- })
179
-
180
- return {
181
- "prediction": results[0]["class"],
182
- "confidence": results[0]["confidence"],
183
- "top_k": results
184
- }
185
 
186
- def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = 16):
 
187
  """
188
- Processes a list of raw frame bytes (JPEG/PNG encoded) into the Swin3D model input format.
189
- Eliminates video encoding/decoding and disk I/O.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  """
191
- image_processor = VivitImageProcessor(
192
- do_resize=True,
193
- size={"shortest_edge": 224},
194
- do_center_crop=True,
195
- crop_size={"height": 224, "width": 224},
196
- do_rescale=True,
197
- rescale_factor=1/255,
198
- do_normalize=True,
199
- image_mean=[0.5, 0.5, 0.5],
200
- image_std=[0.5, 0.5, 0.5],
201
- )
202
-
203
  frames = []
204
- for frame_bytes in frames_list_bytes:
205
- # Decode image from bytes
206
- nparr = np.frombuffer(frame_bytes, np.uint8)
207
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
208
- if img is not None:
209
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
210
- frames.append(img)
211
-
212
  if not frames:
213
- raise ValueError("No valid frames decoded")
214
-
215
- # Temporal sampling/padding
216
- if len(frames) < clip_length:
217
- frames += [frames[-1]] * (clip_length - len(frames))
218
- elif len(frames) > clip_length:
219
- frames = frames[:clip_length]
220
-
221
- # Processor expects list of numpy arrays (H, W, C)
222
- processed = image_processor(
223
- frames,
224
- return_tensors='pt',
225
- # image_processor handles (T, C, H, W) return with return_tensors='pt'
226
- # but we need to check internal dimension order
227
- )
228
-
229
- pixel_values = processed['pixel_values'].squeeze(0) # (T, C, H, W)
230
- pixel_values = pixel_values.permute(1, 0, 2, 3) # (C, T, H, W) for Swin3D
231
-
232
- return pixel_values.unsqueeze(0)
233
-
234
- def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5):
235
- """Runs inference from raw frame bytes"""
236
- pixel_values = preprocess_frames(frames_list_bytes).to(DEVICE)
237
-
238
  with torch.no_grad():
239
- outputs = model(pixel_values)
240
- probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
241
-
242
- top_probs, top_indices = torch.topk(probabilities, k=top_k)
243
-
244
- results = []
245
- for i in range(top_k):
246
- results.append({
247
- "class": CLASSES[top_indices[i].item()],
248
- "confidence": float(top_probs[i].item())
249
- })
250
-
 
251
  return {
252
  "prediction": results[0]["class"],
253
  "confidence": results[0]["confidence"],
254
- "top_k": results
255
  }
256
 
257
- def predict(model, video_bytes: bytes, top_k: int = 5):
258
- """Runs inference and returns the top results"""
259
- pixel_values = preprocess_video(video_bytes).to(DEVICE)
260
-
261
- with torch.no_grad():
262
- # Standardize for CPU/GPU mixed precision
263
- outputs = model(pixel_values)
264
- probabilities = torch.nn.functional.softmax(outputs, dim=-1)[0]
265
-
266
- top_probs, top_indices = torch.topk(probabilities, k=top_k)
267
-
268
- results = []
269
- for i in range(top_k):
270
- results.append({
271
- "class": CLASSES[top_indices[i].item()],
272
- "confidence": float(top_probs[i].item())
273
- })
274
-
275
- return {
276
- "prediction": results[0]["class"],
277
- "confidence": results[0]["confidence"],
278
- "top_k": results
279
- }
 
1
+ import io
2
  import torch
3
  import torch.nn as nn
4
  from torchvision.models import video as ptv
5
  from torchvision.transforms import v2
 
6
  from decord import VideoReader
7
  from decord.bridge import set_bridge
 
 
 
 
 
8
  import cv2
9
  import numpy as np
10
 
11
+ # Classes
12
  CLASSES = [
13
+ 'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
14
+ 'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
15
+ 'dress', 'dry', 'evening', 'expensive', 'famous', 'fast', 'female',
16
+ 'fish', 'flat', 'friday', 'good', 'happy', 'hat', 'healthy', 'horse',
17
+ 'hot', 'hour', 'light', 'long', 'loose', 'loud', 'minute', 'monday',
18
+ 'month', 'morning', 'mouse', 'narrow', 'new', 'night', 'old', 'pant',
19
+ 'pocket', 'quiet', 'sad', 'saturday', 'second', 'shirt', 'shoes',
20
+ 'short', 'sick', 'skirt', 'slow', 'small', 'suit', 'sunday', 't_shirt',
21
+ 'tall', 'thursday', 'time', 'today', 'tomorrow', 'tuesday', 'ugly',
22
  'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
23
  ]
24
 
25
+ # Constants
26
+ CLIP_LENGTH = 16
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ USE_FP16 = DEVICE.type == "cuda"
29
+
30
+ # Global transform pipeline (built once, runs on GPU)
31
+ # Replaces VivitImageProcessor - same operations, but GPU-accelerated via torchvision v2
32
+ _DTYPE = torch.float16 if USE_FP16 else torch.float32
33
+
34
+ TRANSFORMS = v2.Compose([
35
+ v2.Resize(224, antialias=True), # shortest edge → 224
36
+ v2.CenterCrop(224), # 224×224
37
+ v2.ToDtype(_DTYPE, scale=True), # uint8 => float, /255
38
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
39
+ ])
40
 
41
+ # Model
42
  class SwinTClassifications(nn.Module):
 
43
  def __init__(self, classes, weights="KINETICS400_V1"):
44
  super().__init__()
45
  self.classes = classes
 
46
  self.base_model = ptv.swin3d_s(weights=weights)
 
 
47
  self.classification_head = nn.Sequential(
48
  nn.Linear(self.base_model.head.in_features, len(self.classes))
49
  )
 
50
  self.base_model.head = nn.Identity()
51
 
52
  def forward(self, x):
 
54
  x = self.classification_head(x)
55
  return x
56
 
57
+
58
  def load_model():
59
+ """Downloads model from HF Hub, applies FP16 + torch.compile for max speed."""
60
  from huggingface_hub import hf_hub_download
61
+
62
+ print(f"Loading model on {DEVICE} (fp16={USE_FP16}) ...")
63
  model_path = hf_hub_download(
64
+ repo_id="Creator-090/isl-swin3d-model",
65
  filename="ISL_best_model.pt"
66
  )
67
+
68
  model = SwinTClassifications(classes=CLASSES)
69
  model.load_state_dict(
70
  torch.load(model_path, map_location=DEVICE, weights_only=True)
71
  )
72
  model = model.to(DEVICE)
73
+
74
+ # FP16 on GPU - ~2x faster inference, no accuracy loss for classification
75
+ if USE_FP16:
76
+ model = model.half()
77
+
78
  model.eval()
79
+
80
+ # torch.compile - fuses ops, reduces Python overhead (~20-35% faster after warmup)
81
+ if DEVICE.type == "cuda":
82
+ print("Compiling model with torch.compile (mode=reduce-overhead) ...")
83
+ model = torch.compile(model, mode="reduce-overhead")
84
+
85
+ # Warmup - triggers compilation + CUDA kernel caching so first real request is fast
86
+ _warmup(model)
87
+
88
+ print("Model ready.")
89
  return model
90
 
91
+
92
+ def _warmup(model, rounds: int = 3):
93
+ """Run a few dummy forward passes to trigger torch.compile and warm CUDA kernels."""
94
+ print(f"Warming up model ({rounds} rounds) ...")
95
+ dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE)
96
+ with torch.no_grad():
97
+ for _ in range(rounds):
98
+ _ = model(dummy)
99
+ if DEVICE.type == "cuda":
100
+ torch.cuda.synchronize()
101
+ print("Warmup complete.")
102
+
103
+
104
+ # Preprocessing helpers
105
+
106
+ def _frames_to_tensor(frames: list) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  """
108
+ Converts a list of numpy (H,W,3) RGB frames (1, C, T, H, W) tensor on DEVICE.
109
+ Resize + normalize happen on GPU via torchvision v2 transforms.
110
  """
111
+ # Stack => (T, C, H, W) uint8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  video = torch.stack([
113
+ torch.from_numpy(f).permute(2, 0, 1) # H,W,C => C,H,W
114
+ for f in frames
115
+ ]) # (T, C, H, W)
116
+
117
+ video = video.to(DEVICE) # move to GPU first, then transform
118
+ video = TRANSFORMS(video) # resize + crop + normalize on GPU => (T, C, H, W)
119
+ video = video.permute(1, 0, 2, 3) # (C, T, H, W) => Swin3D expects this
120
+ return video.unsqueeze(0) # (1, C, T, H, W)
121
+
122
+
123
+ def _pad_or_trim(frames: list, clip_length: int) -> list:
124
+ if len(frames) < clip_length:
125
+ frames += [frames[-1]] * (clip_length - len(frames))
126
+ elif len(frames) > clip_length:
127
+ # Uniform temporal sampling instead of naive truncation
128
+ indices = [int(i * len(frames) / clip_length) for i in range(clip_length)]
129
+ frames = [frames[i] for i in indices]
130
+ return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+
133
+ def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor:
134
  """
135
+ Decodes a video from raw bytes (no disk I/O) and returns a model-ready tensor.
136
+ Uses decord's in-memory VideoReader to avoid the tempfile write/read cycle.
137
+ """
138
+ set_bridge("torch")
139
+ vr = VideoReader(io.BytesIO(video_bytes)) # in-memory, no disk write
140
+ total = len(vr)
141
+ idx = list(range(min(total, clip_length)))
142
+ if len(idx) < clip_length:
143
+ idx += [idx[-1]] * (clip_length - len(idx))
144
+
145
+ batch = vr.get_batch(idx).asnumpy() # (T, H, W, C) uint8 numpy
146
+ frames = [batch[i] for i in range(batch.shape[0])] # list of (H, W, C)
147
+
148
+ return _frames_to_tensor(frames)
149
+
150
+
151
+ def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor:
152
+ """
153
+ Decodes a list of JPEG/PNG frame bytes and returns a model-ready tensor.
154
+ All heavy lifting (resize, normalize) happens on GPU.
155
  """
 
 
 
 
 
 
 
 
 
 
 
 
156
  frames = []
157
+ for fb in frames_list_bytes:
158
+ arr = np.frombuffer(fb, np.uint8)
159
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
160
+ if img is None:
161
+ continue
162
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR → RGB
163
+ frames.append(img)
164
+
165
  if not frames:
166
+ raise ValueError("No valid frames could be decoded from the provided bytes.")
167
+
168
+ frames = _pad_or_trim(frames, clip_length)
169
+ return _frames_to_tensor(frames)
170
+
171
+
172
+ # Inference
173
+
174
+ def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
175
+ """Shared inference logic for both predict paths."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  with torch.no_grad():
177
+ # autocast is a no-op on CPU; on GPU it enforces FP16 even if something slipped through
178
+ with torch.autocast(device_type=DEVICE.type, dtype=_DTYPE, enabled=USE_FP16):
179
+ outputs = model(pixel_values)
180
+
181
+ probs = torch.nn.functional.softmax(outputs, dim=-1)[0]
182
+
183
+ top_probs, top_indices = torch.topk(probs, k=top_k)
184
+
185
+ results = [
186
+ {"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())}
187
+ for i in range(top_k)
188
+ ]
189
+
190
  return {
191
  "prediction": results[0]["class"],
192
  "confidence": results[0]["confidence"],
193
+ "top_k": results,
194
  }
195
 
196
+
197
+ def predict(model, video_bytes: bytes, top_k: int = 5) -> dict:
198
+ """Inference from raw video bytes."""
199
+ pixel_values = preprocess_video(video_bytes)
200
+ return _run_inference(model, pixel_values, top_k)
201
+
202
+
203
+ def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict:
204
+ """Inference from a list of raw JPEG/PNG frame bytes."""
205
+ pixel_values = preprocess_frames(frames_list_bytes)
206
+ return _run_inference(model, pixel_values, top_k)