"""Manages worker assignment and time limits for live WebRTC streams. Each active WebRTC stream holds a dedicated ``EveWorker`` for its duration (acquiring/releasing per-frame would destroy throughput). The manager: - Assigns workers to streams on first frame (keyed by WebRTC connection ID) - Returns immediately with a "waiting" status when no worker is available (so WebRTC frame handlers never block) - Restores the Face ID gallery on stream start - Enforces a configurable timeout when the queue is full (other users waiting) - Reclaims orphaned workers when a stream goes silent (no frames for 30 s) """ import random import threading import time from dataclasses import dataclass, field from eve_worker_pool import EveWorker, EveWorkerPool, log_worker_activity from face_id_tab import FaceEntry from frame_utils import load_media_frames from log_utils import setup_logger from non_blocking_bridge import NonBlockingInferenceBridge from usage_analytics import UsageTracker logger = setup_logger("LiveStreamManager") _WAIT_TIMEOUT_S = 300.0 # Max seconds to wait for a worker before giving up _ORPHAN_SILENCE_S = 10.0 @dataclass class _StreamEntry: worker: EveWorker start_time: float session_hash: str connection_id: str bridge: NonBlockingInferenceBridge | None = None last_activity: float = field(default_factory=time.monotonic) # When pressure is detected (all workers busy + someone waiting), record the timestamp. # Resets to None when idle workers become available again. pressure_start: float | None = None # Per-stream jittered timeout so streams don't all expire simultaneously pressure_timeout: float = 0.0 @dataclass class _WaitingEntry: """Tracks a stream that is waiting for a worker to become available.""" session_hash: str connection_id: str started_waiting: float last_activity: float registry: dict[int, FaceEntry] class LiveStreamManager: """Manages per-stream worker assignment with timeout enforcement. Streams are keyed by ``connection_id`` — a unique identifier per WebRTC peer connection (e.g. FastRTC's ``webrtc_id``). Args: pool: The ``EveWorkerPool`` to acquire / release workers from. timeout_seconds: Max stream duration before reclaiming the worker, but **only** when the pool has no idle workers (queue is waiting). ``None`` disables pressure-based timeouts entirely. session_lifetime_seconds: Absolute max duration for any live session. After this time the stream is stopped regardless of pressure. ``None`` disables the session lifetime limit. orphan_check_interval_s: How often to scan for stale streams. max_fps: Inference rate when the system is idle. ``None`` disables the FPS cap entirely. min_fps: Inference rate under full system pressure. Defaults to ``max_fps`` (no scaling). """ def __init__( self, pool: EveWorkerPool, timeout_seconds: float | None = None, session_lifetime_seconds: float | None = 240.0, orphan_check_interval_s: float = 5.0, max_fps: float | None = None, min_fps: float | None = None, tracker: UsageTracker | None = None, ): self._pool = pool self._timeout_s = timeout_seconds self._session_lifetime_s = session_lifetime_seconds self._max_fps = max_fps self._min_fps = min_fps self._tracker = tracker self._lock = threading.Lock() self._streams: dict[str, _StreamEntry] = {} self._waiting: dict[str, _WaitingEntry] = {} self._expired: set[str] = set() # connections whose session lifetime ended self._shutting_down = False self._orphan_thread = threading.Thread( target=self._orphan_monitor, args=(orphan_check_interval_s,), daemon=True, ) self._orphan_thread.start() # -- public API -------------------------------------------------------- def get_or_acquire( self, connection_id: str, session_hash: str, registry: dict[int, FaceEntry], ) -> tuple[EveWorker | None, str | None]: """Return the worker for the current stream, acquiring one if needed. This method **never blocks**. If no worker is available on the first frame, it returns ``(None, "waiting")`` so the caller can draw an overlay and return the frame immediately (keeping WebRTC alive). Args: connection_id: Unique WebRTC connection identifier (e.g. webrtc_id). session_hash: Gradio session hash (for pool affinity + gallery restore). registry: Face ID registry ``{user_id: media_path}`` for gallery restore. Returns: ``(worker, None)`` on success. ``(None, "waiting")`` when waiting for a worker. ``(None, reason)`` when the stream should stop (timeout / error). """ # 0. Already expired — don't re-enter the queue with self._lock: if connection_id in self._expired: return None, "Session ended\nclick `Start Inference` to restart." # 1. Already has a worker — check session lifetime, then pressure with self._lock: entry = self._streams.get(connection_id) if entry is not None: now = time.monotonic() entry.last_activity = now # Session lifetime check (absolute cap, independent of pressure) if self._session_lifetime_s is not None: if now - entry.start_time >= self._session_lifetime_s: self._release_entry(connection_id, note="session-expired") with self._lock: self._expired.add(connection_id) return None, ("Session ended\nclick `Start Inference` to restart.") # Pressure-based countdown (only when enabled) if self._timeout_s is not None: under_pressure = self._pool.idle_count == 0 and ( self._pool.waiting_count > 0 or len(self._waiting) > 0 ) if under_pressure: if entry.pressure_start is None: entry.pressure_timeout = self._timeout_s + random.uniform(-10, 5) entry.pressure_start = now logger.info( f"Live stream ({connection_id}) under pressure\n" f"{entry.pressure_timeout:.0f}s countdown started" ) elif now - entry.pressure_start >= entry.pressure_timeout: self._release_entry(connection_id, note="pressure-expired") return None, ( "Stream ended\nother users are waiting. " "Click `Start Inference` to rejoin the queue." ) else: # Pressure relieved — reset countdown if entry.pressure_start is not None: entry.pressure_start = None return entry.worker, None # 2. Already waiting — retry non-blocking acquire with self._lock: waiting = self._waiting.get(connection_id) if waiting is not None: waiting.last_activity = time.monotonic() worker = self._pool.try_acquire(waiting.session_hash) if worker is not None: self._remove_waiting(connection_id, reason="worker_acquired") result = self._setup_stream( connection_id, waiting.session_hash, worker, waiting.registry ) log_worker_activity( logger, "acquired", "live-inference", self._pool, worker.worker_id, note="was queued", ) return result elapsed = time.monotonic() - waiting.started_waiting if elapsed >= _WAIT_TIMEOUT_S: self._remove_waiting(connection_id, reason="timeout") return None, ( f"No workers available\ntimed out after {_WAIT_TIMEOUT_S:.0f}s. " "Please try again later." ) return None, "waiting" # 3. Brand new stream — try non-blocking acquire worker = self._pool.try_acquire(session_hash) if worker is not None: result = self._setup_stream(connection_id, session_hash, worker, registry) log_worker_activity(logger, "acquired", "live-inference", self._pool, worker.worker_id) return result # No worker available — start waiting now = time.monotonic() with self._lock: self._waiting[connection_id] = _WaitingEntry( session_hash=session_hash, connection_id=connection_id, started_waiting=now, last_activity=now, registry=dict(registry) if registry else {}, ) log_worker_activity(logger, "queued", "live-inference", self._pool) if self._tracker: self._tracker.log(session_hash, "live_queue_enter") return None, "waiting" def waiting_position(self, connection_id: str) -> tuple[int, int]: """Return ``(position, total_waiting)`` for the given connection. *position* is 1-based (1 = next in line). If the connection is not in the waiting list, position equals ``total_waiting + 1`` (just joined). """ with self._lock: entries = sorted(self._waiting.values(), key=lambda e: e.started_waiting) total = len(entries) for i, entry in enumerate(entries): if entry.connection_id == connection_id: return i + 1, total return total + 1, total def release(self, connection_id: str) -> None: """Release the worker bound to the given connection (if any).""" self._release_entry(connection_id) self._remove_waiting(connection_id, reason="released") with self._lock: self._expired.discard(connection_id) def countdown_remaining(self, connection_id: str) -> float | None: """Seconds until this stream is reclaimed due to queue pressure. Returns ``None`` if no countdown is active (no pressure: idle workers exist or nobody is waiting). Returns the remaining seconds otherwise. """ with self._lock: entry = self._streams.get(connection_id) if entry is None or entry.pressure_start is None: return None elapsed = time.monotonic() - entry.pressure_start return max(0.0, entry.pressure_timeout - elapsed) def session_remaining(self, connection_id: str) -> float | None: """Seconds left in this stream's session lifetime. Returns ``None`` if no session lifetime is configured or if the connection is not an active stream. Returns the remaining seconds otherwise (clamped to >= 0). """ if self._session_lifetime_s is None: return None with self._lock: entry = self._streams.get(connection_id) if entry is None: return None elapsed = time.monotonic() - entry.start_time return max(0.0, self._session_lifetime_s - elapsed) def estimated_wait(self, position: int) -> float | None: """Estimated seconds until the *position*-th worker frees up. Considers **all** busy workers: - Live-stream workers use session lifetime to estimate remaining time. - Video-processing workers have unknown duration (sorted last). Known ETAs are sorted ascending so position 1 gets the soonest-to-free worker, position 2 the second-soonest, etc. Args: position: 1-based queue position (1 = next in line). Returns: Estimated seconds, or ``None`` when unknown or out of range. """ if position < 1: return None now = time.monotonic() # Known ETAs from live streams with a session lifetime known: list[float] = [] if self._session_lifetime_s is not None: with self._lock: entries = list(self._streams.values()) known = sorted( max(0.0, self._session_lifetime_s - (now - e.start_time)) for e in entries ) # Workers doing video processing (busy, not live-stream) have # unknown remaining time — they could finish at any moment. total_busy = self._pool.worker_count - self._pool.idle_count video_workers = max(0, total_busy - len(known)) all_etas: list[float | None] = list(known) + [None] * video_workers idx = position - 1 if idx >= len(all_etas): return None return all_etas[idx] def get_bridge(self, connection_id: str) -> NonBlockingInferenceBridge | None: """Return the inference bridge for an active stream, or ``None``.""" with self._lock: entry = self._streams.get(connection_id) if entry is not None: return entry.bridge return None def shutdown(self) -> None: """Release all stream workers.""" self._shutting_down = True with self._lock: stream_cids = list(self._streams.keys()) waiting_cids = list(self._waiting.keys()) self._expired.clear() for cid in waiting_cids: self._remove_waiting(cid, reason="shutdown") for cid in stream_cids: self._release_entry(cid) # -- internals --------------------------------------------------------- def _setup_stream( self, connection_id: str, session_hash: str, worker: EveWorker, registry: dict[int, FaceEntry], ) -> tuple[EveWorker, None]: """Configure a newly acquired worker and register the stream.""" worker.is_live_stream = True # Gallery restore runs on the bridge thread so the asyncio thread # is never blocked by disk I/O or pipe round-trips. While the # bridge is setting up, submit_and_get_latest returns None and # the user sees their raw camera feed. frozen_registry = dict(registry) if registry else {} def _restore_gallery() -> None: if frozen_registry: frames_per_user = [ load_media_frames(entry.path) for entry in frozen_registry.values() ] if frames_per_user: worker.send_restore_gallery(frames_per_user) else: worker.send_remove_all_users() else: worker.send_remove_all_users() bridge = NonBlockingInferenceBridge( worker, setup_fn=_restore_gallery, max_fps=self._max_fps, min_fps=self._min_fps, load_fn=lambda: self._pool._shared_load.value, ) now = time.monotonic() entry = _StreamEntry( worker=worker, start_time=now, session_hash=session_hash, connection_id=connection_id, bridge=bridge, last_activity=now, ) with self._lock: self._streams[connection_id] = entry return worker, None def _remove_waiting(self, connection_id: str, reason: str = "") -> None: """Remove a connection from the waiting queue and log the exit event.""" with self._lock: entry = self._waiting.pop(connection_id, None) if entry is not None and self._tracker: wait_seconds = time.monotonic() - entry.started_waiting self._tracker.log( entry.session_hash, "live_queue_exit", wait_seconds=round(wait_seconds, 1), reason=reason or "unknown", ) def _release_entry(self, connection_id: str, note: str = "") -> None: with self._lock: entry = self._streams.pop(connection_id, None) if entry is not None: worker_id = entry.worker.worker_id # Use last_activity (last frame received) instead of now — avoids # inflating duration by the orphan-detection silence period. duration = entry.last_activity - entry.start_time if entry.bridge is not None: entry.bridge.stop() self._pool.release(entry.worker) log_worker_activity( logger, "released", "live-inference", self._pool, worker_id, note=note ) if self._tracker: self._tracker.log( entry.session_hash, "live_stop", duration_seconds=round(duration, 1), reason=note or "normal", ) def _orphan_monitor(self, interval: float) -> None: """Periodically release workers whose stream has gone silent.""" while not self._shutting_down: time.sleep(interval) if self._shutting_down: break self._orphan_monitor_once() def _orphan_monitor_once(self) -> None: """Single pass of orphan detection — also callable from tests.""" now = time.monotonic() with self._lock: stream_snapshot = list(self._streams.items()) waiting_snapshot = list(self._waiting.items()) for cid, entry in stream_snapshot: silence = now - entry.last_activity if silence >= _ORPHAN_SILENCE_S: logger.warning( f"Orphaned live stream detected ({cid}, " f"worker {entry.worker.worker_id}, " f"silent for {silence:.0f}s) — releasing" ) self._release_entry(cid, note="orphan") for cid, waiting in waiting_snapshot: silence = now - waiting.last_activity if silence >= _ORPHAN_SILENCE_S: logger.warning( f"Orphaned waiting entry detected ({cid}, " f"silent for {silence:.0f}s) — removing" ) self._remove_waiting(cid, reason="orphan")