hetchyy Claude Opus 4.6 commited on
Commit
df9144c
·
1 Parent(s): 7711775

Isolate CPU fallback in spawn subprocess to prevent CUDA state poisoning

Browse files

Running torch ops in the main Gradio process during CPU fallback triggers
C-level CUDA runtime queries that partially initialize CUDA state. Since
ZeroGPU uses fork() for GPU workers, all future workers inherit this
corrupted state ("bad fork"), causing permanent "No CUDA GPUs are available"
errors for ALL users until Space restart.

Fix: CPU fallback now runs in a multiprocessing.spawn subprocess with
SPACES_ZERO_GPU="" and CUDA_VISIBLE_DEVICES="" — a clean Python interpreter
with no ZeroGPU patches and no CUDA access. The main process never touches
torch, keeping its CUDA state pristine for future forked workers.

Also converts VAD tensor outputs to numpy for subprocess picklability and
adds _is_in_bad_fork diagnostic logging at GPU worker entry.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

src/core/cpu_subprocess.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Subprocess-isolated CPU inference to prevent CUDA state poisoning.
2
+
3
+ On HuggingFace Spaces with ZeroGPU, the main Gradio process has PyTorch
4
+ monkey-patched (TorchFunctionMode, fake CUDA availability). Running torch
5
+ operations in the main process can trigger C-level CUDA runtime queries
6
+ that partially initialize CUDA state. Since ZeroGPU uses fork() for GPU
7
+ workers, this corrupted state is inherited by ALL future workers, causing
8
+ permanent "No CUDA GPUs are available" errors.
9
+
10
+ Solution: run CPU inference in a spawn-context subprocess. spawn creates
11
+ a clean Python interpreter without inherited CUDA state or ZeroGPU patches.
12
+ """
13
+
14
+ import importlib
15
+ import multiprocessing
16
+ import os
17
+ import sys
18
+ import traceback
19
+
20
+
21
+ def _cpu_worker(func_module, func_name, extra_paths, args, kwargs, result_queue):
22
+ """Worker function for CPU subprocess. Runs in a clean process.
23
+
24
+ Disables ZeroGPU and CUDA so the function runs in a plain CPU PyTorch
25
+ environment with no monkey patches.
26
+ """
27
+ # Add parent's sys.path entries so we can find src/, config, etc.
28
+ for p in extra_paths:
29
+ if p and p not in sys.path:
30
+ sys.path.insert(0, p)
31
+
32
+ # Disable ZeroGPU — prevents spaces package from patching torch
33
+ os.environ["SPACES_ZERO_GPU"] = ""
34
+ # Disable CUDA — guarantees CPU-only execution
35
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
36
+
37
+ try:
38
+ module = importlib.import_module(func_module)
39
+ func = getattr(module, func_name)
40
+ # Unwrap @gpu_with_fallback decorator to call the raw function.
41
+ # functools.wraps sets __wrapped__ on each wrapper layer.
42
+ while hasattr(func, "__wrapped__"):
43
+ func = func.__wrapped__
44
+ result = func(*args, **kwargs)
45
+ result_queue.put(("ok", result))
46
+ except Exception as e:
47
+ tb = traceback.format_exc()
48
+ result_queue.put(("error", (type(e).__name__, str(e), tb)))
49
+
50
+
51
+ def run_in_cpu_subprocess(func, args, kwargs, timeout=600):
52
+ """Run a function in an isolated CPU subprocess.
53
+
54
+ Uses 'spawn' context to create a clean Python interpreter that does
55
+ not inherit the main process's CUDA state or ZeroGPU monkey patches.
56
+
57
+ All args, kwargs, and return values must be picklable (numpy arrays,
58
+ lists, dicts, strings, numbers — no torch tensors or Gradio objects).
59
+
60
+ Args:
61
+ func: The function to call. Must be importable by module + name.
62
+ args: Positional arguments tuple.
63
+ kwargs: Keyword arguments dict.
64
+ timeout: Max seconds to wait (default 600 = 10 min).
65
+
66
+ Returns:
67
+ The function's return value.
68
+
69
+ Raises:
70
+ TimeoutError: If subprocess exceeds timeout.
71
+ RuntimeError: If subprocess fails or exits without result.
72
+ """
73
+ ctx = multiprocessing.get_context("spawn")
74
+ result_queue = ctx.Queue()
75
+
76
+ func_module = func.__module__
77
+ func_name = func.__qualname__
78
+ # Pass sys.path so the subprocess can find all modules (app dir, etc.)
79
+ extra_paths = list(sys.path)
80
+
81
+ print(f"[CPU SUBPROCESS] Spawning for {func_module}.{func_name}")
82
+
83
+ p = ctx.Process(
84
+ target=_cpu_worker,
85
+ args=(func_module, func_name, extra_paths, args, kwargs, result_queue),
86
+ daemon=True,
87
+ )
88
+ p.start()
89
+ p.join(timeout=timeout)
90
+
91
+ if p.is_alive():
92
+ p.kill()
93
+ p.join(timeout=5)
94
+ raise TimeoutError(f"CPU subprocess timed out after {timeout}s")
95
+
96
+ if result_queue.empty():
97
+ raise RuntimeError(
98
+ f"CPU subprocess exited without result (exit code {p.exitcode})"
99
+ )
100
+
101
+ status, payload = result_queue.get_nowait()
102
+ if status == "ok":
103
+ print(f"[CPU SUBPROCESS] {func_name} completed successfully")
104
+ return payload
105
+
106
+ exc_type, exc_msg, exc_tb = payload
107
+ print(f"[CPU SUBPROCESS] Error traceback:\n{exc_tb}")
108
+ raise RuntimeError(f"CPU subprocess error ({exc_type}): {exc_msg}")
src/core/zero_gpu.py CHANGED
@@ -119,6 +119,16 @@ def force_cpu_mode():
119
  # Model cleanup helpers
120
  # =========================================================================
121
 
 
 
 
 
 
 
 
 
 
 
122
  def _cleanup_after_gpu():
123
  """Cleanup that runs at the END of a GPU lease.
