hetchyy Claude Opus 4.6 commited on
Commit
0318d9e
Β·
1 Parent(s): 8fae5a4

Prevent CPU fallback from poisoning CUDA state via GPU model parking

Browse files

When GPU quota is exhausted, the CPU fallback path was dropping references
to GPU-resident models, causing Python GC to fire CUDA tensor destructors
outside a GPU lease. This corrupted torch.cuda state permanently, breaking
all subsequent GPU requests for every user.

Fix: "park" GPU model references in _stale_gpu_refs before invalidating
caches, preventing GC. Parked models are safely released inside the next
GPU lease via _drain_stale_models(). Also removes the now-unnecessary CPU
reload blocks from vad.py and phoneme_asr.py (which were the poison source)
and simplifies worker/CUDA error handlers to propagate immediately instead
of retrying (which risked further state corruption).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

src/alignment/phoneme_asr.py CHANGED
@@ -333,20 +333,10 @@ def transcribe_batch(segment_audios: List[np.ndarray], sample_rate: int, model_n
333
  return [[] for _ in segment_audios], [], 0.0, 0.0
334
 
335
  # Determine inference device.
336
- # In CPU fallback mode, do NOT call model.to("cpu") β€” that requires
337
- # CUDA access (to copy from GPU) and would poison the process if no
338
- # GPU lease is active. Instead, if the model is stuck on GPU, reload
339
- # a fresh copy on CPU.
340
  if is_quota_exhausted() or is_user_forced_cpu():
341
  device = torch.device("cpu")
342
- if next(model.parameters()).device.type != "cpu":
343
- print(f"[PHONEME ASR] CPU fallback but '{model_name}' on GPU β€” reloading fresh on CPU")
344
- with model_device_lock:
345
- if model_name in _cache:
346
- del _cache[model_name]
347
- model, processor = load_phoneme_asr(model_name)
348
- if model is None:
349
- return [[] for _ in segment_audios], [], 0.0, 0.0
350
  else:
351
  device = next(model.parameters()).device
352
 
 
333
  return [[] for _ in segment_audios], [], 0.0, 0.0
334
 
335
  # Determine inference device.
336
+ # In CPU fallback mode, models are already on CPU (parked and reloaded
337
+ # by _park_stale_models in zero_gpu.py before the fallback call).
 
 
338
  if is_quota_exhausted() or is_user_forced_cpu():
339
  device = torch.device("cpu")
 
 
 
 
 
 
 
 
340
  else:
341
  device = next(model.parameters()).device
342
 
src/core/zero_gpu.py CHANGED
@@ -31,6 +31,7 @@ model_device_lock = threading.RLock()
31
  _lease_lock = threading.Lock()
32
  _active_gpu_leases = 0
33
  _models_stale = False # Set True at lease end; drained at next lease start
 
34
 
35
 
36
  try:
@@ -140,6 +141,8 @@ def _drain_stale_models():
140
  if not _models_stale:
141
  return
142
  _models_stale = False
 
 
143
  from ..segmenter.segmenter_model import invalidate_segmenter_cache
144
  from ..alignment.phoneme_asr import invalidate_asr_cache
145
  invalidate_segmenter_cache()
@@ -149,6 +152,31 @@ def _drain_stale_models():
149
  print("[GPU CLEANUP] Drained stale models from previous lease")
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # =========================================================================
153
  # GPU decorator with fallback
154
  # =========================================================================
@@ -165,10 +193,10 @@ def gpu_with_fallback(duration=60):
165
  cleanup) to prevent concurrent threads from moving models mid-inference.
166
 
167
  Error handling strategy:
168
- - Quota exhaustion β†’ CPU fallback (per-user, not process issue)
169
  - Timeout β†’ propagate to caller
170
- - ZeroGPU worker/runtime/CUDA errors β†’ retry on GPU, then propagate
171
- (do NOT silently force CPU; avoids sticky global CPU behavior)
172
  - Unknown non-timeout errors β†’ propagate (avoid hiding real bugs)
173
 
174
  Usage:
@@ -235,6 +263,11 @@ def gpu_with_fallback(duration=60):
235
  match = re.search(r'Try again in (\d+:\d{2}:\d{2})', err_str)
236
  if match:
237
  _request_state.quota_reset_time = match.group(1)
 
 
 
 
 
238
  try:
239
  import gradio as gr
240
  reset_time = get_quota_reset_time()
