Creator-090 commited on
Commit
86c7cf3
·
1 Parent(s): 1cf4369

fix: CPU-safe inference for HF free tier

Browse files

- Skip autocast and torch.compile on CPU
- Reduce warmup to 1 round on CPU (was 3, ~60s saved on cold start)
- Return 503 from /health while model is loading so wake_up() retries correctly

Files changed (2) hide show
  1. app.py +13 -10
  2. model.py +37 -60
app.py CHANGED
@@ -22,7 +22,7 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Global state
26
  model = None
27
  model_loaded = False
28
  model_error = None
@@ -43,7 +43,7 @@ async def startup_event():
43
  print("Model failed to load:", e)
44
 
45
 
46
- # Root
47
  @app.get("/")
48
  def root():
49
  return {
@@ -52,11 +52,15 @@ def root():
52
  }
53
 
54
 
55
- # Health
56
  @app.get("/health")
57
  def health():
58
  if not model_loaded or model is None:
59
- return {"status": "error", "model_loaded": False, "error": model_error}
 
 
 
 
60
  return {
61
  "status": "ok",
62
  "model_loaded": True,
@@ -65,7 +69,7 @@ def health():
65
  }
66
 
67
 
68
- # Deep health
69
  @app.get("/health/deep")
70
  def health_deep():
71
  if not model_loaded or model is None:
@@ -73,7 +77,6 @@ def health_deep():
73
 
74
  try:
75
  import torch
76
- # Must match the dtype the model now runs in (FP16 on GPU)
77
  dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE)
78
  with torch.no_grad():
79
  _ = model(dummy)
@@ -82,9 +85,9 @@ def health_deep():
82
  raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
83
 
84
 
85
- # Predict from frames (real-time path)
86
  class FramesPayload(BaseModel):
87
- frames: List[str] # base64-encoded JPEG/PNG, exactly 16
88
  top_k: int = 5
89
 
90
  @app.post("/predict_frames")
@@ -110,7 +113,7 @@ async def predict_frames_api(payload: FramesPayload):
110
  }
111
 
112
 
113
- # Predict from video file
114
  ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv')
115
 
116
  @app.post("/predict")
@@ -138,6 +141,6 @@ async def predict_sign(file: UploadFile = File(...), top_k: int = 5):
138
  }
139
 
140
 
141
- # Entry point
142
  if __name__ == "__main__":
143
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Global state
26
  model = None
27
  model_loaded = False
28
  model_error = None
 
43
  print("Model failed to load:", e)
44
 
45
 
46
+ # Root
47
  @app.get("/")
48
  def root():
49
  return {
 
52
  }
53
 
54
 
55
+ # Health
56
  @app.get("/health")
57
  def health():
58
  if not model_loaded or model is None:
59
+ # Return 503 so the wake_up() retry loop in backend knows to keep waiting
60
+ raise HTTPException(
61
+ status_code=503,
62
+ detail={"status": "error", "model_loaded": False, "error": model_error}
63
+ )
64
  return {
65
  "status": "ok",
66
  "model_loaded": True,
 
69
  }
70
 
71
 
72
+ # Deep health
73
  @app.get("/health/deep")
74
  def health_deep():
75
  if not model_loaded or model is None:
 
77
 
78
  try:
79
  import torch
 
80
  dummy = torch.zeros(1, 3, 16, 224, 224, device=DEVICE, dtype=_DTYPE)
81
  with torch.no_grad():
82
  _ = model(dummy)
 
85
  raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
86
 
87
 
88
+ # Predict from frames (real-time path)
89
  class FramesPayload(BaseModel):
90
+ frames: List[str]
91
  top_k: int = 5
92
 
93
  @app.post("/predict_frames")
 
113
  }
114
 
115
 
116
+ # Predict from video file
117
  ALLOWED_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mkv')
118
 
119
  @app.post("/predict")
 
141
  }
142
 
143
 
144
+ # Entry point
145
  if __name__ == "__main__":
146
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
model.py CHANGED
@@ -8,7 +8,7 @@ from decord.bridge import set_bridge
8
  import cv2
9
  import numpy as np
10
 
