Spaces:
Running on Zero
Running on Zero
Upload folder using huggingface_hub
Browse files- app.py +23 -0
- config.py +13 -11
- src/api/session_api.py +87 -0
- src/core/cpu_worker_pool.py +525 -0
- src/core/zero_gpu.py +20 -6
- src/ui/event_wiring.py +16 -0
app.py
CHANGED
|
@@ -54,6 +54,29 @@ else:
|
|
| 54 |
|
| 55 |
from src.ui.interface import build_interface
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# =============================================================================
|
| 58 |
# Module-level demo for Gradio hot-reload (`gradio app.py`)
|
| 59 |
# =============================================================================
|
|
|
|
| 54 |
|
| 55 |
from src.ui.interface import build_interface
|
| 56 |
|
| 57 |
+
# =============================================================================
|
| 58 |
+
# Persistent CPU worker pool — spawn BEFORE any GPU use if enabled.
|
| 59 |
+
# This keeps the workers free of any inherited CUDA/ZeroGPU state.
|
| 60 |
+
# =============================================================================
|
| 61 |
+
try:
|
| 62 |
+
from config import (
|
| 63 |
+
CPU_STRATEGY as _CPU_STRATEGY,
|
| 64 |
+
CPU_WORKER_MODE as _CPU_WORKER_MODE,
|
| 65 |
+
CPU_SUBPROCESS_CONCURRENCY as _CPU_SUBPROCESS_CONCURRENCY,
|
| 66 |
+
CPU_POOL_PRELOAD_LARGE as _CPU_POOL_PRELOAD_LARGE,
|
| 67 |
+
)
|
| 68 |
+
from src.core.zero_gpu import IS_CPU_WORKER as _IS_CPU_WORKER
|
| 69 |
+
if (
|
| 70 |
+
_CPU_STRATEGY == "subprocess"
|
| 71 |
+
and _CPU_WORKER_MODE == "persistent"
|
| 72 |
+
and not _IS_CPU_WORKER
|
| 73 |
+
):
|
| 74 |
+
print(f"[APP] Bootstrapping persistent CPU pool: {_CPU_SUBPROCESS_CONCURRENCY} worker(s), preload_large={_CPU_POOL_PRELOAD_LARGE}")
|
| 75 |
+
from src.core.cpu_worker_pool import start_pool as _start_pool
|
| 76 |
+
_start_pool(_CPU_SUBPROCESS_CONCURRENCY, preload_large=_CPU_POOL_PRELOAD_LARGE)
|
| 77 |
+
except Exception as _e:
|
| 78 |
+
print(f"[APP] Persistent CPU pool bootstrap failed (non-fatal): {_e}")
|
| 79 |
+
|
| 80 |
# =============================================================================
|
| 81 |
# Module-level demo for Gradio hot-reload (`gradio app.py`)
|
| 82 |
# =============================================================================
|
config.py
CHANGED
|
@@ -48,14 +48,21 @@ SESSION_EXPIRY_SECONDS = 3600*5 # 5 hours — matches DELETE_CACHE_
|
|
| 48 |
CPU_STRATEGY = os.environ.get("CPU_STRATEGY", "subprocess").lower()
|
| 49 |
|
| 50 |
# Max seconds a subprocess CPU job can run before SIGKILL (used by "subprocess" and "both" strategies).
|
| 51 |
-
CPU_SUBPROCESS_TIMEOUT = int(os.environ.get("CPU_SUBPROCESS_TIMEOUT", str(3600 *
|
| 52 |
|
| 53 |
-
# Max concurrent CPU subprocesses on the main Space.
|
| 54 |
-
# own copy of the VAD + ASR models (~3.6 GB RAM). On zero-a10g with 48 GB RAM
|
| 55 |
-
# and 8 vCPU we can safely run 2–3; pushing higher risks OOM and CPU thrash
|
| 56 |
-
# that would also slow GPU dispatches on main.
|
| 57 |
CPU_SUBPROCESS_CONCURRENCY = int(os.environ.get("CPU_SUBPROCESS_CONCURRENCY", "2"))
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Model dtype for CPU inference.
|
| 60 |
# "bfloat16" — default. Routes attention through PyTorch's chunked CPU flash
|
| 61 |
# kernel (`_scaled_dot_product_flash_attention_for_cpu`), which
|
|
@@ -171,12 +178,7 @@ INFERENCE_BATCH_SIZE = 32 # Fixed segments per batch (used when BATCHING_ST
|
|
| 171 |
|
| 172 |
# Dynamic batching constraints
|
| 173 |
MAX_BATCH_SECONDS = 600 # GPU: max total audio seconds per batch (sum of durations)
|
| 174 |
-
# CPU: tighter cap. SDPA materialises the QK^T tensor per encoder layer
|
| 175 |
-
# (batch * heads * seq^2 * 2B) exceeds the L3 cache (~32 MB on zero-a10g Xeon),
|
| 176 |
-
# every attention layer becomes DRAM-bound instead of cache-bound — on a 22 min
|
| 177 |
-
# m4a we saw one batch (59 segs × 12.5 s) hit a ~24× slowdown vs its neighbours.
|
| 178 |
-
# 300 s keeps QK^T comfortably under L3 at realistic (batch, seq) combinations.
|
| 179 |
-
MAX_BATCH_SECONDS_CPU = 300
|
| 180 |
MAX_PAD_WASTE = 0.2 # Max fraction of padded tensor that is wasted (0=no waste, 1=all waste)
|
| 181 |
MIN_BATCH_SIZE = 8 # Minimum segments per batch (prevents underutilization)
|
| 182 |
|
|
|
|
| 48 |
CPU_STRATEGY = os.environ.get("CPU_STRATEGY", "subprocess").lower()
|
| 49 |
|
| 50 |
# Max seconds a subprocess CPU job can run before SIGKILL (used by "subprocess" and "both" strategies).
|
| 51 |
+
CPU_SUBPROCESS_TIMEOUT = int(os.environ.get("CPU_SUBPROCESS_TIMEOUT", str(3600 * 2)))
|
| 52 |
|
| 53 |
+
# Max concurrent CPU subprocesses on the main Space.
|
|
|
|
|
|
|
|
|
|
| 54 |
CPU_SUBPROCESS_CONCURRENCY = int(os.environ.get("CPU_SUBPROCESS_CONCURRENCY", "2"))
|
| 55 |
|
| 56 |
+
# CPU_WORKER_MODE — when CPU_STRATEGY="subprocess", chooses between:
|
| 57 |
+
# "spawn" — legacy: fork a fresh subprocess per request (cpu_subprocess.py).
|
| 58 |
+
# "persistent" — new: route to a pool of long-lived workers (cpu_worker_pool.py).
|
| 59 |
+
# Semaphore capacity stays = CPU_SUBPROCESS_CONCURRENCY either way.
|
| 60 |
+
CPU_WORKER_MODE = os.environ.get("CPU_WORKER_MODE", "persistent").lower()
|
| 61 |
+
|
| 62 |
+
# Whether the persistent pool preloads ASR Large at boot. If False, Large is
|
| 63 |
+
# loaded on-demand inside the worker and cached there.
|
| 64 |
+
CPU_POOL_PRELOAD_LARGE = os.environ.get("CPU_POOL_PRELOAD_LARGE", "1") == "1"
|
| 65 |
+
|
| 66 |
# Model dtype for CPU inference.
|
| 67 |
# "bfloat16" — default. Routes attention through PyTorch's chunked CPU flash
|
| 68 |
# kernel (`_scaled_dot_product_flash_attention_for_cpu`), which
|
|
|
|
| 178 |
|
| 179 |
# Dynamic batching constraints
|
| 180 |
MAX_BATCH_SECONDS = 600 # GPU: max total audio seconds per batch (sum of durations)
|
| 181 |
+
MAX_BATCH_SECONDS_CPU = 300 # CPU: tighter cap. SDPA materialises the QK^T tensor per encoder layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
MAX_PAD_WASTE = 0.2 # Max fraction of padded tensor that is wasted (0=no waste, 1=all waste)
|
| 183 |
MIN_BATCH_SIZE = 8 # Minimum segments per batch (prevents underutilization)
|
| 184 |
|
src/api/session_api.py
CHANGED
|
@@ -957,6 +957,93 @@ def pool_status(hf_token):
|
|
| 957 |
}
|
| 958 |
|
| 959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
# ---------------------------------------------------------------------------
|
| 961 |
# Hidden debug endpoint
|
| 962 |
# ---------------------------------------------------------------------------
|
|
|
|
| 957 |
}
|
| 958 |
|
| 959 |
|
| 960 |
+
def cpu_pool_kill(hf_token, worker_id):
|
| 961 |
+
"""Kill a persistent worker for crash-recovery testing. HF-token-gated."""
|
| 962 |
+
space_token = os.environ.get("HF_TOKEN", "")
|
| 963 |
+
if not hf_token or (space_token and hf_token != space_token):
|
| 964 |
+
return {"error": "Unauthorized"}
|
| 965 |
+
try:
|
| 966 |
+
from src.core.cpu_worker_pool import _get_pool
|
| 967 |
+
import signal as _signal
|
| 968 |
+
import time as _time
|
| 969 |
+
p = _get_pool()
|
| 970 |
+
wid = int(worker_id)
|
| 971 |
+
h = p.workers[wid]
|
| 972 |
+
pid = h.pid
|
| 973 |
+
was_alive = h.process is not None and h.process.is_alive()
|
| 974 |
+
try:
|
| 975 |
+
os.kill(pid, _signal.SIGKILL)
|
| 976 |
+
sent = True
|
| 977 |
+
send_err = None
|
| 978 |
+
except Exception as ke:
|
| 979 |
+
sent = False
|
| 980 |
+
send_err = str(ke)
|
| 981 |
+
# give OS a moment to reap
|
| 982 |
+
_time.sleep(0.3)
|
| 983 |
+
alive_after = h.process is not None and h.process.is_alive()
|
| 984 |
+
return {
|
| 985 |
+
"worker_id": wid,
|
| 986 |
+
"pid": pid,
|
| 987 |
+
"was_alive": was_alive,
|
| 988 |
+
"kill_sent": sent,
|
| 989 |
+
"send_err": send_err,
|
| 990 |
+
"alive_after": alive_after,
|
| 991 |
+
}
|
| 992 |
+
except Exception as e:
|
| 993 |
+
return {"error": str(e)}
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
def cpu_pool_status(hf_token):
|
| 997 |
+
"""Return persistent CPU worker pool state. HF-token-gated.
|
| 998 |
+
|
| 999 |
+
Prototype diagnostic: shows per-worker boot snapshots, load times, pids,
|
| 1000 |
+
live RSS, and job counts. Safe to call on Spaces where CPU_WORKER_MODE
|
| 1001 |
+
is not 'persistent' — just returns `started=False`.
|
| 1002 |
+
"""
|
| 1003 |
+
space_token = os.environ.get("HF_TOKEN", "")
|
| 1004 |
+
if not hf_token or (space_token and hf_token != space_token):
|
| 1005 |
+
return {"error": "Unauthorized"}
|
| 1006 |
+
|
| 1007 |
+
try:
|
| 1008 |
+
from src.core.cpu_worker_pool import is_started, stats as pool_stats, probe_rss
|
| 1009 |
+
except Exception as e:
|
| 1010 |
+
return {"error": f"pool import failed: {e}"}
|
| 1011 |
+
|
| 1012 |
+
if not is_started():
|
| 1013 |
+
return {"started": False, "note": "CPU_WORKER_MODE != persistent or pool not yet bootstrapped"}
|
| 1014 |
+
|
| 1015 |
+
s = pool_stats()
|
| 1016 |
+
# Augment with live RSS probe per worker
|
| 1017 |
+
for w in s.get("workers", []):
|
| 1018 |
+
try:
|
| 1019 |
+
w["rss_now"] = probe_rss(w["id"])
|
| 1020 |
+
except Exception as e:
|
| 1021 |
+
w["rss_now_error"] = str(e)
|
| 1022 |
+
# Include main process RSS
|
| 1023 |
+
try:
|
| 1024 |
+
import psutil as _ps
|
| 1025 |
+
s["main_rss"] = _ps.Process(os.getpid()).memory_info().rss
|
| 1026 |
+
vm = _ps.virtual_memory()
|
| 1027 |
+
s["host_mem"] = {"total": vm.total, "available": vm.available, "used": vm.used, "percent": vm.percent}
|
| 1028 |
+
except Exception as e:
|
| 1029 |
+
s["main_rss_error"] = str(e)
|
| 1030 |
+
# Probe cgroup (container) memory limit — authoritative Space budget.
|
| 1031 |
+
cgroup = {}
|
| 1032 |
+
for path in ("/sys/fs/cgroup/memory.max", "/sys/fs/cgroup/memory/memory.limit_in_bytes"):
|
| 1033 |
+
try:
|
| 1034 |
+
with open(path) as _f:
|
| 1035 |
+
cgroup[path] = _f.read().strip()
|
| 1036 |
+
except Exception as e:
|
| 1037 |
+
cgroup[path] = f"err: {e}"
|
| 1038 |
+
try:
|
| 1039 |
+
with open("/sys/fs/cgroup/memory.current") as _f:
|
| 1040 |
+
cgroup["memory.current"] = _f.read().strip()
|
| 1041 |
+
except Exception:
|
| 1042 |
+
pass
|
| 1043 |
+
s["cgroup"] = cgroup
|
| 1044 |
+
return s
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
# ---------------------------------------------------------------------------
|
| 1048 |
# Hidden debug endpoint
|
| 1049 |
# ---------------------------------------------------------------------------
|
src/core/cpu_worker_pool.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Persistent CPU worker pool — prototype.
|
| 2 |
+
|
| 3 |
+
Replaces the spawn-per-request path (cpu_subprocess.py) with long-lived
|
| 4 |
+
workers that preload VAD + ASR (Base) once and stay ready. ASR Large is
|
| 5 |
+
loaded on-demand in each worker (and cached there) so steady-state RAM
|
| 6 |
+
stays predictable.
|
| 7 |
+
|
| 8 |
+
Gated behind env var CPU_WORKER_MODE=persistent. CPU_WORKER_MODE=spawn
|
| 9 |
+
(default) keeps the existing spawn-per-request behavior.
|
| 10 |
+
|
| 11 |
+
Semaphore capacity = CPU_SUBPROCESS_CONCURRENCY (shared with spawn path).
|
| 12 |
+
A free-worker queue gives O(1) idle-worker pickup and guarantees at most
|
| 13 |
+
one concurrent job per worker.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import importlib
|
| 19 |
+
import multiprocessing as mp
|
| 20 |
+
import os
|
| 21 |
+
import queue as queue_mod
|
| 22 |
+
import signal
|
| 23 |
+
import sys
|
| 24 |
+
import threading
|
| 25 |
+
import time
|
| 26 |
+
import traceback
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Any, Optional
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Worker side
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _worker_loop(
|
| 37 |
+
worker_id: int,
|
| 38 |
+
extra_paths: list,
|
| 39 |
+
req_q: mp.Queue,
|
| 40 |
+
res_q: mp.Queue,
|
| 41 |
+
ready_ev,
|
| 42 |
+
preload_large: bool,
|
| 43 |
+
):
|
| 44 |
+
"""Long-lived worker body. Runs in a spawn-context process.
|
| 45 |
+
|
| 46 |
+
Steps:
|
| 47 |
+
1. Env hygiene — hide CUDA, disable ZeroGPU patches.
|
| 48 |
+
2. Import project modules, call force_cpu_mode().
|
| 49 |
+
3. Preload VAD + ASR Base (Large optional).
|
| 50 |
+
4. Signal `ready_ev`.
|
| 51 |
+
5. Loop: pull task → execute → push result.
|
| 52 |
+
|
| 53 |
+
Tasks are pickled tuples: (task_id, kind, payload)
|
| 54 |
+
kind="run": payload=(func_module, func_name, args, kwargs)
|
| 55 |
+
kind="rss": payload=None (return rss bytes)
|
| 56 |
+
kind="load_large": payload=None (preload ASR Large if not cached)
|
| 57 |
+
kind="shutdown": payload=None (exit loop)
|
| 58 |
+
"""
|
| 59 |
+
# ---- Env hygiene BEFORE any torch import ----------------------------
|
| 60 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 61 |
+
os.environ["SPACES_ZERO_GPU"] = ""
|
| 62 |
+
# Suppress the HF download progress bars — the parent app sets this but
|
| 63 |
+
# the spawned child inherits only env, not module state.
|
| 64 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 65 |
+
|
| 66 |
+
# Restore sys.path from parent so src/ is importable
|
| 67 |
+
for p in extra_paths:
|
| 68 |
+
if p and p not in sys.path:
|
| 69 |
+
sys.path.insert(0, p)
|
| 70 |
+
|
| 71 |
+
# Helpful process label for ps/htop
|
| 72 |
+
try:
|
| 73 |
+
import setproctitle # type: ignore
|
| 74 |
+
setproctitle.setproctitle(f"cpu-worker-{worker_id}")
|
| 75 |
+
except Exception:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
log = lambda msg: print(f"[CPU-POOL/W{worker_id}] {msg}", flush=True)
|
| 79 |
+
|
| 80 |
+
# ---- RSS probe ------------------------------------------------------
|
| 81 |
+
def _rss_bytes() -> int:
|
| 82 |
+
try:
|
| 83 |
+
import psutil # type: ignore
|
| 84 |
+
return psutil.Process(os.getpid()).memory_info().rss
|
| 85 |
+
except Exception:
|
| 86 |
+
try:
|
| 87 |
+
with open(f"/proc/{os.getpid()}/status") as f:
|
| 88 |
+
for line in f:
|
| 89 |
+
if line.startswith("VmRSS:"):
|
| 90 |
+
return int(line.split()[1]) * 1024
|
| 91 |
+
except Exception:
|
| 92 |
+
return 0
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
snapshots = {"start": _rss_bytes()}
|
| 96 |
+
t0 = time.time()
|
| 97 |
+
log(f"booted pid={os.getpid()} rss={snapshots['start']/1e6:.1f}MB")
|
| 98 |
+
|
| 99 |
+
# ---- Imports + force CPU mode --------------------------------------
|
| 100 |
+
try:
|
| 101 |
+
from src.core.zero_gpu import force_cpu_mode
|
| 102 |
+
force_cpu_mode()
|
| 103 |
+
snapshots["after_imports"] = _rss_bytes()
|
| 104 |
+
log(f"imports done +{(time.time()-t0):.1f}s rss={snapshots['after_imports']/1e6:.1f}MB")
|
| 105 |
+
except Exception as e:
|
| 106 |
+
log(f"FATAL imports: {e}\n{traceback.format_exc()}")
|
| 107 |
+
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
load_times = {}
|
| 111 |
+
|
| 112 |
+
# ---- Preload VAD ---------------------------------------------------
|
| 113 |
+
try:
|
| 114 |
+
t = time.time()
|
| 115 |
+
from src.segmenter.segmenter_model import load_segmenter
|
| 116 |
+
load_segmenter()
|
| 117 |
+
load_times["vad"] = time.time() - t
|
| 118 |
+
snapshots["after_vad"] = _rss_bytes()
|
| 119 |
+
log(f"VAD loaded in {load_times['vad']:.2f}s rss={snapshots['after_vad']/1e6:.1f}MB")
|
| 120 |
+
except Exception as e:
|
| 121 |
+
log(f"VAD load failed: {e}")
|
| 122 |
+
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# ---- Preload ASR Base ---------------------------------------------
|
| 126 |
+
try:
|
| 127 |
+
t = time.time()
|
| 128 |
+
from src.alignment.phoneme_asr import load_phoneme_asr
|
| 129 |
+
load_phoneme_asr("Base")
|
| 130 |
+
load_times["asr_base"] = time.time() - t
|
| 131 |
+
snapshots["after_asr_base"] = _rss_bytes()
|
| 132 |
+
log(f"ASR Base loaded in {load_times['asr_base']:.2f}s rss={snapshots['after_asr_base']/1e6:.1f}MB")
|
| 133 |
+
except Exception as e:
|
| 134 |
+
log(f"ASR Base load failed: {e}")
|
| 135 |
+
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
# ---- Preload caches (ngram index, phoneme chapters) ----------------
|
| 139 |
+
try:
|
| 140 |
+
t = time.time()
|
| 141 |
+
from src.alignment.ngram_index import get_ngram_index
|
| 142 |
+
from src.alignment.phoneme_matcher_cache import preload_all_chapters
|
| 143 |
+
get_ngram_index()
|
| 144 |
+
preload_all_chapters()
|
| 145 |
+
load_times["caches"] = time.time() - t
|
| 146 |
+
snapshots["after_caches"] = _rss_bytes()
|
| 147 |
+
log(f"caches loaded in {load_times['caches']:.2f}s rss={snapshots['after_caches']/1e6:.1f}MB")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
log(f"caches load failed (non-fatal): {e}")
|
| 150 |
+
|
| 151 |
+
# ---- Optionally preload ASR Large ---------------------------------
|
| 152 |
+
if preload_large:
|
| 153 |
+
try:
|
| 154 |
+
t = time.time()
|
| 155 |
+
from src.alignment.phoneme_asr import load_phoneme_asr
|
| 156 |
+
load_phoneme_asr("Large")
|
| 157 |
+
load_times["asr_large"] = time.time() - t
|
| 158 |
+
snapshots["after_asr_large"] = _rss_bytes()
|
| 159 |
+
log(f"ASR Large loaded in {load_times['asr_large']:.2f}s rss={snapshots['after_asr_large']/1e6:.1f}MB")
|
| 160 |
+
except Exception as e:
|
| 161 |
+
log(f"ASR Large preload failed: {e}")
|
| 162 |
+
|
| 163 |
+
# ---- Warm up resampler --------------------------------------------
|
| 164 |
+
try:
|
| 165 |
+
import numpy as np, librosa
|
| 166 |
+
from config import RESAMPLE_TYPE
|
| 167 |
+
_ = librosa.resample(np.zeros(1600, dtype=np.float32),
|
| 168 |
+
orig_sr=44100, target_sr=16000, res_type=RESAMPLE_TYPE)
|
| 169 |
+
except Exception:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
snapshots["ready"] = _rss_bytes()
|
| 173 |
+
total_boot = time.time() - t0
|
| 174 |
+
log(f"READY in {total_boot:.2f}s, final rss={snapshots['ready']/1e6:.1f}MB")
|
| 175 |
+
|
| 176 |
+
# Signal parent that this worker booted successfully
|
| 177 |
+
res_q.put(("__ready__", "ok", {
|
| 178 |
+
"worker_id": worker_id,
|
| 179 |
+
"pid": os.getpid(),
|
| 180 |
+
"snapshots": snapshots,
|
| 181 |
+
"load_times": load_times,
|
| 182 |
+
"boot_time": total_boot,
|
| 183 |
+
}))
|
| 184 |
+
ready_ev.set()
|
| 185 |
+
|
| 186 |
+
# ---- Main loop -----------------------------------------------------
|
| 187 |
+
while True:
|
| 188 |
+
try:
|
| 189 |
+
item = req_q.get()
|
| 190 |
+
except (EOFError, OSError, KeyboardInterrupt):
|
| 191 |
+
break
|
| 192 |
+
if item is None:
|
| 193 |
+
break
|
| 194 |
+
task_id, kind, payload = item
|
| 195 |
+
try:
|
| 196 |
+
if kind == "shutdown":
|
| 197 |
+
break
|
| 198 |
+
elif kind == "rss":
|
| 199 |
+
res_q.put((task_id, "ok", _rss_bytes()))
|
| 200 |
+
continue
|
| 201 |
+
elif kind == "load_large":
|
| 202 |
+
try:
|
| 203 |
+
from src.alignment.phoneme_asr import load_phoneme_asr
|
| 204 |
+
t = time.time()
|
| 205 |
+
load_phoneme_asr("Large")
|
| 206 |
+
res_q.put((task_id, "ok", {"load_time": time.time() - t, "rss": _rss_bytes()}))
|
| 207 |
+
except Exception as e:
|
| 208 |
+
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 209 |
+
continue
|
| 210 |
+
elif kind == "run":
|
| 211 |
+
func_module, func_name, args, kwargs = payload
|
| 212 |
+
try:
|
| 213 |
+
module = importlib.import_module(func_module)
|
| 214 |
+
func = getattr(module, func_name)
|
| 215 |
+
while hasattr(func, "__wrapped__"):
|
| 216 |
+
func = func.__wrapped__
|
| 217 |
+
result = func(*args, **kwargs)
|
| 218 |
+
res_q.put((task_id, "ok", result))
|
| 219 |
+
except Exception as e:
|
| 220 |
+
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 221 |
+
continue
|
| 222 |
+
else:
|
| 223 |
+
res_q.put((task_id, "error", ("ValueError", f"unknown kind {kind!r}", "")))
|
| 224 |
+
except Exception as e:
|
| 225 |
+
# Catch-all so the loop survives
|
| 226 |
+
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
|
| 227 |
+
|
| 228 |
+
log("exiting cleanly")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
# Parent side
|
| 233 |
+
# ---------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@dataclass
|
| 237 |
+
class _WorkerHandle:
|
| 238 |
+
worker_id: int
|
| 239 |
+
process: Optional[Any] = None
|
| 240 |
+
req_q: Optional[Any] = None
|
| 241 |
+
res_q: Optional[Any] = None
|
| 242 |
+
ready_ev: Optional[Any] = None
|
| 243 |
+
snapshots: dict = field(default_factory=dict)
|
| 244 |
+
load_times: dict = field(default_factory=dict)
|
| 245 |
+
boot_time: float = 0.0
|
| 246 |
+
pid: Optional[int] = None
|
| 247 |
+
total_jobs: int = 0
|
| 248 |
+
lock: threading.Lock = field(default_factory=threading.Lock)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class _Pool:
|
| 252 |
+
def __init__(self):
|
| 253 |
+
self.ctx = mp.get_context("spawn")
|
| 254 |
+
self.workers: list[_WorkerHandle] = []
|
| 255 |
+
self.free_q: "queue_mod.Queue[int]" = queue_mod.Queue()
|
| 256 |
+
self._started = False
|
| 257 |
+
self._lock = threading.Lock()
|
| 258 |
+
self._task_counter = 0
|
| 259 |
+
self._preload_large = False
|
| 260 |
+
self._extra_paths: list[str] = []
|
| 261 |
+
|
| 262 |
+
# ---- lifecycle -------------------------------------------------------
|
| 263 |
+
|
| 264 |
+
def start(self, n_workers: int, preload_large: bool = False, boot_timeout: float = 600.0):
|
| 265 |
+
with self._lock:
|
| 266 |
+
if self._started:
|
| 267 |
+
return
|
| 268 |
+
self._started = True
|
| 269 |
+
self._preload_large = preload_large
|
| 270 |
+
self._extra_paths = list(sys.path)
|
| 271 |
+
print(f"[CPU-POOL] Starting {n_workers} persistent worker(s) preload_large={preload_large}")
|
| 272 |
+
for i in range(n_workers):
|
| 273 |
+
h = self._spawn_worker(i)
|
| 274 |
+
self.workers.append(h)
|
| 275 |
+
# Wait for ready signal from each (serial — avoids RAM spike)
|
| 276 |
+
for h in self.workers:
|
| 277 |
+
self._wait_ready(h, timeout=boot_timeout)
|
| 278 |
+
self.free_q.put(h.worker_id)
|
| 279 |
+
print(f"[CPU-POOL] All {n_workers} workers READY")
|
| 280 |
+
|
| 281 |
+
def _spawn_worker(self, worker_id: int) -> _WorkerHandle:
|
| 282 |
+
req_q = self.ctx.Queue()
|
| 283 |
+
res_q = self.ctx.Queue()
|
| 284 |
+
ready_ev = self.ctx.Event()
|
| 285 |
+
p = self.ctx.Process(
|
| 286 |
+
target=_worker_loop,
|
| 287 |
+
args=(worker_id, self._extra_paths, req_q, res_q, ready_ev, self._preload_large),
|
| 288 |
+
daemon=True,
|
| 289 |
+
name=f"cpu-worker-{worker_id}",
|
| 290 |
+
)
|
| 291 |
+
p.start()
|
| 292 |
+
return _WorkerHandle(
|
| 293 |
+
worker_id=worker_id,
|
| 294 |
+
process=p,
|
| 295 |
+
req_q=req_q,
|
| 296 |
+
res_q=res_q,
|
| 297 |
+
ready_ev=ready_ev,
|
| 298 |
+
pid=p.pid,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def _wait_ready(self, h: _WorkerHandle, timeout: float):
|
| 302 |
+
"""Drain res_q until we see the __ready__ tag or a __boot_error__."""
|
| 303 |
+
deadline = time.time() + timeout
|
| 304 |
+
while time.time() < deadline:
|
| 305 |
+
try:
|
| 306 |
+
tag, status, payload = h.res_q.get(timeout=min(10.0, deadline - time.time()))
|
| 307 |
+
except queue_mod.Empty:
|
| 308 |
+
if h.process is not None and not h.process.is_alive():
|
| 309 |
+
raise RuntimeError(f"Worker {h.worker_id} died during boot (exit={h.process.exitcode})")
|
| 310 |
+
continue
|
| 311 |
+
if tag == "__ready__":
|
| 312 |
+
h.snapshots = payload.get("snapshots", {})
|
| 313 |
+
h.load_times = payload.get("load_times", {})
|
| 314 |
+
h.boot_time = payload.get("boot_time", 0.0)
|
| 315 |
+
h.pid = payload.get("pid", h.pid)
|
| 316 |
+
return
|
| 317 |
+
if tag == "__boot_error__":
|
| 318 |
+
exc_type, exc_msg, tb = payload
|
| 319 |
+
raise RuntimeError(f"Worker {h.worker_id} boot failed: {exc_type}: {exc_msg}\n{tb}")
|
| 320 |
+
# Unexpected tag during boot — ignore and keep waiting.
|
| 321 |
+
raise TimeoutError(f"Worker {h.worker_id} did not become ready within {timeout}s")
|
| 322 |
+
|
| 323 |
+
def shutdown(self, timeout: float = 5.0):
|
| 324 |
+
with self._lock:
|
| 325 |
+
if not self._started:
|
| 326 |
+
return
|
| 327 |
+
for h in self.workers:
|
| 328 |
+
try:
|
| 329 |
+
h.req_q.put((0, "shutdown", None))
|
| 330 |
+
except Exception:
|
| 331 |
+
pass
|
| 332 |
+
for h in self.workers:
|
| 333 |
+
try:
|
| 334 |
+
if h.process is not None:
|
| 335 |
+
h.process.join(timeout=timeout)
|
| 336 |
+
if h.process.is_alive():
|
| 337 |
+
h.process.kill()
|
| 338 |
+
h.process.join(timeout=2)
|
| 339 |
+
except Exception:
|
| 340 |
+
pass
|
| 341 |
+
self.workers.clear()
|
| 342 |
+
self._started = False
|
| 343 |
+
|
| 344 |
+
# ---- task dispatch ---------------------------------------------------
|
| 345 |
+
|
| 346 |
+
def _next_task_id(self) -> int:
|
| 347 |
+
with self._lock:
|
| 348 |
+
self._task_counter += 1
|
| 349 |
+
return self._task_counter
|
| 350 |
+
|
| 351 |
+
def _acquire_worker(self, timeout: Optional[float] = None) -> _WorkerHandle:
|
| 352 |
+
wid = self.free_q.get(timeout=timeout)
|
| 353 |
+
# Validate the worker is still alive; if not, respawn in-place.
|
| 354 |
+
h = self.workers[wid]
|
| 355 |
+
if h.process is None or not h.process.is_alive():
|
| 356 |
+
print(f"[CPU-POOL] Worker {wid} dead on acquire — respawning")
|
| 357 |
+
self._respawn_worker(wid)
|
| 358 |
+
h = self.workers[wid]
|
| 359 |
+
return h
|
| 360 |
+
|
| 361 |
+
def _release_worker(self, h: _WorkerHandle):
|
| 362 |
+
self.free_q.put(h.worker_id)
|
| 363 |
+
|
| 364 |
+
def _respawn_worker(self, worker_id: int):
|
| 365 |
+
"""Replace a dead worker in-place. Blocks until ready."""
|
| 366 |
+
t0 = time.time()
|
| 367 |
+
new_h = self._spawn_worker(worker_id)
|
| 368 |
+
self._wait_ready(new_h, timeout=600.0)
|
| 369 |
+
self.workers[worker_id] = new_h
|
| 370 |
+
print(f"[CPU-POOL] Worker {worker_id} respawned in {time.time()-t0:.1f}s (new pid={new_h.pid})")
|
| 371 |
+
|
| 372 |
+
def run(self, func, args, kwargs, timeout: Optional[float] = None) -> Any:
|
| 373 |
+
if not self._started:
|
| 374 |
+
raise RuntimeError("Pool not started")
|
| 375 |
+
|
| 376 |
+
h = self._acquire_worker(timeout=timeout)
|
| 377 |
+
try:
|
| 378 |
+
task_id = self._next_task_id()
|
| 379 |
+
func_module = func.__module__
|
| 380 |
+
func_name = func.__qualname__
|
| 381 |
+
print(f"[CPU-POOL] dispatch task#{task_id} {func_module}.{func_name} -> W{h.worker_id} (pid={h.pid})")
|
| 382 |
+
t0 = time.time()
|
| 383 |
+
h.req_q.put((task_id, "run", (func_module, func_name, args, kwargs)))
|
| 384 |
+
|
| 385 |
+
# Drain res_q; tolerate process death.
|
| 386 |
+
deadline = time.time() + (timeout or 3600 * 4)
|
| 387 |
+
while True:
|
| 388 |
+
try:
|
| 389 |
+
tag, status, payload = h.res_q.get(timeout=min(30.0, max(1.0, deadline - time.time())))
|
| 390 |
+
except queue_mod.Empty:
|
| 391 |
+
if not h.process.is_alive():
|
| 392 |
+
# worker died mid-task. respawn and raise so caller can retry.
|
| 393 |
+
print(f"[CPU-POOL] Worker {h.worker_id} died mid-task (exit={h.process.exitcode})")
|
| 394 |
+
self._respawn_worker(h.worker_id)
|
| 395 |
+
raise RuntimeError(f"Worker {h.worker_id} died mid-task")
|
| 396 |
+
if time.time() >= deadline:
|
| 397 |
+
raise TimeoutError(f"CPU pool task timed out after {timeout}s")
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
+
if tag == task_id:
|
| 401 |
+
break
|
| 402 |
+
# stray message (e.g. leftover rss reply). Drop.
|
| 403 |
+
print(f"[CPU-POOL] W{h.worker_id} stray message tag={tag!r}, ignoring")
|
| 404 |
+
|
| 405 |
+
h.total_jobs += 1
|
| 406 |
+
dt = time.time() - t0
|
| 407 |
+
if status == "ok":
|
| 408 |
+
print(f"[CPU-POOL] task#{task_id} ok in {dt:.2f}s on W{h.worker_id}")
|
| 409 |
+
return payload
|
| 410 |
+
exc_type, exc_msg, tb = payload
|
| 411 |
+
print(f"[CPU-POOL] task#{task_id} error on W{h.worker_id}: {exc_type}: {exc_msg}\n{tb}")
|
| 412 |
+
raise RuntimeError(f"Worker error ({exc_type}): {exc_msg}")
|
| 413 |
+
finally:
|
| 414 |
+
# If the worker died we may have respawned it inside _run. In that
|
| 415 |
+
# case it's already in workers[] but not in free_q. Add it back.
|
| 416 |
+
if h.process is not None and not h.process.is_alive():
|
| 417 |
+
# respawn already put nothing back on free_q; add the *new* handle
|
| 418 |
+
new_h = self.workers[h.worker_id]
|
| 419 |
+
if new_h is not h:
|
| 420 |
+
self.free_q.put(new_h.worker_id)
|
| 421 |
+
else:
|
| 422 |
+
# lost — dead and not replaced. Try a respawn now.
|
| 423 |
+
try:
|
| 424 |
+
self._respawn_worker(h.worker_id)
|
| 425 |
+
self.free_q.put(h.worker_id)
|
| 426 |
+
except Exception as e:
|
| 427 |
+
print(f"[CPU-POOL] could not respawn W{h.worker_id}: {e}")
|
| 428 |
+
else:
|
| 429 |
+
self._release_worker(h)
|
| 430 |
+
|
| 431 |
+
# ---- diagnostics -----------------------------------------------------
|
| 432 |
+
|
| 433 |
+
def probe_rss(self, worker_id: int, timeout: float = 10.0) -> int:
|
| 434 |
+
h = self.workers[worker_id]
|
| 435 |
+
task_id = self._next_task_id()
|
| 436 |
+
h.req_q.put((task_id, "rss", None))
|
| 437 |
+
deadline = time.time() + timeout
|
| 438 |
+
while time.time() < deadline:
|
| 439 |
+
tag, status, payload = h.res_q.get(timeout=deadline - time.time())
|
| 440 |
+
if tag == task_id:
|
| 441 |
+
return int(payload)
|
| 442 |
+
raise TimeoutError("rss probe timed out")
|
| 443 |
+
|
| 444 |
+
def load_large(self, worker_id: int, timeout: float = 300.0) -> dict:
|
| 445 |
+
h = self.workers[worker_id]
|
| 446 |
+
task_id = self._next_task_id()
|
| 447 |
+
h.req_q.put((task_id, "load_large", None))
|
| 448 |
+
deadline = time.time() + timeout
|
| 449 |
+
while time.time() < deadline:
|
| 450 |
+
tag, status, payload = h.res_q.get(timeout=deadline - time.time())
|
| 451 |
+
if tag == task_id:
|
| 452 |
+
if status == "ok":
|
| 453 |
+
return payload
|
| 454 |
+
raise RuntimeError(f"load_large failed: {payload}")
|
| 455 |
+
raise TimeoutError("load_large timed out")
|
| 456 |
+
|
| 457 |
+
def stats(self) -> dict:
|
| 458 |
+
return {
|
| 459 |
+
"started": self._started,
|
| 460 |
+
"n_workers": len(self.workers),
|
| 461 |
+
"workers": [
|
| 462 |
+
{
|
| 463 |
+
"id": h.worker_id,
|
| 464 |
+
"pid": h.pid,
|
| 465 |
+
"alive": h.process is not None and h.process.is_alive(),
|
| 466 |
+
"total_jobs": h.total_jobs,
|
| 467 |
+
"boot_time": h.boot_time,
|
| 468 |
+
"snapshots": {k: v for k, v in h.snapshots.items()},
|
| 469 |
+
"load_times": h.load_times,
|
| 470 |
+
}
|
| 471 |
+
for h in self.workers
|
| 472 |
+
],
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# ---------------------------------------------------------------------------
|
| 477 |
+
# Module-level singleton API
|
| 478 |
+
# ---------------------------------------------------------------------------
|
| 479 |
+
|
| 480 |
+
_POOL: Optional[_Pool] = None
|
| 481 |
+
_START_LOCK = threading.Lock()
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def _get_pool() -> _Pool:
|
| 485 |
+
global _POOL
|
| 486 |
+
if _POOL is None:
|
| 487 |
+
with _START_LOCK:
|
| 488 |
+
if _POOL is None:
|
| 489 |
+
_POOL = _Pool()
|
| 490 |
+
return _POOL
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def start_pool(n_workers: int, preload_large: bool = False):
|
| 494 |
+
"""Spawn the persistent worker pool. Idempotent."""
|
| 495 |
+
_get_pool().start(n_workers, preload_large=preload_large)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def is_started() -> bool:
|
| 499 |
+
return _POOL is not None and _POOL._started
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def stats() -> dict:
|
| 503 |
+
return _get_pool().stats()
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def probe_rss(worker_id: int) -> int:
|
| 507 |
+
return _get_pool().probe_rss(worker_id)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def load_large(worker_id: int) -> dict:
|
| 511 |
+
return _get_pool().load_large(worker_id)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def shutdown():
|
| 515 |
+
if _POOL is not None:
|
| 516 |
+
_POOL.shutdown()
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def run_on_persistent_worker(func, args, kwargs, timeout: Optional[float] = None):
|
| 520 |
+
"""Run a function on a free persistent worker. Blocks until done.
|
| 521 |
+
|
| 522 |
+
Caller is responsible for concurrency gating (the wrapper in zero_gpu.py
|
| 523 |
+
uses the same semaphore as the spawn path).
|
| 524 |
+
"""
|
| 525 |
+
return _get_pool().run(func, args, kwargs, timeout=timeout)
|
src/core/zero_gpu.py
CHANGED
|
@@ -269,18 +269,32 @@ def gpu_with_fallback(duration=60):
|
|
| 269 |
|
| 270 |
if CPU_STRATEGY == "subprocess":
|
| 271 |
import time as _time
|
|
|
|
| 272 |
sem = _get_subprocess_semaphore()
|
| 273 |
_t_acq = _time.time()
|
| 274 |
sem.acquire()
|
| 275 |
_wait = _time.time() - _t_acq
|
| 276 |
try:
|
| 277 |
_check_cuda_fork_state(f"before CPU subprocess ({func.__name__})")
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
_check_cuda_fork_state(f"after CPU subprocess ({func.__name__})")
|
| 285 |
return result
|
| 286 |
finally:
|
|
|
|
| 269 |
|
| 270 |
if CPU_STRATEGY == "subprocess":
|
| 271 |
import time as _time
|
| 272 |
+
from config import CPU_WORKER_MODE
|
| 273 |
sem = _get_subprocess_semaphore()
|
| 274 |
_t_acq = _time.time()
|
| 275 |
sem.acquire()
|
| 276 |
_wait = _time.time() - _t_acq
|
| 277 |
try:
|
| 278 |
_check_cuda_fork_state(f"before CPU subprocess ({func.__name__})")
|
| 279 |
+
if CPU_WORKER_MODE == "persistent":
|
| 280 |
+
from .cpu_worker_pool import run_on_persistent_worker, is_started, start_pool
|
| 281 |
+
if not is_started():
|
| 282 |
+
# Lazy start fallback — slow first request.
|
| 283 |
+
from config import CPU_SUBPROCESS_CONCURRENCY, CPU_POOL_PRELOAD_LARGE
|
| 284 |
+
print("[CPU] Pool not started — lazy-starting now")
|
| 285 |
+
start_pool(CPU_SUBPROCESS_CONCURRENCY, preload_large=CPU_POOL_PRELOAD_LARGE)
|
| 286 |
+
print(
|
| 287 |
+
f"[CPU] Running {func.__name__} on persistent worker "
|
| 288 |
+
f"(CPU_WORKER_MODE=persistent, queue_wait={_wait:.2f}s)"
|
| 289 |
+
)
|
| 290 |
+
result = run_on_persistent_worker(func, args, kwargs)
|
| 291 |
+
else:
|
| 292 |
+
print(
|
| 293 |
+
f"[CPU] Running {func.__name__} in isolated subprocess "
|
| 294 |
+
f"(CPU_STRATEGY=subprocess, queue_wait={_wait:.2f}s)"
|
| 295 |
+
)
|
| 296 |
+
from .cpu_subprocess import run_in_cpu_subprocess
|
| 297 |
+
result = run_in_cpu_subprocess(func, args, kwargs)
|
| 298 |
_check_cuda_fork_state(f"after CPU subprocess ({func.__name__})")
|
| 299 |
return result
|
| 300 |
finally:
|
src/ui/event_wiring.py
CHANGED
|
@@ -17,6 +17,8 @@ from src.api.session_api import (
|
|
| 17 |
debug_process,
|
| 18 |
cpu_exec,
|
| 19 |
pool_status,
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
from src.mfa import compute_mfa_timestamps
|
| 22 |
from src.ui.progress_bar import pipeline_progress_bar_html
|
|
@@ -935,6 +937,20 @@ def _wire_api_endpoint(c):
|
|
| 935 |
outputs=[c.api_result],
|
| 936 |
api_name="pool_status",
|
| 937 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
| 939 |
|
| 940 |
def _wire_dev_tab(c):
|
|
|
|
| 17 |
debug_process,
|
| 18 |
cpu_exec,
|
| 19 |
pool_status,
|
| 20 |
+
cpu_pool_status,
|
| 21 |
+
cpu_pool_kill,
|
| 22 |
)
|
| 23 |
from src.mfa import compute_mfa_timestamps
|
| 24 |
from src.ui.progress_bar import pipeline_progress_bar_html
|
|
|
|
| 937 |
outputs=[c.api_result],
|
| 938 |
api_name="pool_status",
|
| 939 |
)
|
| 940 |
+
# Persistent CPU worker pool status — HF-token-gated.
|
| 941 |
+
gr.Button(visible=False).click(
|
| 942 |
+
fn=cpu_pool_status,
|
| 943 |
+
inputs=[c.api_pool_status_token],
|
| 944 |
+
outputs=[c.api_result],
|
| 945 |
+
api_name="cpu_pool_status",
|
| 946 |
+
)
|
| 947 |
+
# Kill a persistent worker — crash-recovery test helper, HF-token-gated.
|
| 948 |
+
gr.Button(visible=False).click(
|
| 949 |
+
fn=cpu_pool_kill,
|
| 950 |
+
inputs=[c.api_pool_status_token, c.api_cpu_exec_module], # reuse token + a string input
|
| 951 |
+
outputs=[c.api_result],
|
| 952 |
+
api_name="cpu_pool_kill",
|
| 953 |
+
)
|
| 954 |
|
| 955 |
|
| 956 |
def _wire_dev_tab(c):
|