@@ -254,29 +287,17 @@ def gpu_with_fallback(duration=60):
254
  print(f"[GPU] Timeout error in {func.__name__}: {e}")
255
  raise
256
 
257
- # Worker/runtime init errors are often transient in ZeroGPU.
258
- # Retry a few times on GPU, then propagate (no auto-CPU fallback).
259
  is_worker_error = (
260
  err_title == "ZeroGPU worker error"
261
  or "no cuda gpus are available" in err_lower
262
  or "gpu task aborted" in err_lower
263
  )
264
  if is_worker_error:
265
- import time
266
- max_attempts = 3
267
- last_err: Exception = e
268
- for attempt in range(2, max_attempts + 1):
269
- delay = 0.35 * (attempt - 1)
270
- print(f"[GPU] Worker error in {func.__name__}, retry {attempt}/{max_attempts} in {delay:.2f}s: {last_err}")
271
- time.sleep(delay)
272
- try:
273
- return gpu_func(*args, **kwargs)
274
- except Exception as retry_e:
275
- last_err = retry_e
276
- print(f"[GPU] Worker error persisted after {max_attempts} attempts in {func.__name__}: {last_err}")
277
- raise last_err
278
-
279
- # CUDA runtime errors (non-timeout, non-quota): retry once then propagate.
280
  is_cuda_runtime_error = (
281
  "cuda" in err_lower
282
  or "cudnn" in err_lower
@@ -284,15 +305,11 @@ def gpu_with_fallback(duration=60):
284
  or err_title == "CUDA error"
285
  )
286
  if is_cuda_runtime_error:
287
- print(f"[GPU] CUDA runtime error in {func.__name__}, retrying once: {type(e).__name__}: {e}")
288
  global _models_stale
289
  with _lease_lock:
290
  _models_stale = True
291
- try:
292
- return gpu_func(*args, **kwargs)
293
- except Exception as retry_e:
294
- print(f"[GPU] CUDA runtime error persisted in {func.__name__}: {retry_e}")
295
- raise retry_e
296
 
297
  # Unknown non-timeout errors should propagate so genuine bugs
298
  # are not silently hidden behind CPU fallback.
 
31
  _lease_lock = threading.Lock()
32
  _active_gpu_leases = 0
33
  _models_stale = False # Set True at lease end; drained at next lease start
34
+ _stale_gpu_refs = [] # Prevent GC of GPU models outside a lease
35
 
36
 
37
  try:
 
141
  if not _models_stale:
142
  return
143
  _models_stale = False
144
+ # Release parked GPU models (CUDA destructors safe inside lease)
145
+ _stale_gpu_refs.clear()
146
  from ..segmenter.segmenter_model import invalidate_segmenter_cache
147
  from ..alignment.phoneme_asr import invalidate_asr_cache
148
  invalidate_segmenter_cache()
 
152
  print("[GPU CLEANUP] Drained stale models from previous lease")
153
 
154
 
155
+ def _park_stale_models():
156
+ """Park GPU models before CPU fallback to prevent CUDA destructors outside lease.
157
+
158
+ Keeps references alive in _stale_gpu_refs so GC doesn't trigger CUDA ops.
159
+ Invalidates caches so fresh CPU models are loaded. The parked models get
160
+ properly released inside the next GPU lease via _drain_stale_models().
161
+ """
162
+ from ..segmenter.segmenter_model import _segmenter_cache
163
+ from ..alignment.phoneme_asr import _cache as _asr_cache
164
+
165
+ # Stash GPU model references to prevent GC
166
+ if _segmenter_cache.get("model") is not None:
167
+ _stale_gpu_refs.append(_segmenter_cache["model"])
168
+ for entry in _asr_cache.values():
169
+ if entry.get("model") is not None:
170
+ _stale_gpu_refs.append(entry["model"])
171
+
172
+ # Invalidate caches (refs in _stale_gpu_refs keep models alive)
173
+ from ..segmenter.segmenter_model import invalidate_segmenter_cache
174
+ from ..alignment.phoneme_asr import invalidate_asr_cache
175
+ invalidate_segmenter_cache()
176
+ invalidate_asr_cache()
177
+ print(f"[GPU PARK] Parked {len(_stale_gpu_refs)} model(s) to prevent CUDA GC outside lease")
178
+
179
+
180
  # =========================================================================
181
  # GPU decorator with fallback
182
  # =========================================================================
 
