quranic-universal-aligner / src /core /cpu_worker_pool.py
hetchyy's picture
Upload folder using huggingface_hub
419fe6e verified
"""Persistent CPU worker pool — prototype.
Replaces the spawn-per-request path (cpu_subprocess.py) with long-lived
workers that preload VAD + ASR (Base) once and stay ready. ASR Large is
loaded on-demand in each worker (and cached there) so steady-state RAM
stays predictable.
Gated behind env var CPU_WORKER_MODE=persistent. CPU_WORKER_MODE=spawn
(default) keeps the existing spawn-per-request behavior.
Semaphore capacity = CPU_SUBPROCESS_CONCURRENCY (shared with spawn path).
A free-worker queue gives O(1) idle-worker pickup and guarantees at most
one concurrent job per worker.
"""
from __future__ import annotations
import importlib
import multiprocessing as mp
import os
import queue as queue_mod
import signal
import sys
import threading
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, Optional
# ---------------------------------------------------------------------------
# Worker side
# ---------------------------------------------------------------------------
def _worker_loop(
worker_id: int,
extra_paths: list,
req_q: mp.Queue,
res_q: mp.Queue,
ready_ev,
preload_large: bool,
):
"""Long-lived worker body. Runs in a spawn-context process.
Steps:
1. Env hygiene — hide CUDA, disable ZeroGPU patches.
2. Import project modules, call force_cpu_mode().
3. Preload VAD + ASR Base (Large optional).
4. Signal `ready_ev`.
5. Loop: pull task → execute → push result.
Tasks are pickled tuples: (task_id, kind, payload)
kind="run": payload=(func_module, func_name, args, kwargs)
kind="rss": payload=None (return rss bytes)
kind="load_large": payload=None (preload ASR Large if not cached)
kind="shutdown": payload=None (exit loop)
"""
# ---- Env hygiene BEFORE any torch import ----------------------------
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["SPACES_ZERO_GPU"] = ""
# Suppress the HF download progress bars — the parent app sets this but
# the spawned child inherits only env, not module state.
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
# Restore sys.path from parent so src/ is importable
for p in extra_paths:
if p and p not in sys.path:
sys.path.insert(0, p)
# Helpful process label for ps/htop
try:
import setproctitle # type: ignore
setproctitle.setproctitle(f"cpu-worker-{worker_id}")
except Exception:
pass
log = lambda msg: print(f"[CPU-POOL/W{worker_id}] {msg}", flush=True)
# ---- RSS probe ------------------------------------------------------
def _rss_bytes() -> int:
try:
import psutil # type: ignore
return psutil.Process(os.getpid()).memory_info().rss
except Exception:
try:
with open(f"/proc/{os.getpid()}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1]) * 1024
except Exception:
return 0
return 0
snapshots = {"start": _rss_bytes()}
t0 = time.time()
log(f"booted pid={os.getpid()} rss={snapshots['start']/1e6:.1f}MB")
# ---- Imports + force CPU mode --------------------------------------
try:
from src.core.zero_gpu import force_cpu_mode
force_cpu_mode()
snapshots["after_imports"] = _rss_bytes()
log(f"imports done +{(time.time()-t0):.1f}s rss={snapshots['after_imports']/1e6:.1f}MB")
except Exception as e:
log(f"FATAL imports: {e}\n{traceback.format_exc()}")
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
return
load_times = {}
# ---- Preload VAD ---------------------------------------------------
try:
t = time.time()
from src.segmenter.segmenter_model import load_segmenter
load_segmenter()
load_times["vad"] = time.time() - t
snapshots["after_vad"] = _rss_bytes()
log(f"VAD loaded in {load_times['vad']:.2f}s rss={snapshots['after_vad']/1e6:.1f}MB")
except Exception as e:
log(f"VAD load failed: {e}")
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
return
# ---- Preload ASR Base ---------------------------------------------
try:
t = time.time()
from src.alignment.phoneme_asr import load_phoneme_asr
load_phoneme_asr("Base")
load_times["asr_base"] = time.time() - t
snapshots["after_asr_base"] = _rss_bytes()
log(f"ASR Base loaded in {load_times['asr_base']:.2f}s rss={snapshots['after_asr_base']/1e6:.1f}MB")
except Exception as e:
log(f"ASR Base load failed: {e}")
res_q.put(("__boot_error__", "error", (type(e).__name__, str(e), traceback.format_exc())))
return
# ---- Preload caches (ngram index, phoneme chapters) ----------------
try:
t = time.time()
from src.alignment.ngram_index import get_ngram_index
from src.alignment.phoneme_matcher_cache import preload_all_chapters
get_ngram_index()
preload_all_chapters()
load_times["caches"] = time.time() - t
snapshots["after_caches"] = _rss_bytes()
log(f"caches loaded in {load_times['caches']:.2f}s rss={snapshots['after_caches']/1e6:.1f}MB")
except Exception as e:
log(f"caches load failed (non-fatal): {e}")
# ---- Optionally preload ASR Large ---------------------------------
if preload_large:
try:
t = time.time()
from src.alignment.phoneme_asr import load_phoneme_asr
load_phoneme_asr("Large")
load_times["asr_large"] = time.time() - t
snapshots["after_asr_large"] = _rss_bytes()
log(f"ASR Large loaded in {load_times['asr_large']:.2f}s rss={snapshots['after_asr_large']/1e6:.1f}MB")
except Exception as e:
log(f"ASR Large preload failed: {e}")
# ---- Warm up resampler --------------------------------------------
try:
import numpy as np, librosa
from config import RESAMPLE_TYPE
_ = librosa.resample(np.zeros(1600, dtype=np.float32),
orig_sr=44100, target_sr=16000, res_type=RESAMPLE_TYPE)
except Exception:
pass
snapshots["ready"] = _rss_bytes()
total_boot = time.time() - t0
log(f"READY in {total_boot:.2f}s, final rss={snapshots['ready']/1e6:.1f}MB")
# Signal parent that this worker booted successfully
res_q.put(("__ready__", "ok", {
"worker_id": worker_id,
"pid": os.getpid(),
"snapshots": snapshots,
"load_times": load_times,
"boot_time": total_boot,
}))
ready_ev.set()
# ---- Main loop -----------------------------------------------------
while True:
try:
item = req_q.get()
except (EOFError, OSError, KeyboardInterrupt):
break
if item is None:
break
task_id, kind, payload = item
try:
if kind == "shutdown":
break
elif kind == "rss":
res_q.put((task_id, "ok", _rss_bytes()))
continue
elif kind == "load_large":
try:
from src.alignment.phoneme_asr import load_phoneme_asr
t = time.time()
load_phoneme_asr("Large")
res_q.put((task_id, "ok", {"load_time": time.time() - t, "rss": _rss_bytes()}))
except Exception as e:
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
continue
elif kind == "run":
func_module, func_name, args, kwargs = payload
try:
module = importlib.import_module(func_module)
func = getattr(module, func_name)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
result = func(*args, **kwargs)
res_q.put((task_id, "ok", result))
except Exception as e:
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
continue
else:
res_q.put((task_id, "error", ("ValueError", f"unknown kind {kind!r}", "")))
except Exception as e:
# Catch-all so the loop survives
res_q.put((task_id, "error", (type(e).__name__, str(e), traceback.format_exc())))
log("exiting cleanly")
# ---------------------------------------------------------------------------
# Parent side
# ---------------------------------------------------------------------------
@dataclass
class _WorkerHandle:
worker_id: int
process: Optional[Any] = None
req_q: Optional[Any] = None
res_q: Optional[Any] = None
ready_ev: Optional[Any] = None
snapshots: dict = field(default_factory=dict)
load_times: dict = field(default_factory=dict)
boot_time: float = 0.0
pid: Optional[int] = None
total_jobs: int = 0
lock: threading.Lock = field(default_factory=threading.Lock)
class _Pool:
def __init__(self):
self.ctx = mp.get_context("spawn")
self.workers: list[_WorkerHandle] = []
self.free_q: "queue_mod.Queue[int]" = queue_mod.Queue()
self._started = False
self._lock = threading.Lock()
self._task_counter = 0
self._respawn_count = 0
self._preload_large = False
self._extra_paths: list[str] = []
# ---- lifecycle -------------------------------------------------------
def start(self, n_workers: int, preload_large: bool = False, boot_timeout: float = 600.0):
with self._lock:
if self._started:
return
self._started = True
self._preload_large = preload_large
self._extra_paths = list(sys.path)
print(f"[CPU-POOL] Starting {n_workers} persistent worker(s) preload_large={preload_large}")
for i in range(n_workers):
h = self._spawn_worker(i)
self.workers.append(h)
# Wait for ready signal from each (serial — avoids RAM spike)
for h in self.workers:
self._wait_ready(h, timeout=boot_timeout)
self.free_q.put(h.worker_id)
print(f"[CPU-POOL] All {n_workers} workers READY")
def _spawn_worker(self, worker_id: int) -> _WorkerHandle:
req_q = self.ctx.Queue()
res_q = self.ctx.Queue()
ready_ev = self.ctx.Event()
p = self.ctx.Process(
target=_worker_loop,
args=(worker_id, self._extra_paths, req_q, res_q, ready_ev, self._preload_large),
daemon=True,
name=f"cpu-worker-{worker_id}",
)
p.start()
return _WorkerHandle(
worker_id=worker_id,
process=p,
req_q=req_q,
res_q=res_q,
ready_ev=ready_ev,
pid=p.pid,
)
def _wait_ready(self, h: _WorkerHandle, timeout: float):
"""Drain res_q until we see the __ready__ tag or a __boot_error__."""
deadline = time.time() + timeout
while time.time() < deadline:
try:
tag, status, payload = h.res_q.get(timeout=min(10.0, deadline - time.time()))
except queue_mod.Empty:
if h.process is not None and not h.process.is_alive():
raise RuntimeError(f"Worker {h.worker_id} died during boot (exit={h.process.exitcode})")
continue
if tag == "__ready__":
h.snapshots = payload.get("snapshots", {})
h.load_times = payload.get("load_times", {})
h.boot_time = payload.get("boot_time", 0.0)
h.pid = payload.get("pid", h.pid)
return
if tag == "__boot_error__":
exc_type, exc_msg, tb = payload
raise RuntimeError(f"Worker {h.worker_id} boot failed: {exc_type}: {exc_msg}\n{tb}")
# Unexpected tag during boot — ignore and keep waiting.
raise TimeoutError(f"Worker {h.worker_id} did not become ready within {timeout}s")
def shutdown(self, timeout: float = 5.0):
with self._lock:
if not self._started:
return
for h in self.workers:
try:
h.req_q.put((0, "shutdown", None))
except Exception:
pass
for h in self.workers:
try:
if h.process is not None:
h.process.join(timeout=timeout)
if h.process.is_alive():
h.process.kill()
h.process.join(timeout=2)
except Exception:
pass
self.workers.clear()
self._started = False
# ---- task dispatch ---------------------------------------------------
def _next_task_id(self) -> int:
with self._lock:
self._task_counter += 1
return self._task_counter
def _acquire_worker(self, timeout: Optional[float] = None) -> _WorkerHandle:
wid = self.free_q.get(timeout=timeout)
# Validate the worker is still alive; if not, respawn in-place.
h = self.workers[wid]
if h.process is None or not h.process.is_alive():
print(f"[CPU-POOL] Worker {wid} dead on acquire — respawning")
self._respawn_worker(wid)
h = self.workers[wid]
return h
def _release_worker(self, h: _WorkerHandle):
self.free_q.put(h.worker_id)
def _respawn_worker(self, worker_id: int):
"""Replace a dead worker in-place. Blocks until ready."""
t0 = time.time()
new_h = self._spawn_worker(worker_id)
self._wait_ready(new_h, timeout=600.0)
self.workers[worker_id] = new_h
self._respawn_count += 1
print(f"[CPU-POOL] Worker {worker_id} respawned in {time.time()-t0:.1f}s (new pid={new_h.pid})")
def run(self, func, args, kwargs, timeout: Optional[float] = None) -> Any:
if not self._started:
raise RuntimeError("Pool not started")
h = self._acquire_worker(timeout=timeout)
try:
task_id = self._next_task_id()
func_module = func.__module__
func_name = func.__qualname__
print(f"[CPU-POOL] dispatch task#{task_id} {func_module}.{func_name} -> W{h.worker_id} (pid={h.pid})")
t0 = time.time()
h.req_q.put((task_id, "run", (func_module, func_name, args, kwargs)))
# Drain res_q; tolerate process death.
deadline = time.time() + (timeout or 3600 * 4)
while True:
try:
tag, status, payload = h.res_q.get(timeout=min(30.0, max(1.0, deadline - time.time())))
except queue_mod.Empty:
if not h.process.is_alive():
# worker died mid-task. respawn and raise so caller can retry.
print(f"[CPU-POOL] Worker {h.worker_id} died mid-task (exit={h.process.exitcode})")
self._respawn_worker(h.worker_id)
raise RuntimeError(f"Worker {h.worker_id} died mid-task")
if time.time() >= deadline:
raise TimeoutError(f"CPU pool task timed out after {timeout}s")
continue
if tag == task_id:
break
# stray message (e.g. leftover rss reply). Drop.
print(f"[CPU-POOL] W{h.worker_id} stray message tag={tag!r}, ignoring")
h.total_jobs += 1
dt = time.time() - t0
if status == "ok":
print(f"[CPU-POOL] task#{task_id} ok in {dt:.2f}s on W{h.worker_id}")
return payload
exc_type, exc_msg, tb = payload
print(f"[CPU-POOL] task#{task_id} error on W{h.worker_id}: {exc_type}: {exc_msg}\n{tb}")
raise RuntimeError(f"Worker error ({exc_type}): {exc_msg}")
finally:
# If the worker died we may have respawned it inside _run. In that
# case it's already in workers[] but not in free_q. Add it back.
if h.process is not None and not h.process.is_alive():
# respawn already put nothing back on free_q; add the *new* handle
new_h = self.workers[h.worker_id]
if new_h is not h:
self.free_q.put(new_h.worker_id)
else:
# lost — dead and not replaced. Try a respawn now.
try:
self._respawn_worker(h.worker_id)
self.free_q.put(h.worker_id)
except Exception as e:
print(f"[CPU-POOL] could not respawn W{h.worker_id}: {e}")
else:
self._release_worker(h)
# ---- diagnostics -----------------------------------------------------
def probe_rss(self, worker_id: int, timeout: float = 10.0) -> int:
h = self.workers[worker_id]
task_id = self._next_task_id()
h.req_q.put((task_id, "rss", None))
deadline = time.time() + timeout
while time.time() < deadline:
tag, status, payload = h.res_q.get(timeout=deadline - time.time())
if tag == task_id:
return int(payload)
raise TimeoutError("rss probe timed out")
def load_large(self, worker_id: int, timeout: float = 300.0) -> dict:
h = self.workers[worker_id]
task_id = self._next_task_id()
h.req_q.put((task_id, "load_large", None))
deadline = time.time() + timeout
while time.time() < deadline:
tag, status, payload = h.res_q.get(timeout=deadline - time.time())
if tag == task_id:
if status == "ok":
return payload
raise RuntimeError(f"load_large failed: {payload}")
raise TimeoutError("load_large timed out")
def stats(self) -> dict:
# Peek free_q without popping — derives per-worker is_busy. Queue.queue
# is a private-but-stable deque attribute; the read is best-effort and
# transient mismatches during pickup/release are acceptable for the
# telemetry sampler (periodic, non-critical).
try:
free_ids = set(list(self.free_q.queue))
except Exception:
free_ids = set()
return {
"started": self._started,
"n_workers": len(self.workers),
"free_count": self.free_q.qsize(),
"busy_count": max(0, len(self.workers) - self.free_q.qsize()),
"respawn_count": self._respawn_count,
"workers": [
{
"id": h.worker_id,
"pid": h.pid,
"alive": h.process is not None and h.process.is_alive(),
"is_busy": h.worker_id not in free_ids,
"total_jobs": h.total_jobs,
"boot_time": h.boot_time,
"snapshots": {k: v for k, v in h.snapshots.items()},
"load_times": h.load_times,
}
for h in self.workers
],
}
# ---------------------------------------------------------------------------
# Module-level singleton API
# ---------------------------------------------------------------------------
_POOL: Optional[_Pool] = None
_START_LOCK = threading.Lock()
def _get_pool() -> _Pool:
global _POOL
if _POOL is None:
with _START_LOCK:
if _POOL is None:
_POOL = _Pool()
return _POOL
def start_pool(n_workers: int, preload_large: bool = False):
"""Spawn the persistent worker pool. Idempotent."""
_get_pool().start(n_workers, preload_large=preload_large)
def is_started() -> bool:
return _POOL is not None and _POOL._started
def stats() -> dict:
return _get_pool().stats()
def probe_rss(worker_id: int) -> int:
return _get_pool().probe_rss(worker_id)
def load_large(worker_id: int) -> dict:
return _get_pool().load_large(worker_id)
def shutdown():
if _POOL is not None:
_POOL.shutdown()
def run_on_persistent_worker(func, args, kwargs, timeout: Optional[float] = None):
"""Run a function on a free persistent worker. Blocks until done.
Caller is responsible for concurrency gating (the wrapper in zero_gpu.py
uses the same semaphore as the spawn path).
"""
return _get_pool().run(func, args, kwargs, timeout=timeout)