hetchyy commited on
Commit
d46a954
·
verified ·
1 Parent(s): 3955329

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -54,6 +54,29 @@ else:
54
 
55
  from src.ui.interface import build_interface
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # =============================================================================
58
  # Module-level demo for Gradio hot-reload (`gradio app.py`)
59
  # =============================================================================
 
54
 
55
  from src.ui.interface import build_interface
56
 
57
+ # =============================================================================
58
+ # Persistent CPU worker pool — spawn BEFORE any GPU use if enabled.
59
+ # This keeps the workers free of any inherited CUDA/ZeroGPU state.
60
+ # =============================================================================
61
+ try:
62
+ from config import (
63
+ CPU_STRATEGY as _CPU_STRATEGY,
64
+ CPU_WORKER_MODE as _CPU_WORKER_MODE,
65
+ CPU_SUBPROCESS_CONCURRENCY as _CPU_SUBPROCESS_CONCURRENCY,
66
+ CPU_POOL_PRELOAD_LARGE as _CPU_POOL_PRELOAD_LARGE,
67
+ )
68
+ from src.core.zero_gpu import IS_CPU_WORKER as _IS_CPU_WORKER
69
+ if (
70
+ _CPU_STRATEGY == "subprocess"
71
+ and _CPU_WORKER_MODE == "persistent"
72
+ and not _IS_CPU_WORKER
73
+ ):
74
+ print(f"[APP] Bootstrapping persistent CPU pool: {_CPU_SUBPROCESS_CONCURRENCY} worker(s), preload_large={_CPU_POOL_PRELOAD_LARGE}")
75
+ from src.core.cpu_worker_pool import start_pool as _start_pool
76
+ _start_pool(_CPU_SUBPROCESS_CONCURRENCY, preload_large=_CPU_POOL_PRELOAD_LARGE)
77
+ except Exception as _e:
78
+ print(f"[APP] Persistent CPU pool bootstrap failed (non-fatal): {_e}")
79
+
80
  # =============================================================================
81
  # Module-level demo for Gradio hot-reload (`gradio app.py`)
82
  # =============================================================================
config.py CHANGED
@@ -48,14 +48,21 @@ SESSION_EXPIRY_SECONDS = 3600*5 # 5 hours — matches DELETE_CACHE_
48
  CPU_STRATEGY = os.environ.get("CPU_STRATEGY", "subprocess").lower()
49
 
50
  # Max seconds a subprocess CPU job can run before SIGKILL (used by "subprocess" and "both" strategies).
51
- CPU_SUBPROCESS_TIMEOUT = int(os.environ.get("CPU_SUBPROCESS_TIMEOUT", str(3600 * 4)))
52
 
53
- # Max concurrent CPU subprocesses on the main Space. Each subprocess loads its
54
- # own copy of the VAD + ASR models (~3.6 GB RAM). On zero-a10g with 48 GB RAM
55
- # and 8 vCPU we can safely run 2–3; pushing higher risks OOM and CPU thrash
56
- # that would also slow GPU dispatches on main.
57
  CPU_SUBPROCESS_CONCURRENCY = int(os.environ.get("CPU_SUBPROCESS_CONCURRENCY", "2"))
58
 
 
 
 
 
 
 
 
 
 
 
59
  # Model dtype for CPU inference.
60
  # "bfloat16" — default. Routes attention through PyTorch's chunked CPU flash
61
  # kernel (`_scaled_dot_product_flash_attention_for_cpu`), which
@@ -171,12 +178,7 @@ INFERENCE_BATCH_SIZE = 32 # Fixed segments per batch (used when BATCHING_ST
171
 
172
  # Dynamic batching constraints
173
  MAX_BATCH_SECONDS = 600 # GPU: max total audio seconds per batch (sum of durations)
174
- # CPU: tighter cap. SDPA materialises the QK^T tensor per encoder layer; once
175
- # (batch * heads * seq^2 * 2B) exceeds the L3 cache (~32 MB on zero-a10g Xeon),
176
- # every attention layer becomes DRAM-bound instead of cache-bound — on a 22 min
177
- # m4a we saw one batch (59 segs × 12.5 s) hit a ~24× slowdown vs its neighbours.
178
- # 300 s keeps QK^T comfortably under L3 at realistic (batch, seq) combinations.
179
- MAX_BATCH_SECONDS_CPU = 300
180
  MAX_PAD_WASTE = 0.2 # Max fraction of padded tensor that is wasted (0=no waste, 1=all waste)
181
  MIN_BATCH_SIZE = 8 # Minimum segments per batch (prevents underutilization)
182
 
 
48
  CPU_STRATEGY = os.environ.get("CPU_STRATEGY", "subprocess").lower()
49
 