11
- # Classes
12
  CLASSES = [
13
  'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
14
  'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
@@ -22,23 +22,23 @@ CLASSES = [
22
  'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
23
  ]
24
 
25
- # Constants
26
  CLIP_LENGTH = 16
27
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- USE_FP16 = DEVICE.type == "cuda"
 
29
 
30
- # Global transform pipeline (built once, runs on GPU)
31
- # Replaces VivitImageProcessor - same operations, but GPU-accelerated via torchvision v2
32
- _DTYPE = torch.float16 if USE_FP16 else torch.float32
33
 
 
34
  TRANSFORMS = v2.Compose([
35
- v2.Resize(224, antialias=True), # shortest edge → 224
36
- v2.CenterCrop(224), # 224×224
37
- v2.ToDtype(_DTYPE, scale=True), # uint8 => float, /255
38
  v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
39
  ])
40
 
41
- # Model
42
  class SwinTClassifications(nn.Module):
43
  def __init__(self, classes, weights="KINETICS400_V1"):
44
  super().__init__()
@@ -56,10 +56,9 @@ class SwinTClassifications(nn.Module):
56
 
57
 
58
  def load_model():
59
- """Downloads model from HF Hub, applies FP16 + torch.compile for max speed."""
60
  from huggingface_hub import hf_hub_download
61
 
62
- print(f"Loading model on {DEVICE} (fp16={USE_FP16}) ...")
63
  model_path = hf_hub_download(
64
  repo_id="Creator-090/isl-swin3d-model",
65
  filename="ISL_best_model.pt"
@@ -71,27 +70,25 @@ def load_model():
71
  )
72
  model = model.to(DEVICE)
73
 
74
- # FP16 on GPU - ~2x faster inference, no accuracy loss for classification
75
  if USE_FP16:
76
  model = model.half()
77
 
78
  model.eval()
79
 
80
- # torch.compile - fuses ops, reduces Python overhead (~20-35% faster after warmup)
81
  if DEVICE.type == "cuda":
82
- print("Compiling model with torch.compile (mode=reduce-overhead) ...")
83
  model = torch.compile(model, mode="reduce-overhead")
84
 
85
- # Warmup - triggers compilation + CUDA kernel caching so first real request is fast
86
  _warmup(model)
87
-
88
  print("Model ready.")
89
  return model
90
 
91
 
92
- def _warmup(model, rounds: int = 3):
93
- """Run a few dummy forward passes to trigger torch.compile and warm CUDA kernels."""
94
- print(f"Warming up model ({rounds} rounds) ...")
 
95
  dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE)
96
  with torch.no_grad():
97
  for _ in range(rounds):
@@ -101,92 +98,74 @@ def _warmup(model, rounds: int = 3):
101
  print("Warmup complete.")
102
 
103
 
104
- # Preprocessing helpers
105
-
106
  def _frames_to_tensor(frames: list) -> torch.Tensor:
107
- """
108
- Converts a list of numpy (H,W,3) RGB frames → (1, C, T, H, W) tensor on DEVICE.
109
- Resize + normalize happen on GPU via torchvision v2 transforms.
110
- """
111
- # Stack => (T, C, H, W) uint8
112
  video = torch.stack([
113
- torch.from_numpy(f).permute(2, 0, 1) # H,W,C => C,H,W
114
  for f in frames
115
- ]) # (T, C, H, W)
116
-
117
- video = video.to(DEVICE) # move to GPU first, then transform
118
- video = TRANSFORMS(video) # resize + crop + normalize on GPU => (T, C, H, W)
119
- video = video.permute(1, 0, 2, 3) # (C, T, H, W) => Swin3D expects this
120
- return video.unsqueeze(0) # (1, C, T, H, W)
121
 
122
 
123
  def _pad_or_trim(frames: list, clip_length: int) -> list:
124
  if len(frames) < clip_length:
125
  frames += [frames[-1]] * (clip_length - len(frames))
126
  elif len(frames) > clip_length:
127
- # Uniform temporal sampling instead of naive truncation
128
  indices = [int(i * len(frames) / clip_length) for i in range(clip_length)]
129
  frames = [frames[i] for i in indices]
130
  return frames
131
 
132
 
133
  def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor:
134
- """
135
- Decodes a video from raw bytes (no disk I/O) and returns a model-ready tensor.
136
- Uses decord's in-memory VideoReader to avoid the tempfile write/read cycle.
137
- """
138
  set_bridge("torch")
139
- vr = VideoReader(io.BytesIO(video_bytes)) # in-memory, no disk write
140
- total = len(vr)
141
- idx = list(range(min(total, clip_length)))
142
  if len(idx) < clip_length:
