"""TaskRegistry — Phase 1 cancellation + live-process tracking. This is the small piece of in-memory state that lets the kernel: * know which tasks are still running so it can stream updates, * know whether the user asked for the task to be cancelled, * hold a reference to the currently-active subprocess for each task so that cancellation can actually kill it (not just set a flag and hope). It is intentionally process-local: HF Spaces only runs one uvicorn process, so an in-memory registry is fine. If we ever go multi-worker we replace this with Redis/Postgres advisory locks — the public surface stays the same. """ from __future__ import annotations import os import signal import threading from dataclasses import dataclass, field from typing import Dict, Optional @dataclass class TaskHandle: task_id: int cancelled: bool = False active_pgid: Optional[int] = None # process-group id of the active shell, if any lock: threading.Lock = field(default_factory=threading.Lock) class TaskRegistry: def __init__(self) -> None: self._tasks: Dict[int, TaskHandle] = {} self._lock = threading.Lock() def register(self, task_id: int) -> TaskHandle: with self._lock: handle = self._tasks.get(task_id) if handle is None: handle = TaskHandle(task_id=task_id) self._tasks[task_id] = handle else: # Re-running an old id: reset its cancellation flag. handle.cancelled = False handle.active_pgid = None return handle def get(self, task_id: int) -> Optional[TaskHandle]: with self._lock: return self._tasks.get(task_id) def is_cancelled(self, task_id: int) -> bool: handle = self.get(task_id) return bool(handle and handle.cancelled) def cancel(self, task_id: int) -> bool: """Flip the cancel flag and SIGKILL any live subprocess for the task. Returns True if a handle existed (cancel was actually applied). """ handle = self.get(task_id) if not handle: return False with handle.lock: handle.cancelled = True pgid = handle.active_pgid if pgid is not None: for sig in (signal.SIGTERM, signal.SIGKILL): try: os.killpg(pgid, sig) except (ProcessLookupError, PermissionError): break except Exception: break return True def set_active_pgid(self, task_id: int, pgid: Optional[int]) -> None: handle = self.get(task_id) if handle: with handle.lock: handle.active_pgid = pgid def release(self, task_id: int) -> None: """Optional cleanup — keep entries around so the API can report cancellation history; just clear pgid.""" handle = self.get(task_id) if handle: with handle.lock: handle.active_pgid = None # Module-level singleton _registry: Optional[TaskRegistry] = None def get_registry() -> TaskRegistry: global _registry if _registry is None: _registry = TaskRegistry() return _registry