50
  # Max seconds a subprocess CPU job can run before SIGKILL (used by "subprocess" and "both" strategies).
51
+ CPU_SUBPROCESS_TIMEOUT = int(os.environ.get("CPU_SUBPROCESS_TIMEOUT", str(3600 * 2)))
52
 
53
+ # Max concurrent CPU subprocesses on the main Space.
 
 
 
54
  CPU_SUBPROCESS_CONCURRENCY = int(os.environ.get("CPU_SUBPROCESS_CONCURRENCY", "2"))
55
 
56
+ # CPU_WORKER_MODE — when CPU_STRATEGY="subprocess", chooses between:
57
+ # "spawn" — legacy: fork a fresh subprocess per request (cpu_subprocess.py).
58
+ # "persistent" — new: route to a pool of long-lived workers (cpu_worker_pool.py).
59
+ # Semaphore capacity stays = CPU_SUBPROCESS_CONCURRENCY either way.
60
+ CPU_WORKER_MODE = os.environ.get("CPU_WORKER_MODE", "persistent").lower()
61
+
62
+ # Whether the persistent pool preloads ASR Large at boot. If False, Large is
63
+ # loaded on-demand inside the worker and cached there.
64
+ CPU_POOL_PRELOAD_LARGE = os.environ.get("CPU_POOL_PRELOAD_LARGE", "1") == "1"
65
+
66
  # Model dtype for CPU inference.
67
  # "bfloat16" — default. Routes attention through PyTorch's chunked CPU flash
68
  # kernel (`_scaled_dot_product_flash_attention_for_cpu`), which
 
178
 
179
  # Dynamic batching constraints
180
  MAX_BATCH_SECONDS = 600 # GPU: max total audio seconds per batch (sum of durations)
181
+ MAX_BATCH_SECONDS_CPU = 300 # CPU: tighter cap. SDPA materialises the QK^T tensor per encoder layer
 
 
 
 
 
182
  MAX_PAD_WASTE = 0.2 # Max fraction of padded tensor that is wasted (0=no waste, 1=all waste)
183
  MIN_BATCH_SIZE = 8 # Minimum segments per batch (prevents underutilization)
184
 
src/api/session_api.py CHANGED
@@ -957,6 +957,93 @@ def pool_status(hf_token):
957
  }
958
 
959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  # ---------------------------------------------------------------------------
961
  # Hidden debug endpoint
962
  # ---------------------------------------------------------------------------
 
957
  }
958
 
959
 
960
+ def cpu_pool_kill(hf_token, worker_id):
961
+ """Kill a persistent worker for crash-recovery testing. HF-token-gated."""
962
+ space_token = os.environ.get("HF_TOKEN", "")
963
+ if not hf_token or (space_token and hf_token != space_token):
964
+ return {"error": "Unauthorized"}
965
+ try:
966
+ from src.core.cpu_worker_pool import _get_pool
967
+ import signal as _signal
968
+ import time as _time
969
+ p = _get_pool()
970
+ wid = int(worker_id)
971
+ h = p.workers[wid]
972
+ pid = h.pid
973
+ was_alive = h.process is not None and h.process.is_alive()
974
+ try:
975
+ os.kill(pid, _signal.SIGKILL)
976
+ sent = True
977
+ send_err = None
978
+ except Exception as ke:
979
+ sent = False
980
+ send_err = str(ke)
981
+ # give OS a moment to reap
982
+ _time.sleep(0.3)
983
+ alive_after = h.process is not None and h.process.is_alive()
984
+ return {
985
+ "worker_id": wid,
986
+ "pid": pid,
987
+ "was_alive": was_alive,
988
+ "kill_sent": sent,
989
+ "send_err": send_err,
990
+ "alive_after": alive_after,
991
+ }
992
+ except Exception as e:
993
+ return {"error": str(e)}
994
+
995
+
996
+ def cpu_pool_status(hf_token):
997
+ """Return persistent CPU worker pool state. HF-token-gated.
998
+
999
+ Prototype diagnostic: shows per-worker boot snapshots, load times, pids,
1000
+ live RSS, and job counts. Safe to call on Spaces where CPU_WORKER_MODE
1001
+ is not 'persistent' — just returns `started=False`.
1002
+ """
1003
+ space_token = os.environ.get("HF_TOKEN", "")
1004
+ if not hf_token or (space_token and hf_token != space_token):
1005
+ return {"error": "Unauthorized"}
1006
+
1007
+ try:
1008
+ from src.core.cpu_worker_pool import is_started, stats as pool_stats, probe_rss
1009
+ except Exception as e:
1010
+ return {"error": f"pool import failed: {e}"}
1011
+
1012
+ if not is_started():
1013
+ return {"started": False, "note": "CPU_WORKER_MODE != persistent or pool not yet bootstrapped"}
1014
+
1015
+ s = pool_stats()
1016
+ # Augment with live RSS probe per worker
1017
+ for w in s.get("workers", []):
1018
+ try:
1019
+ w["rss_now"] = probe_rss(w["id"])
1020
+ except Exception as e:
1021
+ w["rss_now_error"] = str(e)
1022
+ # Include main process RSS
1023
+ try:
1024
+ import psutil as _ps
1025
+ s["main_rss"] = _ps.Process(os.getpid()).memory_info().rss
1026
+ vm = _ps.virtual_memory()
1027
+ s["host_mem"] = {"total": vm.total, "available": vm.available, "used": vm.used, "percent": vm.percent}
1028
+ except Exception as e:
1029
+ s["main_rss_error"] = str(e)
1030
+ # Probe cgroup (container) memory limit — authoritative Space budget.
1031
+ cgroup = {}
1032
+ for path in ("/sys/fs/cgroup/memory.max", "/sys/fs/cgroup/memory/memory.limit_in_bytes"):
1033
+ try:
1034
+ with open(path) as _f:
1035
+ cgroup[path] = _f.read().strip()
1036
+ except Exception as e:
1037
+ cgroup[path] = f"err: {e}"
1038
+ try:
1039
+ with open("/sys/fs/cgroup/memory.current") as _f:
1040
+ cgroup["memory.current"] = _f.read().strip()
1041
+ except Exception:
1042
+ pass
1043
+ s["cgroup"] = cgroup
1044
+ return s
1045
+
1046
+
1047
  # ---------------------------------------------------------------------------