124
 
@@ -190,6 +200,7 @@ def gpu_with_fallback(duration=60):
190
  # to prevent concurrent threads from moving models mid-inference.
191
  @wraps(func)
192
  def func_with_cleanup(*args, **kwargs):
 
193
  _enter_gpu_lease()
194
  with model_device_lock:
195
  try:
@@ -214,10 +225,21 @@ def gpu_with_fallback(duration=60):
214
  @wraps(func)
215
  def wrapper(*args, **kwargs):
216
  global _models_stale
217
- # If user explicitly chose CPU mode, skip GPU entirely
 
 
 
218
  if is_user_forced_cpu():
219
- print("[CPU] User selected CPU mode")
220
- return func(*args, **kwargs)
 
 
 
 
 
 
 
 
221
 
222
  # Try GPU
223
  try:
 
119
  # Model cleanup helpers
120
  # =========================================================================
121
 
122
+ def _check_cuda_fork_state(label: str):
123
+ """Log whether the current process is in a bad-fork CUDA state."""
124
+ try:
125
+ import torch
126
+ bad_fork = torch.cuda._is_in_bad_fork()
127
+ print(f"[CUDA DIAG] {label}: _is_in_bad_fork={bad_fork}")
128
+ except Exception as e:
129
+ print(f"[CUDA DIAG] {label}: check failed: {e}")
130
+
131
+
132
  def _cleanup_after_gpu():