143
  idx += [idx[-1]] * (clip_length - len(idx))
144
-
145
- batch = vr.get_batch(idx).asnumpy() # (T, H, W, C) uint8 numpy
146
- frames = [batch[i] for i in range(batch.shape[0])] # list of (H, W, C)
147
-
148
  return _frames_to_tensor(frames)
149
 
150
 
151
  def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor:
152
- """
153
- Decodes a list of JPEG/PNG frame bytes and returns a model-ready tensor.
154
- All heavy lifting (resize, normalize) happens on GPU.
155
- """
156
  frames = []
157
  for fb in frames_list_bytes:
158
  arr = np.frombuffer(fb, np.uint8)
159
  img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
160
  if img is None:
161
  continue
162
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR → RGB
163
  frames.append(img)
164
 
165
  if not frames:
166
- raise ValueError("No valid frames could be decoded from the provided bytes.")
167
 
168
  frames = _pad_or_trim(frames, clip_length)
169
  return _frames_to_tensor(frames)
170
 
171
 
172
  # Inference
173
-
174
  def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
175
- """Shared inference logic for both predict paths."""
176
  with torch.no_grad():
177
- # autocast is a no-op on CPU; on GPU it enforces FP16 even if something slipped through
178
- with torch.autocast(device_type=DEVICE.type, dtype=_DTYPE, enabled=USE_FP16):
 
 
 
 
179
  outputs = model(pixel_values)
180
 
181
  probs = torch.nn.functional.softmax(outputs, dim=-1)[0]
182
 
183
  top_probs, top_indices = torch.topk(probs, k=top_k)