1048
  # Hidden debug endpoint
1049
  # ---------------------------------------------------------------------------
src/core/cpu_worker_pool.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Persistent CPU worker pool — prototype.
2
+
3
+ Replaces the spawn-per-request path (cpu_subprocess.py) with long-lived
4
+ workers that preload VAD + ASR (Base) once and stay ready. ASR Large is
5
+ loaded on-demand in each worker (and cached there) so steady-state RAM
6
+ stays predictable.
7
+
8
+ Gated behind env var CPU_WORKER_MODE=persistent. CPU_WORKER_MODE=spawn
9
+ (default) keeps the existing spawn-per-request behavior.
10
+
11
+ Semaphore capacity = CPU_SUBPROCESS_CONCURRENCY (shared with spawn path).
12
+ A free-worker queue gives O(1) idle-worker pickup and guarantees at most
13
+ one concurrent job per worker.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import importlib
19
+ import multiprocessing as mp
20
+ import os
21
+ import queue as queue_mod
22
+ import signal
23
+ import sys
24
+ import threading
25
+ import time
26
+ import traceback
27
+ from dataclasses import dataclass, field
28
+ from typing import Any, Optional
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Worker side
33
+ # ---------------------------------------------------------------------------
34
+
35
+
36
+ def _worker_loop(
37
+ worker_id: int,
38
+ extra_paths: list,
39
+ req_q: mp.Queue,
40
+ res_q: mp.Queue,
41
+ ready_ev,
42
+ preload_large: bool,
43
+ ):
44
+ """Long-lived worker body. Runs in a spawn-context process.
45
+
46
+ Steps:
47
+ 1. Env hygiene — hide CUDA, disable ZeroGPU patches.
48
+ 2. Import project modules, call force_cpu_mode().
49
+ 3. Preload VAD + ASR Base (Large optional).
50
+ 4. Signal `ready_ev`.
51
+ 5. Loop: pull task → execute → push result.
52
+
53
+ Tasks are pickled tuples: (task_id, kind, payload)
54
+ kind="run": payload=(func_module, func_name, args, kwargs)
55
+ kind="rss": payload=None (return rss bytes)
56
+ kind="load_large": payload=None (preload ASR Large if not cached)
57
+ kind="shutdown": payload=None (exit loop)
58
+ """
59
+ # ---- Env hygiene BEFORE any torch import ----------------------------
60
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
61
+ os.environ["SPACES_ZERO_GPU"] = ""
62
+ # Suppress the HF download progress bars — the parent app sets this but
63
+ # the spawned child inherits only env, not module state.
64
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
65
+
66
+ # Restore sys.path from parent so src/ is importable
67
+ for p in extra_paths:
68
+ if p and p not in sys.path:
69
+ sys.path.insert(0, p)
70
+
71
+ # Helpful process label for ps/htop
72
+ try:
73
+ import setproctitle # type: ignore
74
+ setproctitle.setproctitle(f"cpu-worker-{worker_id}")
75
+ except Exception:
76
+ pass
77
+
78
+ log = lambda msg: print(f"[CPU-POOL/W{worker_id}] {msg}", flush=True)
79
+
80
+ # ---- RSS probe ------------------------------------------------------
81
+ def _rss_bytes() -> int:
82
+ try:
83
+ import psutil # type: ignore
84
+ return psutil.Process(os.getpid()).memory_info().rss
85
+ except Exception:
86
+ try:
87
+ with open(f"/proc/{os.getpid()}/status") as f:
88
+ for line in f:
89
+ if line.startswith("VmRSS:"):
90
+ return int(line.split()[1]) * 1024
91
+ except Exception:
92
+ return 0
93
+ return 0
94
+
95
+ snapshots = {"start": _rss_bytes()}
96
+ t0 = time.time()
97
+ log(f"booted pid={os.getpid()} rss={snapshots['start']/1e6:.1f}MB")
98
+
99
+ # ---- Imports + force CPU mode --------------------------------------
100
+ try:
101
+ from src.core.zero_gpu import force_cpu_mode
102
+ force_cpu_mode()
103
+ snapshots["after_imports"] = _rss_bytes()
104
+ log(f"imports done +{(time.time()-t0):.1f}s rss={snapshots['after_imports']/1e6:.1f}MB")
105
+ except Exception as e:
106
+ log(f"FATAL imports: {e}\n{traceback.format_exc()}")
107
+ res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
108
+ return
109
+
110
+ load_times = {}
111
+
112
+ # ---- Preload VAD ---------------------------------------------------
113
+ try:
114
+ t = time.time()
115
+ from src.segmenter.segmenter_model import load_segmenter
116
+ load_segmenter()
117
+ load_times["vad"] = time.time() - t
118
+ snapshots["after_vad"] = _rss_bytes()
119
+ log(f"VAD loaded in {load_times['vad']:.2f}s rss={snapshots['after_vad']/1e6:.1f}MB")
120
+ except Exception as e:
121
+ log(f"VAD load failed: {e}")
122
+ res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
123
+ return
124
+
125
+ # ---- Preload ASR Base ---------------------------------------------
126
+ try:
127
+ t = time.time()
128
+ from src.alignment.phoneme_asr import load_phoneme_asr
129
+ load_phoneme_asr("Base")
130
+ load_times["asr_base"] = time.time() - t
131
+ snapshots["after_asr_base"] = _rss_bytes()
132
+ log(f"ASR Base loaded in {load_times['asr_base']:.2f}s rss={snapshots['after_asr_base']/1e6:.1f}MB")
133
+ except Exception as e:
134
+ log(f"ASR Base load failed: {e}")
135
+ res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
136
+ return
137
+
138
+ # ---- Preload caches (ngram index, phoneme chapters) ----------------
139
+ try:
140
+ t = time.time()
141
+ from src.alignment.ngram_index import get_ngram_index
142
+ from src.alignment.phoneme_matcher_cache import preload_all_chapters
143
+ get_ngram_index()
144
+ preload_all_chapters()
145
+ load_times["caches"] = time.time() - t
146
+ snapshots["after_caches"] = _rss_bytes()
147
+ log(f"caches loaded in {load_times['caches']:.2f}s rss={snapshots['after_caches']/1e6:.1f}MB")
148
+ except Exception as e:
149
+ log(f"caches load failed (non-fatal): {e}")
150
+
151
+ # ---- Optionally preload ASR Large ---------------------------------
152
+ if preload_large:
153
+ try:
154
+ t = time.time()
155
+ from src.alignment.phoneme_asr import load_phoneme_asr
156
+ load_phoneme_asr("Large")
157
+ load_times["asr_large"] = time.time() - t
158
+ snapshots["after_asr_large"] = _rss_bytes()
159
+ log(f"ASR Large loaded in {load_times['asr_large']:.2f}s rss={snapshots['after_asr_large']/1e6:.1f}MB")
160
+ except Exception as e:
161
+ log(f"ASR Large preload failed: {e}")
162
+
163
+ # ---- Warm up resampler --------------------------------------------
164
+ try:
165
+ import numpy as np, librosa
166
+ from config import RESAMPLE_TYPE
167
+ _ = librosa.resample(np.zeros(1600, dtype=np.float32),
168
+ orig_sr=44100, target_sr=16000, res_type=RESAMPLE_TYPE)
169
+ except Exception:
170
+ pass
171
+
172
+ snapshots["ready"] = _rss_bytes()
173
+ total_boot = time.time() - t0
174
+ log(f"READY in {total_boot:.2f}s, final rss={snapshots['ready']/1e6:.1f}MB")
175
+
176
+ # Signal parent that this worker booted successfully
177
+ res_q.put(("__ready__", "ok", {
178
+ "worker_id": worker_id,
179
+ "pid": os.getpid(),
180
+ "snapshots": snapshots,
181
+ "load_times": load_times,
182
+ "boot_time": total_boot,
183
+ }))
184
+ ready_ev.set()
185
+
186
+ # ---- Main loop -----------------------------------------------------
187
+ while True:
188
+ try:
189
+ item = req_q.get()
190
+ except (EOFError, OSError, KeyboardInterrupt):
191
+ break
192
+ if item is None:
193
+ break
194
+ task_id, kind, payload = item
195
+ try:
196
+ if kind == "shutdown":
197
+ break
198
+ elif kind == "rss":
199
+ res_q.put((task_id, "ok", _rss_bytes()))
200
+ continue
201
+ elif kind == "load_large":
202
+ try:
203
+ from src.alignment.phoneme_asr import load_phoneme_asr
204
+ t = time.time()
205
+ load_phoneme_asr("Large")
206
+ res_q.put((task_id, "ok", {"load_time": time.time() - t, "rss": _rss_bytes()}))
207
+ except Exception as e:
208
+ res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
209
+ continue
210
+ elif kind == "run":
211
+ func_module, func_name, args, kwargs = payload
212
+ try:
213
+ module = importlib.import_module(func_module)
214
+ func = getattr(module, func_name)
215
+ while hasattr(func, "__wrapped__"):
216
+ func = func.__wrapped__
217
+ result = func(*args, **kwargs)
218
+ res_q.put((task_id, "ok", result))
219
+ except Exception as e:
220
+ res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
221
+ continue
222
+ else:
223
+ res_q.put((task_id, "error", ("ValueError", f"unknown kind {kind!r}", "")))
224
+ except Exception as e:
225
+ # Catch-all so the loop survives
226
+ res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
227
+
228
+ log("exiting cleanly")
229
+
230
+
231
+ # ---------------------------------------------------------------------------
232
+ # Parent side
233
+ # ---------------------------------------------------------------------------
234
+
235
+
236
+ @dataclass
237
+ class _WorkerHandle:
238
+ worker_id: int
239
+ process: Optional[Any] = None
240
+ req_q: Optional[Any] = None
241
+ res_q: Optional[Any] = None
242
+ ready_ev: Optional[Any] = None
243
+ snapshots: dict = field(default_factory=dict)
244
+ load_times: dict = field(default_factory=dict)
245
+ boot_time: float = 0.0
246
+ pid: Optional[int] = None
247
+ total_jobs: int = 0
248
+ lock: threading.Lock = field(default_factory=threading.Lock)
249
+
250
+
251
+ class _Pool:
252
+ def __init__(self):
253
+ self.ctx = mp.get_context("spawn")
254
+ self.workers: list[_WorkerHandle] = []
255
+ self.free_q: "queue_mod.Queue[int]" = queue_mod.Queue()
256
+ self._started = False
257
+ self._lock = threading.Lock()
258
+ self._task_counter = 0
259
+ self._preload_large = False
260
+ self._extra_paths: list[str] = []
261
+
262
+ # ---- lifecycle -------------------------------------------------------
263
+
264
+ def start(self, n_workers: int, preload_large: bool = False, boot_timeout: float = 600.0):
265
+ with self._lock:
266
+ if self._started:
267
+ return
268
+ self._started = True
269
+ self._preload_large = preload_large
270
+ self._extra_paths = list(sys.path)
271
+ print(f"[CPU-POOL] Starting {n_workers} persistent worker(s) preload_large={preload_large}")
272
+ for i in range(n_workers):
273
+ h = self._spawn_worker(i)
274
+ self.workers.append(h)
275
+ # Wait for ready signal from each (serial — avoids RAM spike)
276
+ for h in self.workers:
277
+ self._wait_ready(h, timeout=boot_timeout)
278
+ self.free_q.put(h.worker_id)
279
+ print(f"[CPU-POOL] All {n_workers} workers READY")
280
+
281
+ def _spawn_worker(self, worker_id: int) -> _WorkerHandle:
282
+ req_q = self.ctx.Queue()
283
+ res_q = self.ctx.Queue()
284
+ ready_ev = self.ctx.Event()
285
+ p = self.ctx.Process(
286
+ target=_worker_loop,
287
+ args=(worker_id, self._extra_paths, req_q, res_q, ready_ev, self._preload_large),
288
+ daemon=True,
289
+ name=f"cpu-worker-{worker_id}",
290
+ )
291
+ p.start()
292
+ return _WorkerHandle(
293
+ worker_id=worker_id,
294
+ process=p,
295
+ req_q=req_q,
296
+ res_q=res_q,
297
+ ready_ev=ready_ev,
298
+ pid=p.pid,
299
+ )
300
+
301
+ def _wait_ready(self, h: _WorkerHandle, timeout: float):
302
+ """Drain res_q until we see the __ready__ tag or a __boot_error__."""
303
+ deadline = time.time() + timeout
304
+ while time.time() < deadline:
305
+ try:
306
+ tag, status, payload = h.res_q.get(timeout=min(10.0, deadline - time.time()))
307
+ except queue_mod.Empty:
308
+ if h.process is not None and not h.process.is_alive():
309
+ raise RuntimeError(f"Worker {h.worker_id} died during boot (exit={h.process.exitcode})")
310
+ continue
311
+ if tag == "__ready__":
312
+ h.snapshots = payload.get("snapshots", {})
313
+ h.load_times = payload.get("load_times", {})
314
+ h.boot_time = payload.get("boot_time", 0.0)
315
+ h.pid = payload.get("pid", h.pid)
316
+ return
317
+ if tag == "__boot_error__":
318
+ exc_type, exc_msg, tb = payload
319
+ raise RuntimeError(f"Worker {h.worker_id} boot failed: {exc_type}: {exc_msg}\n{tb}")
320
+ # Unexpected tag during boot — ignore and keep waiting.
321
+ raise TimeoutError(f"Worker {h.worker_id} did not become ready within {timeout}s")
322
+
323
+ def shutdown(self, timeout: float = 5.0):
324
+ with self._lock:
325
+ if not self._started:
326
+ return
327
+ for h in self.workers:
328
+ try:
329
+ h.req_q.put((0, "shutdown", None))
330
+ except Exception:
331
+ pass
332
+ for h in self.workers:
333
+ try:
334
+ if h.process is not None:
335
+ h.process.join(timeout=timeout)
336
+ if h.process.is_alive():
337
+ h.process.kill()
338
+ h.process.join(timeout=2)
339
+ except Exception:
340
+ pass
341
+ self.workers.clear()
342
+ self._started = False
343
+
344
+ # ---- task dispatch ---------------------------------------------------
345
+
346
+ def _next_task_id(self) -> int:
347
+ with self._lock:
348
+ self._task_counter += 1
349
+ return self._task_counter
350
+
351
+ def _acquire_worker(self, timeout: Optional[float] = None) -> _WorkerHandle:
352
+ wid = self.free_q.get(timeout=timeout)
353
+ # Validate the worker is still alive; if not, respawn in-place.
354
+ h = self.workers[wid]
355
+ if h.process is None or not h.process.is_alive():
356
+ print(f"[CPU-POOL] Worker {wid} dead on acquire — respawning")
357
+ self._respawn_worker(wid)
358
+ h = self.workers[wid]
359
+ return h
360
+
361
+ def _release_worker(self, h: _WorkerHandle):
362
+ self.free_q.put(h.worker_id)
363
+
364
+ def _respawn_worker(self, worker_id: int):
365
+ """Replace a dead worker in-place. Blocks until ready."""
366
+ t0 = time.time()
367
+ new_h = self._spawn_worker(worker_id)
368
+ self._wait_ready(new_h, timeout=600.0)
369
+ self.workers[worker_id] = new_h
370
+ print(f"[CPU-POOL] Worker {worker_id} respawned in {time.time()-t0:.1f}s (new pid={new_h.pid})")
371
+
372
+ def run(self, func, args, kwargs, timeout: Optional[float] = None) -> Any:
373
+ if not self._started:
374
+ raise RuntimeError("Pool not started")
375
+
376
+ h = self._acquire_worker(timeout=timeout)
377
+ try:
378
+ task_id = self._next_task_id()
379
+ func_module = func.__module__
380
+ func_name = func.__qualname__
381
+ print(f"[CPU-POOL] dispatch task#{task_id} {func_module}.{func_name} -> W{h.worker_id} (pid={h.pid})")
382
+ t0 = time.time()
383
+ h.req_q.put((task_id, "run", (func_module, func_name, args, kwargs)))
384
+
385
+ # Drain res_q; tolerate process death.
386
+ deadline = time.time() + (timeout or 3600 * 4)
387
+ while True:
388
+ try:
389
+ tag, status, payload = h.res_q.get(timeout=min(30.0, max(1.0, deadline - time.time())))
390
+ except queue_mod.Empty:
391
+ if not h.process.is_alive():
392
+ # worker died mid-task. respawn and raise so caller can retry.
393
+ print(f"[CPU-POOL] Worker {h.worker_id} died mid-task (exit={h.process.exitcode})")
394
+ self._respawn_worker(h.worker_id)
395
+ raise RuntimeError(f"Worker {h.worker_id} died mid-task")
396
+ if time.time() >= deadline:
397
+ raise TimeoutError(f"CPU pool task timed out after {timeout}s")
398
+ continue
399
+
400
+ if tag == task_id:
401
+ break
402
+ # stray message (e.g. leftover rss reply). Drop.
403
+ print(f"[CPU-POOL] W{h.worker_id} stray message tag={tag!r}, ignoring")
404
+
405
+ h.total_jobs += 1
406
+ dt = time.time() - t0
407
+ if status == "ok":
408
+ print(f"[CPU-POOL] task#{task_id} ok in {dt:.2f}s on W{h.worker_id}")
409
+ return payload
410
+ exc_type, exc_msg, tb = payload
411
+ print(f"[CPU-POOL] task#{task_id} error on W{h.worker_id}: {exc_type}: {exc_msg}\n{tb}")
412
+ raise RuntimeError(f"Worker error ({exc_type}): {exc_msg}")
413
+ finally:
414
+ # If the worker died we may have respawned it inside _run. In that
415
+ # case it's already in workers[] but not in free_q. Add it back.
416
+ if h.process is not None and not h.process.is_alive():
417
+ # respawn already put nothing back on free_q; add the *new* handle
418
+ new_h = self.workers[h.worker_id]
419
+ if new_h is not h:
420
+ self.free_q.put(new_h.worker_id)
421
+ else:
422
+ # lost — dead and not replaced. Try a respawn now.
423
+ try:
424
+ self._respawn_worker(h.worker_id)
425
+ self.free_q.put(h.worker_id)
426
+ except Exception as e:
427
+ print(f"[CPU-POOL] could not respawn W{h.worker_id}: {e}")
428
+ else:
429
+ self._release_worker(h)
430
+
431
+ # ---- diagnostics -----------------------------------------------------
432
+
433
+ def probe_rss(self, worker_id: int, timeout: float = 10.0) -> int:
434
+ h = self.workers[worker_id]
435
+ task_id = self._next_task_id()
436
+ h.req_q.put((task_id, "rss", None))
437
+ deadline = time.time() + timeout
438
+ while time.time() < deadline:
439
+ tag, status, payload = h.res_q.get(timeout=deadline - time.time())
440
+ if tag == task_id:
441
+ return int(payload)
442
+ raise TimeoutError("rss probe timed out")
443
+
444
+ def load_large(self, worker_id: int, timeout: float = 300.0) -> dict:
445
+ h = self.workers[worker_id]
446
+ task_id = self._next_task_id()
447
+ h.req_q.put((task_id, "load_large", None))
448
+ deadline = time.time() + timeout
449
+ while time.time() < deadline:
450
+ tag, status, payload = h.res_q.get(timeout=deadline - time.time())
451
+ if tag == task_id:
452
+ if status == "ok":
453
+ return payload
454
+ raise RuntimeError(f"load_large failed: {payload}")
455
+ raise TimeoutError("load_large timed out")
456
+
457
+ def stats(self) -> dict:
458
+ return {
459
+ "started": self._started,
460
+ "n_workers": len(self.workers),
461
+ "workers": [
462
+ {
463
+ "id": h.worker_id,
464
+ "pid": h.pid,
465
+ "alive": h.process is not None and h.process.is_alive(),
466
+ "total_jobs": h.total_jobs,
467
+ "boot_time": h.boot_time,
468
+ "snapshots": {k: v for k, v in h.snapshots.items()},
469
+ "load_times": h.load_times,
470
+ }
471
+ for h in self.workers
472
+ ],
473
+ }
474
+
475
+
476
+ # ---------------------------------------------------------------------------
477
+ # Module-level singleton API
478
+ # ---------------------------------------------------------------------------
479
+
480
+ _POOL: Optional[_Pool] = None
481
+ _START_LOCK = threading.Lock()
482
+
483
+
484
+ def _get_pool() -> _Pool:
485
+ global _POOL
486
+ if _POOL is None:
487
+ with _START_LOCK:
488
+ if _POOL is None:
489
+ _POOL = _Pool()
490
+ return _POOL
491
+
492
+
493
+ def start_pool(n_workers: int, preload_large: bool = False):
494
+ """Spawn the persistent worker pool. Idempotent."""
495
+ _get_pool().start(n_workers, preload_large=preload_large)
496
+
497
+
498
+ def is_started() -> bool:
499
+ return _POOL is not None and _POOL._started
500
+
501
+
502
+ def stats() -> dict:
503
+ return _get_pool().stats()
504
+
505
+
506
+ def probe_rss(worker_id: int) -> int:
507
+ return _get_pool().probe_rss(worker_id)
508
+
509
+
510
+ def load_large(worker_id: int) -> dict:
511
+ return _get_pool().load_large(worker_id)
512
+
513
+
514
+ def shutdown():
515
+ if _POOL is not None:
516
+ _POOL.shutdown()
517
+
518
+
519
+ def run_on_persistent_worker(func, args, kwargs, timeout: Optional[float] = None):
520
+ """Run a function on a free persistent worker. Blocks until done.
521
+
522
+ Caller is responsible for concurrency gating (the wrapper in zero_gpu.py
523
+ uses the same semaphore as the spawn path).
524
+ """
525
+ return _get_pool().run(func, args, kwargs, timeout=timeout)
src/core/zero_gpu.py CHANGED
@@ -269,18 +269,32 @@ def gpu_with_fallback(duration=60):
269
 