133
  """Cleanup that runs at the END of a GPU lease.
134
 
 
200
  # to prevent concurrent threads from moving models mid-inference.
201
  @wraps(func)
202
  def func_with_cleanup(*args, **kwargs):
203
+ _check_cuda_fork_state(f"GPU worker entry ({func.__name__})")
204
  _enter_gpu_lease()
205
  with model_device_lock:
206
  try:
 
225
  @wraps(func)
226
  def wrapper(*args, **kwargs):
227
  global _models_stale
228
+ # If user explicitly chose CPU mode, skip GPU entirely.
229
+ # On ZeroGPU, run in an isolated subprocess to prevent CUDA
230
+ # state poisoning — torch ops in the main process corrupt the
231
+ # C-level CUDA runtime, making all future forked workers fail.
232
  if is_user_forced_cpu():
233
+ if ZERO_GPU_AVAILABLE:
234
+ _check_cuda_fork_state(f"before CPU subprocess ({func.__name__})")
235
+ print(f"[CPU] Running {func.__name__} in isolated subprocess")
236
+ from .cpu_subprocess import run_in_cpu_subprocess
237
+ result = run_in_cpu_subprocess(func, args, kwargs)
238
+ _check_cuda_fork_state(f"after CPU subprocess ({func.__name__})")
239
+ return result
240
+ else:
241
+ print("[CPU] User selected CPU mode (local dev)")
242
+ return func(*args, **kwargs)
243
 
244
  # Try GPU
245
  try:
src/pipeline.py CHANGED
@@ -679,9 +679,16 @@ def resegment_audio(
679
  print("[STAGE] Resegmenting...")
680
 
681
  # Re-clean speech intervals with new parameters (CPU, no GPU needed)
 
 
 
 
 
 
 
682
  from recitations_segmenter import clean_speech_intervals
683
  clean_out = clean_speech_intervals(
684
- cached_speech_intervals,
685
  cached_is_complete,
686
  min_silence_duration_ms=int(min_silence_ms),
687
  min_speech_duration_ms=int(min_speech_ms),
 
679
  print("[STAGE] Resegmenting...")
680
 
681
  # Re-clean speech intervals with new parameters (CPU, no GPU needed)
682
+ # Convert numpy→torch if needed (VAD returns numpy for picklability)
683
+ import torch as _torch
684
+ _intervals_tensor = (
685
+ _torch.from_numpy(cached_speech_intervals)
686
+ if isinstance(cached_speech_intervals, np.ndarray)
687
+ else cached_speech_intervals
688
+ )
689
  from recitations_segmenter import clean_speech_intervals
690
  clean_out = clean_speech_intervals(
691
+ _intervals_tensor,
692
  cached_is_complete,
693
  min_silence_duration_ms=int(min_silence_ms),
694
  min_speech_duration_ms=int(min_speech_ms),
src/segmenter/vad.py CHANGED
@@ -92,11 +92,12 @@ def detect_speech_segments(
92
 
93
  raw_speech_intervals = outputs[0].speech_intervals
94
  raw_is_complete = outputs[0].is_complete
95
- # Detach from GPU to prevent CUDA tensor refs escaping the lease
96
- if hasattr(raw_speech_intervals, 'cpu'):
97
- raw_speech_intervals = raw_speech_intervals.cpu()
98
- if hasattr(raw_is_complete, 'cpu'):
99
- raw_is_complete = raw_is_complete.cpu()
 
100
 
101
  return [(start, end) for start, end in intervals], {"model_load_time": model_load_time, "inference_time": inference_time}, raw_speech_intervals, raw_is_complete
102
 
 
92
 
93
  raw_speech_intervals = outputs[0].speech_intervals
94
  raw_is_complete = outputs[0].is_complete
95
+ # Convert to numpy: prevents CUDA tensor refs escaping the lease
96
+ # and ensures picklability for CPU subprocess isolation
97
+ if hasattr(raw_speech_intervals, 'detach'):
98
+ raw_speech_intervals = raw_speech_intervals.detach().cpu().numpy()
99
+ if hasattr(raw_is_complete, 'detach'):
100
+ raw_is_complete = raw_is_complete.detach().cpu().numpy()
101
 
102
  return [(start, end) for start, end in intervals], {"model_load_time": model_load_time, "inference_time": inference_time}, raw_speech_intervals, raw_is_complete
103