193
  cleanup) to prevent concurrent threads from moving models mid-inference.
194
 
195
  Error handling strategy:
196
+ - Quota exhaustion β†’ park GPU models, CPU fallback (per-user, not process issue)
197
  - Timeout β†’ propagate to caller
198
+ - ZeroGPU worker/CUDA errors β†’ propagate immediately (no retry to avoid
199
+ CUDA state corruption from retries outside a clean lease)
200
  - Unknown non-timeout errors β†’ propagate (avoid hiding real bugs)
201
 
202
  Usage:
 
263
  match = re.search(r'Try again in (\d+:\d{2}:\d{2})', err_str)
264
  if match:
265
  _request_state.quota_reset_time = match.group(1)
266
+ # Park GPU models to prevent CUDA destructors outside lease
267
+ with model_device_lock:
268
+ _park_stale_models()
269
+ with _lease_lock:
270
+ _models_stale = True
271
  try:
272
  import gradio as gr
273
  reset_time = get_quota_reset_time()
 
287
  print(f"[GPU] Timeout error in {func.__name__}: {e}")
288
  raise
289
 
290
+ # Worker/runtime init errors β€” propagate immediately.
 
291
  is_worker_error = (
292
  err_title == "ZeroGPU worker error"
293
  or "no cuda gpus are available" in err_lower
294
  or "gpu task aborted" in err_lower
295
  )
296
  if is_worker_error:
297
+ print(f"[GPU] Worker error in {func.__name__}: {e}")
298
+ raise
299
+
300
+ # CUDA runtime errors (non-timeout, non-quota): mark stale and propagate.
 
 
 
 
 
 
 
 
 
 
 
301
  is_cuda_runtime_error = (
302
  "cuda" in err_lower
303
  or "cudnn" in err_lower
 
305
  or err_title == "CUDA error"
306
  )
307
  if is_cuda_runtime_error:
308
+ print(f"[GPU] CUDA error in {func.__name__}: {e}")
309
  global _models_stale
310
  with _lease_lock:
311
  _models_stale = True
312
+ raise
 
 
 
 
313
 
314
  # Unknown non-timeout errors should propagate so genuine bugs
315
  # are not silently hidden behind CPU fallback.
src/segmenter/vad.py CHANGED
@@ -6,8 +6,8 @@ import numpy as np
6
  import torch
7
 
8
  from .segmenter_aoti import is_aoti_applied
9
- from .segmenter_model import load_segmenter, _log_env_once, _segmenter_cache
10
- from ..core.zero_gpu import is_quota_exhausted, is_user_forced_cpu, model_device_lock
11
 
12
 
13
  def detect_speech_segments(
@@ -52,20 +52,10 @@ def detect_speech_segments(
52
  dtype = next(model.parameters()).dtype
53
 
54
  # Determine inference device.
55
- # In CPU fallback mode, do NOT call model.to("cpu") β€” that requires
56
- # CUDA access (to copy from GPU) and would poison the process if no
57
- # GPU lease is active. Instead, if the model is stuck on GPU, reload
58
- # a fresh copy on CPU.
59
  if is_quota_exhausted() or is_user_forced_cpu():
60
  device = torch.device("cpu")
61
- if next(model.parameters()).device.type != "cpu":
62
- print("[VAD] CPU fallback but model on GPU β€” reloading fresh on CPU")
63
- with model_device_lock:
64
- _segmenter_cache["loaded"] = False
65
- model, processor, _ = load_segmenter()
66
- if model is None:
67
- raise RuntimeError("[VAD] Failed to reload model on CPU")
68
- dtype = next(model.parameters()).dtype
69
  else:
70
  device = next(model.parameters()).device
71
 
 
6
  import torch
7
 
8
  from .segmenter_aoti import is_aoti_applied
9
+ from .segmenter_model import load_segmenter, _log_env_once
10
+ from ..core.zero_gpu import is_quota_exhausted, is_user_forced_cpu
11
 
12
 
13
  def detect_speech_segments(
 
52
  dtype = next(model.parameters()).dtype
53
 
54
  # Determine inference device.
55
+ # In CPU fallback mode, models are already on CPU (parked and reloaded
56
+ # by _park_stale_models in zero_gpu.py before the fallback call).
 
 
57
  if is_quota_exhausted() or is_user_forced_cpu():
58
  device = torch.device("cpu")
 
 
 
 
 
 
 
 
59
  else:
60
  device = next(model.parameters()).device
61