270
  if CPU_STRATEGY == "subprocess":
271
  import time as _time
 
272
  sem = _get_subprocess_semaphore()
273
  _t_acq = _time.time()
274
  sem.acquire()
275
  _wait = _time.time() - _t_acq
276
  try:
277
  _check_cuda_fork_state(f"before CPU subprocess ({func.__name__})")
278
- print(
279
- f"[CPU] Running {func.__name__} in isolated subprocess "
280
- f"(CPU_STRATEGY=subprocess, queue_wait={_wait:.2f}s)"
281
- )
282
- from .cpu_subprocess import run_in_cpu_subprocess
283
- result = run_in_cpu_subprocess(func, args, kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  _check_cuda_fork_state(f"after CPU subprocess ({func.__name__})")
285
  return result
286
  finally:
 
269
 
270
  if CPU_STRATEGY == "subprocess":
271
  import time as _time
272
+ from config import CPU_WORKER_MODE
273
  sem = _get_subprocess_semaphore()
274
  _t_acq = _time.time()
275
  sem.acquire()
276
  _wait = _time.time() - _t_acq
277
  try:
278
  _check_cuda_fork_state(f"before CPU subprocess ({func.__name__})")
279
+ if CPU_WORKER_MODE == "persistent":
280
+ from .cpu_worker_pool import run_on_persistent_worker, is_started, start_pool
281
+ if not is_started():
282
+ # Lazy start fallback — slow first request.
283
+ from config import CPU_SUBPROCESS_CONCURRENCY, CPU_POOL_PRELOAD_LARGE
284
+ print("[CPU] Pool not started — lazy-starting now")
285
+ start_pool(CPU_SUBPROCESS_CONCURRENCY, preload_large=CPU_POOL_PRELOAD_LARGE)
286
+ print(
287
+ f"[CPU] Running {func.__name__} on persistent worker "
288
+ f"(CPU_WORKER_MODE=persistent, queue_wait={_wait:.2f}s)"
289
+ )
290
+ result = run_on_persistent_worker(func, args, kwargs)
291
+ else:
292
+ print(
293
+ f"[CPU] Running {func.__name__} in isolated subprocess "
294
+ f"(CPU_STRATEGY=subprocess, queue_wait={_wait:.2f}s)"
295
+ )
296
+ from .cpu_subprocess import run_in_cpu_subprocess
297
+ result = run_in_cpu_subprocess(func, args, kwargs)
298
  _check_cuda_fork_state(f"after CPU subprocess ({func.__name__})")
299
  return result
300
  finally:
src/ui/event_wiring.py CHANGED
@@ -17,6 +17,8 @@ from src.api.session_api import (
17
  debug_process,
18
  cpu_exec,
19
  pool_status,
 
 
20
  )
21
  from src.mfa import compute_mfa_timestamps
22
  from src.ui.progress_bar import pipeline_progress_bar_html
@@ -935,6 +937,20 @@ def _wire_api_endpoint(c):
935
  outputs=[c.api_result],
936
  api_name="pool_status",
937
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
 
939
 
940
  def _wire_dev_tab(c):
 
17
  debug_process,
18
  cpu_exec,
19
  pool_status,
20
+ cpu_pool_status,
21
+ cpu_pool_kill,
22
  )
23
  from src.mfa import compute_mfa_timestamps
24
  from src.ui.progress_bar import pipeline_progress_bar_html
 
937
  outputs=[c.api_result],
938
  api_name="pool_status",
939
  )
940
+ # Persistent CPU worker pool status — HF-token-gated.
941
+ gr.Button(visible=False).click(
942
+ fn=cpu_pool_status,
943
+ inputs=[c.api_pool_status_token],
944
+ outputs=[c.api_result],
945
+ api_name="cpu_pool_status",
946
+ )
947
+ # Kill a persistent worker — crash-recovery test helper, HF-token-gated.
948
+ gr.Button(visible=False).click(
949
+ fn=cpu_pool_kill,
950
+ inputs=[c.api_pool_status_token, c.api_cpu_exec_module], # reuse token + a string input
951
+ outputs=[c.api_result],
952
+ api_name="cpu_pool_kill",
953
+ )
954
 
955
 
956
  def _wire_dev_tab(c):