hetchyy Claude Opus 4.6 commited on
Commit
47224bd
·
1 Parent(s): bed2244

Fix race condition where concurrent quota exhaustion permanently breaks GPU

Browse files

GPU quota state was a process-global variable shared across request threads.
When an unlogged user exhausted ZeroGPU quota, a concurrent request's
reset_quota_flag() could clear the flag before the CPU fallback path checked
it, causing model.to("cuda") outside a GPU context — permanently poisoning
CUDA init for all users until space restart.

Replace globals with threading.local() for per-request isolation and add
RuntimeError safety net in ensure_models_on_gpu to prevent CUDA init from
ever escaping. Also add row_group_size to parquet writes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

src/core/usage_logger.py CHANGED
@@ -183,7 +183,12 @@ if _HAS_DEPS:
183
  try:
184
  import tempfile
185
  archive = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False)
186
- pq.write_table(table, archive.name)
 
 
 
 
 
187
  self.api.upload_file(
188
  repo_id=self.repo_id,
189
  repo_type=self.repo_type,
 
183
  try:
184
  import tempfile
185
  archive = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False)
186
+ pq.write_table(
187
+ table,
188
+ archive.name,
189
+ row_group_size=1,
190
+ write_page_index=True,
191
+ )
192
  self.api.upload_file(
193
  repo_id=self.repo_id,
194
  repo_type=self.repo_type,
src/core/zero_gpu.py CHANGED
@@ -4,6 +4,7 @@ local or non-ZeroGPU environments.
4
  """
5
 
6
  import re
 
7
  from typing import Callable, TypeVar
8
  from functools import wraps
9
 
@@ -12,10 +13,8 @@ T = TypeVar("T", bound=Callable)
12
  # Default values in case the spaces package is unavailable (e.g., local runs).
13
  ZERO_GPU_AVAILABLE = False
14
 
15
- # Track whether we've fallen back to CPU due to quota exhaustion
16
- _gpu_quota_exhausted = False
17
- _quota_reset_time = None # e.g. "13:53:59"
18
- _user_forced_cpu = False
19
 
20
  try:
21
  import spaces # type: ignore
@@ -39,32 +38,30 @@ except Exception:
39
 
40
 
41
  def is_quota_exhausted() -> bool:
42
- """Check if GPU quota has been exhausted this session."""
43
- return _gpu_quota_exhausted
44
 
45
 
46
  def is_user_forced_cpu() -> bool:
47
- """Check if the user manually selected CPU mode."""
48
- return _user_forced_cpu
49
 
50
 
51
  def get_quota_reset_time() -> str | None:
52
  """Return the quota reset time string (e.g. '13:53:59'), or None."""
53
- return _quota_reset_time
54
 
55
 
56
  def reset_quota_flag():
57
- """Reset the quota exhausted flag (e.g., after quota resets)."""
58
- global _gpu_quota_exhausted, _quota_reset_time, _user_forced_cpu
59
- _gpu_quota_exhausted = False
60
- _quota_reset_time = None
61
- _user_forced_cpu = False
62
 
63
 
64
  def force_cpu_mode():
65
- """Force all GPU-decorated functions to skip GPU and run on CPU."""
66
- global _user_forced_cpu
67
- _user_forced_cpu = True
68
  _move_models_to_cpu()
69
 
70
 
@@ -100,15 +97,13 @@ def gpu_with_fallback(duration=60):
100
 
101
  @wraps(func)
102
  def wrapper(*args, **kwargs):
103
- global _gpu_quota_exhausted, _quota_reset_time
104
-
105
  # If user explicitly chose CPU mode, skip GPU entirely
106
- if _user_forced_cpu:
107
  print("[CPU] User selected CPU mode")
108
  return func(*args, **kwargs)
109
 
110
  # If quota already exhausted, go straight to CPU
111
- if _gpu_quota_exhausted:
112
  print("[GPU] Quota exhausted, using CPU fallback")
113
  _move_models_to_cpu()
114
  return func(*args, **kwargs)
@@ -124,15 +119,16 @@ def gpu_with_fallback(duration=60):
124
 
125
  if is_quota_error:
126
  print(f"[GPU] Quota exceeded, falling back to CPU: {e}")
127
- _gpu_quota_exhausted = True
128
  # Parse reset time from message like "Try again in 13:53:59"
129
  match = re.search(r'Try again in (\d+:\d{2}:\d{2})', str(e))
130
  if match:
131
- _quota_reset_time = match.group(1)
132
  # Show immediate toast notification
133
  try:
134
  import gradio as gr
135
- reset_msg = f" Resets in {_quota_reset_time}." if _quota_reset_time else ""
 
136
  gr.Warning(f"GPU quota reached — switching to CPU (slower).{reset_msg}")
137
  except Exception:
138
  pass # Not in a Gradio context (e.g., CLI usage)
 
4
  """
5
 
6
  import re
7
+ import threading
8
  from typing import Callable, TypeVar
9
  from functools import wraps
10
 
 
13
  # Default values in case the spaces package is unavailable (e.g., local runs).
14
  ZERO_GPU_AVAILABLE = False
15
 
16
+ # Per-thread (per-request) GPU state so concurrent requests don't interfere
17
+ _request_state = threading.local()
 
 
18
 
19
  try:
20
  import spaces # type: ignore
 
38
 
39
 
40
  def is_quota_exhausted() -> bool:
41
+ """Check if GPU quota has been exhausted for this request's thread."""
42
+ return getattr(_request_state, 'gpu_quota_exhausted', False)
43
 
44
 
45
  def is_user_forced_cpu() -> bool:
46
+ """Check if the user manually selected CPU mode for this request."""
47
+ return getattr(_request_state, 'user_forced_cpu', False)
48
 
49
 
50
  def get_quota_reset_time() -> str | None:
51
  """Return the quota reset time string (e.g. '13:53:59'), or None."""
52
+ return getattr(_request_state, 'quota_reset_time', None)
53
 
54
 
55
  def reset_quota_flag():
56
+ """Reset the quota exhausted flag for this request's thread."""
57
+ _request_state.gpu_quota_exhausted = False
58
+ _request_state.quota_reset_time = None
59
+ _request_state.user_forced_cpu = False
 
60
 
61
 
62
  def force_cpu_mode():
63
+ """Force GPU-decorated functions to skip GPU and run on CPU for this request."""
64
+ _request_state.user_forced_cpu = True
 
65
  _move_models_to_cpu()
66
 
67
 
 
97
 
98
  @wraps(func)
99
  def wrapper(*args, **kwargs):
 
 
100
  # If user explicitly chose CPU mode, skip GPU entirely
101
+ if is_user_forced_cpu():
102
  print("[CPU] User selected CPU mode")
103
  return func(*args, **kwargs)
104
 
105
  # If quota already exhausted, go straight to CPU
106
+ if is_quota_exhausted():
107
  print("[GPU] Quota exhausted, using CPU fallback")
108
  _move_models_to_cpu()
109
  return func(*args, **kwargs)
 
119
 
120
  if is_quota_error:
121
  print(f"[GPU] Quota exceeded, falling back to CPU: {e}")
122
+ _request_state.gpu_quota_exhausted = True
123
  # Parse reset time from message like "Try again in 13:53:59"
124
  match = re.search(r'Try again in (\d+:\d{2}:\d{2})', str(e))
125
  if match:
126
+ _request_state.quota_reset_time = match.group(1)
127
  # Show immediate toast notification
128
  try:
129
  import gradio as gr
130
+ reset_time = get_quota_reset_time()
131
+ reset_msg = f" Resets in {reset_time}." if reset_time else ""
132
  gr.Warning(f"GPU quota reached — switching to CPU (slower).{reset_msg}")
133
  except Exception:
134
  pass # Not in a Gradio context (e.g., CLI usage)
src/segmenter/segmenter_model.py CHANGED
@@ -74,19 +74,24 @@ def ensure_models_on_gpu(asr_model_name=None):
74
  dtype = _TORCH_DTYPE
75
  move_start = time.time()
76
 
77
- # Move segmenter to GPU
78
- if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
79
- model = _segmenter_cache["model"]
80
- if next(model.parameters()).device.type != "cuda":
81
- print("[GPU] Moving segmenter to CUDA...")
82
- model.to(device, dtype=dtype)
83
- _segmenter_cache["model"] = model
84
- _segmenter_cache["device"] = "cuda"
85
- print("[GPU] Segmenter on CUDA")
86
-
87
- # Move phoneme ASR to GPU (only the requested model)
88
- if asr_model_name is not None:
89
- move_phoneme_asr_to_gpu(asr_model_name)
 
 
 
 
 
90
 
91
  return time.time() - move_start
92
 
 
74
  dtype = _TORCH_DTYPE
75
  move_start = time.time()
76
 
77
+ try:
78
+ # Move segmenter to GPU
79
+ if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
80
+ model = _segmenter_cache["model"]
81
+ if next(model.parameters()).device.type != "cuda":
82
+ print("[GPU] Moving segmenter to CUDA...")
83
+ model.to(device, dtype=dtype)
84
+ _segmenter_cache["model"] = model
85
+ _segmenter_cache["device"] = "cuda"
86
+ print("[GPU] Segmenter on CUDA")
87
+
88
+ # Move phoneme ASR to GPU (only the requested model)
89
+ if asr_model_name is not None:
90
+ move_phoneme_asr_to_gpu(asr_model_name)
91
+ except RuntimeError as e:
92
+ # Prevent CUDA init outside GPU context from poisoning the process
93
+ print(f"[GPU] CUDA move failed, staying on CPU: {e}")
94
+ return 0.0
95
 
96
  return time.time() - move_start
97