184
-
185
  results = [
186
  {"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())}
187
  for i in range(top_k)
188
  ]
189
-
190
  return {
191
  "prediction": results[0]["class"],
192
  "confidence": results[0]["confidence"],
@@ -195,12 +174,10 @@ def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
195
 
196
 
197
  def predict(model, video_bytes: bytes, top_k: int = 5) -> dict:
198
- """Inference from raw video bytes."""
199
  pixel_values = preprocess_video(video_bytes)
200
  return _run_inference(model, pixel_values, top_k)
201
 
202
 
203
  def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict:
204
- """Inference from a list of raw JPEG/PNG frame bytes."""
205
  pixel_values = preprocess_frames(frames_list_bytes)
206
  return _run_inference(model, pixel_values, top_k)
 
8
  import cv2
9
  import numpy as np
10
 
11
+ # Classes
12
  CLASSES = [
13
  'afternoon', 'animal', 'bad', 'beautiful', 'big', 'bird', 'blind',
14
  'cat', 'cheap', 'clothing', 'cold', 'cow', 'curved', 'deaf', 'dog',
 
22
  'warm', 'wednesday', 'week', 'wet', 'wide', 'year', 'yesterday', 'young'
23
  ]
24
 
25
+ # Constants
26
  CLIP_LENGTH = 16
27
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ USE_FP16 = DEVICE.type == "cuda" # False on HF free tier (CPU only)
29
+ _DTYPE = torch.float16 if USE_FP16 else torch.float32
30
 
31
+ print(f"[model] device={DEVICE} | fp16={USE_FP16} | dtype={_DTYPE}")
 
 
32
 
33
+ # Global transform pipeline (built once)
34
  TRANSFORMS = v2.Compose([
35
+ v2.Resize(224, antialias=True),
36
+ v2.CenterCrop(224),
37
+ v2.ToDtype(_DTYPE, scale=True),
38
  v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
39
  ])
40
 
41
+ # Model
42
  class SwinTClassifications(nn.Module):
43
  def __init__(self, classes, weights="KINETICS400_V1"):
44
  super().__init__()
 
56
 
57
 
58
  def load_model():
 
59
  from huggingface_hub import hf_hub_download
60
 
61
+ print(f"Loading model on {DEVICE} ...")
62
  model_path = hf_hub_download(
63
  repo_id="Creator-090/isl-swin3d-model",
64
  filename="ISL_best_model.pt"
 
70
  )
71
  model = model.to(DEVICE)
72
 
 
73
  if USE_FP16:
74
  model = model.half()
75
 
76
  model.eval()
77
 
78
+ # torch.compile only on CUDA can error or be very slow on CPU
79
  if DEVICE.type == "cuda":
80
+ print("Compiling model with torch.compile ...")
81
  model = torch.compile(model, mode="reduce-overhead")
82
 
 
83
  _warmup(model)
 
84
  print("Model ready.")
85
  return model
86
 
87
 
88
+ def _warmup(model):
89
+ # 1 round on CPU (warmup is slow ~30s on CPU Swin3D), 3 on GPU
90
+ rounds = 1 if DEVICE.type == "cpu" else 3
91
+ print(f"Warming up ({rounds} round(s) on {DEVICE}) ...")
92
  dummy = torch.zeros(1, 3, CLIP_LENGTH, 224, 224, device=DEVICE, dtype=_DTYPE)
93
  with torch.no_grad():
94
  for _ in range(rounds):
 
98
  print("Warmup complete.")
99
 
100
 
101
+ # Preprocessing
 
102
  def _frames_to_tensor(frames: list) -> torch.Tensor:
 
 
 
 
 
103
  video = torch.stack([
104
+ torch.from_numpy(f).permute(2, 0, 1)
105
  for f in frames
106
+ ]) # (T, C, H, W) uint8
107
+ video = video.to(DEVICE)
108
+ video = TRANSFORMS(video) # (T, C, H, W) float
109
+ video = video.permute(1, 0, 2, 3) # (C, T, H, W)
110
+ return video.unsqueeze(0) # (1, C, T, H, W)
 
111
 
112
 
113
  def _pad_or_trim(frames: list, clip_length: int) -> list:
114
  if len(frames) < clip_length:
115
  frames += [frames[-1]] * (clip_length - len(frames))
116
  elif len(frames) > clip_length:
 
117
  indices = [int(i * len(frames) / clip_length) for i in range(clip_length)]
118
  frames = [frames[i] for i in indices]
119
  return frames
120
 
121
 
122
  def preprocess_video(video_bytes: bytes, clip_length: int = CLIP_LENGTH) -> torch.Tensor:
 
 
 
 
123
  set_bridge("torch")
124
+ vr = VideoReader(io.BytesIO(video_bytes))
125
+ total = len(vr)
126
+ idx = list(range(min(total, clip_length)))
127
  if len(idx) < clip_length:
128
  idx += [idx[-1]] * (clip_length - len(idx))
129
+ batch = vr.get_batch(idx).asnumpy()
130
+ frames = [batch[i] for i in range(batch.shape[0])]
 
 
131
  return _frames_to_tensor(frames)
132
 
133
 
134
  def preprocess_frames(frames_list_bytes: list[bytes], clip_length: int = CLIP_LENGTH) -> torch.Tensor:
 
 
 
 
135
  frames = []
136
  for fb in frames_list_bytes:
137
  arr = np.frombuffer(fb, np.uint8)
138
  img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
139
  if img is None:
140
  continue
141
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
142
  frames.append(img)
143
 
144
  if not frames:
145
+ raise ValueError("No valid frames could be decoded.")
146
 
147
  frames = _pad_or_trim(frames, clip_length)
148
  return _frames_to_tensor(frames)
149
 
150
 
151
  # Inference
 
152
  def _run_inference(model, pixel_values: torch.Tensor, top_k: int) -> dict:
 
153
  with torch.no_grad():
154
+ if USE_FP16:
155
+ # autocast only valid on CUDA
156
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
157
+ outputs = model(pixel_values)
158
+ else:
159
+ # CPU path — plain fp32, no autocast
160
  outputs = model(pixel_values)
161
 
162
  probs = torch.nn.functional.softmax(outputs, dim=-1)[0]
163
 
164
  top_probs, top_indices = torch.topk(probs, k=top_k)
 
165
  results = [
166
  {"class": CLASSES[top_indices[i].item()], "confidence": float(top_probs[i].item())}
167
  for i in range(top_k)
168
  ]
 
169
  return {
170
  "prediction": results[0]["class"],
171
  "confidence": results[0]["confidence"],
 
174
 
175
 
176
  def predict(model, video_bytes: bytes, top_k: int = 5) -> dict:
 
177
  pixel_values = preprocess_video(video_bytes)
178
  return _run_inference(model, pixel_values, top_k)
179
 
180
 
181
  def predict_from_frames(model, frames_list_bytes: list[bytes], top_k: int = 5) -> dict:
 
182
  pixel_values = preprocess_frames(frames_list_bytes)
183
  return _run_inference(model, pixel_values, top_k)