Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Manages worker assignment and time limits for live WebRTC streams. | |
| Each active WebRTC stream holds a dedicated ``EveWorker`` for its duration | |
| (acquiring/releasing per-frame would destroy throughput). The manager: | |
| - Assigns workers to streams on first frame (keyed by WebRTC connection ID) | |
| - Returns immediately with a "waiting" status when no worker is available | |
| (so WebRTC frame handlers never block) | |
| - Restores the Face ID gallery on stream start | |
| - Enforces a configurable timeout when the queue is full (other users waiting) | |
| - Reclaims orphaned workers when a stream goes silent (no frames for 30 s) | |
| """ | |
| import random | |
| import threading | |
| import time | |
| from dataclasses import dataclass, field | |
| from eve_worker_pool import EveWorker, EveWorkerPool, log_worker_activity | |
| from face_id_tab import FaceEntry | |
| from frame_utils import load_media_frames | |
| from log_utils import setup_logger | |
| from non_blocking_bridge import NonBlockingInferenceBridge | |
| from usage_analytics import UsageTracker | |
| logger = setup_logger("LiveStreamManager") | |
| _WAIT_TIMEOUT_S = 300.0 # Max seconds to wait for a worker before giving up | |
| _ORPHAN_SILENCE_S = 10.0 | |
| class _StreamEntry: | |
| worker: EveWorker | |
| start_time: float | |
| session_hash: str | |
| connection_id: str | |
| bridge: NonBlockingInferenceBridge | None = None | |
| last_activity: float = field(default_factory=time.monotonic) | |
| # When pressure is detected (all workers busy + someone waiting), record the timestamp. | |
| # Resets to None when idle workers become available again. | |
| pressure_start: float | None = None | |
| # Per-stream jittered timeout so streams don't all expire simultaneously | |
| pressure_timeout: float = 0.0 | |
| class _WaitingEntry: | |
| """Tracks a stream that is waiting for a worker to become available.""" | |
| session_hash: str | |
| connection_id: str | |
| started_waiting: float | |
| last_activity: float | |
| registry: dict[int, FaceEntry] | |
| class LiveStreamManager: | |
| """Manages per-stream worker assignment with timeout enforcement. | |
| Streams are keyed by ``connection_id`` — a unique identifier per WebRTC | |
| peer connection (e.g. FastRTC's ``webrtc_id``). | |
| Args: | |
| pool: The ``EveWorkerPool`` to acquire / release workers from. | |
| timeout_seconds: Max stream duration before reclaiming the worker, | |
| but **only** when the pool has no idle workers (queue is waiting). | |
| ``None`` disables pressure-based timeouts entirely. | |
| session_lifetime_seconds: Absolute max duration for any live session. | |
| After this time the stream is stopped regardless of pressure. | |
| ``None`` disables the session lifetime limit. | |
| orphan_check_interval_s: How often to scan for stale streams. | |
| max_fps: Inference rate when the system is idle. ``None`` | |
| disables the FPS cap entirely. | |
| min_fps: Inference rate under full system pressure. Defaults | |
| to ``max_fps`` (no scaling). | |
| """ | |
| def __init__( | |
| self, | |
| pool: EveWorkerPool, | |
| timeout_seconds: float | None = None, | |
| session_lifetime_seconds: float | None = 240.0, | |
| orphan_check_interval_s: float = 5.0, | |
| max_fps: float | None = None, | |
| min_fps: float | None = None, | |
| tracker: UsageTracker | None = None, | |
| ): | |
| self._pool = pool | |
| self._timeout_s = timeout_seconds | |
| self._session_lifetime_s = session_lifetime_seconds | |
| self._max_fps = max_fps | |
| self._min_fps = min_fps | |
| self._tracker = tracker | |
| self._lock = threading.Lock() | |
| self._streams: dict[str, _StreamEntry] = {} | |
| self._waiting: dict[str, _WaitingEntry] = {} | |
| self._expired: set[str] = set() # connections whose session lifetime ended | |
| self._shutting_down = False | |
| self._orphan_thread = threading.Thread( | |
| target=self._orphan_monitor, | |
| args=(orphan_check_interval_s,), | |
| daemon=True, | |
| ) | |
| self._orphan_thread.start() | |
| # -- public API -------------------------------------------------------- | |
| def get_or_acquire( | |
| self, | |
| connection_id: str, | |
| session_hash: str, | |
| registry: dict[int, FaceEntry], | |
| ) -> tuple[EveWorker | None, str | None]: | |
| """Return the worker for the current stream, acquiring one if needed. | |
| This method **never blocks**. If no worker is available on the first | |
| frame, it returns ``(None, "waiting")`` so the caller can draw an | |
| overlay and return the frame immediately (keeping WebRTC alive). | |
| Args: | |
| connection_id: Unique WebRTC connection identifier (e.g. webrtc_id). | |
| session_hash: Gradio session hash (for pool affinity + gallery restore). | |
| registry: Face ID registry ``{user_id: media_path}`` for gallery restore. | |
| Returns: | |
| ``(worker, None)`` on success. | |
| ``(None, "waiting")`` when waiting for a worker. | |
| ``(None, reason)`` when the stream should stop (timeout / error). | |
| """ | |
| # 0. Already expired — don't re-enter the queue | |
| with self._lock: | |
| if connection_id in self._expired: | |
| return None, "Session ended\nclick `Start Inference` to restart." | |
| # 1. Already has a worker — check session lifetime, then pressure | |
| with self._lock: | |
| entry = self._streams.get(connection_id) | |
| if entry is not None: | |
| now = time.monotonic() | |
| entry.last_activity = now | |
| # Session lifetime check (absolute cap, independent of pressure) | |
| if self._session_lifetime_s is not None: | |
| if now - entry.start_time >= self._session_lifetime_s: | |
| self._release_entry(connection_id, note="session-expired") | |
| with self._lock: | |
| self._expired.add(connection_id) | |
| return None, ("Session ended\nclick `Start Inference` to restart.") | |
| # Pressure-based countdown (only when enabled) | |
| if self._timeout_s is not None: | |
| under_pressure = self._pool.idle_count == 0 and ( | |
| self._pool.waiting_count > 0 or len(self._waiting) > 0 | |
| ) | |
| if under_pressure: | |
| if entry.pressure_start is None: | |
| entry.pressure_timeout = self._timeout_s + random.uniform(-10, 5) | |
| entry.pressure_start = now | |
| logger.info( | |
| f"Live stream ({connection_id}) under pressure\n" | |
| f"{entry.pressure_timeout:.0f}s countdown started" | |
| ) | |
| elif now - entry.pressure_start >= entry.pressure_timeout: | |
| self._release_entry(connection_id, note="pressure-expired") | |
| return None, ( | |
| "Stream ended\nother users are waiting. " | |
| "Click `Start Inference` to rejoin the queue." | |
| ) | |
| else: | |
| # Pressure relieved — reset countdown | |
| if entry.pressure_start is not None: | |
| entry.pressure_start = None | |
| return entry.worker, None | |
| # 2. Already waiting — retry non-blocking acquire | |
| with self._lock: | |
| waiting = self._waiting.get(connection_id) | |
| if waiting is not None: | |
| waiting.last_activity = time.monotonic() | |
| worker = self._pool.try_acquire(waiting.session_hash) | |
| if worker is not None: | |
| self._remove_waiting(connection_id, reason="worker_acquired") | |
| result = self._setup_stream( | |
| connection_id, waiting.session_hash, worker, waiting.registry | |
| ) | |
| log_worker_activity( | |
| logger, | |
| "acquired", | |
| "live-inference", | |
| self._pool, | |
| worker.worker_id, | |
| note="was queued", | |
| ) | |
| return result | |
| elapsed = time.monotonic() - waiting.started_waiting | |
| if elapsed >= _WAIT_TIMEOUT_S: | |
| self._remove_waiting(connection_id, reason="timeout") | |
| return None, ( | |
| f"No workers available\ntimed out after {_WAIT_TIMEOUT_S:.0f}s. " | |
| "Please try again later." | |
| ) | |
| return None, "waiting" | |
| # 3. Brand new stream — try non-blocking acquire | |
| worker = self._pool.try_acquire(session_hash) | |
| if worker is not None: | |
| result = self._setup_stream(connection_id, session_hash, worker, registry) | |
| log_worker_activity(logger, "acquired", "live-inference", self._pool, worker.worker_id) | |
| return result | |
| # No worker available — start waiting | |
| now = time.monotonic() | |
| with self._lock: | |
| self._waiting[connection_id] = _WaitingEntry( | |
| session_hash=session_hash, | |
| connection_id=connection_id, | |
| started_waiting=now, | |
| last_activity=now, | |
| registry=dict(registry) if registry else {}, | |
| ) | |
| log_worker_activity(logger, "queued", "live-inference", self._pool) | |
| if self._tracker: | |
| self._tracker.log(session_hash, "live_queue_enter") | |
| return None, "waiting" | |
| def waiting_position(self, connection_id: str) -> tuple[int, int]: | |
| """Return ``(position, total_waiting)`` for the given connection. | |
| *position* is 1-based (1 = next in line). If the connection is not in | |
| the waiting list, position equals ``total_waiting + 1`` (just joined). | |
| """ | |
| with self._lock: | |
| entries = sorted(self._waiting.values(), key=lambda e: e.started_waiting) | |
| total = len(entries) | |
| for i, entry in enumerate(entries): | |
| if entry.connection_id == connection_id: | |
| return i + 1, total | |
| return total + 1, total | |
| def release(self, connection_id: str) -> None: | |
| """Release the worker bound to the given connection (if any).""" | |
| self._release_entry(connection_id) | |
| self._remove_waiting(connection_id, reason="released") | |
| with self._lock: | |
| self._expired.discard(connection_id) | |
| def countdown_remaining(self, connection_id: str) -> float | None: | |
| """Seconds until this stream is reclaimed due to queue pressure. | |
| Returns ``None`` if no countdown is active (no pressure: idle workers | |
| exist or nobody is waiting). Returns the remaining seconds otherwise. | |
| """ | |
| with self._lock: | |
| entry = self._streams.get(connection_id) | |
| if entry is None or entry.pressure_start is None: | |
| return None | |
| elapsed = time.monotonic() - entry.pressure_start | |
| return max(0.0, entry.pressure_timeout - elapsed) | |
| def session_remaining(self, connection_id: str) -> float | None: | |
| """Seconds left in this stream's session lifetime. | |
| Returns ``None`` if no session lifetime is configured or if the | |
| connection is not an active stream. Returns the remaining seconds | |
| otherwise (clamped to >= 0). | |
| """ | |
| if self._session_lifetime_s is None: | |
| return None | |
| with self._lock: | |
| entry = self._streams.get(connection_id) | |
| if entry is None: | |
| return None | |
| elapsed = time.monotonic() - entry.start_time | |
| return max(0.0, self._session_lifetime_s - elapsed) | |
| def estimated_wait(self, position: int) -> float | None: | |
| """Estimated seconds until the *position*-th worker frees up. | |
| Considers **all** busy workers: | |
| - Live-stream workers use session lifetime to estimate remaining time. | |
| - Video-processing workers have unknown duration (sorted last). | |
| Known ETAs are sorted ascending so position 1 gets the soonest-to-free | |
| worker, position 2 the second-soonest, etc. | |
| Args: | |
| position: 1-based queue position (1 = next in line). | |
| Returns: | |
| Estimated seconds, or ``None`` when unknown or out of range. | |
| """ | |
| if position < 1: | |
| return None | |
| now = time.monotonic() | |
| # Known ETAs from live streams with a session lifetime | |
| known: list[float] = [] | |
| if self._session_lifetime_s is not None: | |
| with self._lock: | |
| entries = list(self._streams.values()) | |
| known = sorted( | |
| max(0.0, self._session_lifetime_s - (now - e.start_time)) for e in entries | |
| ) | |
| # Workers doing video processing (busy, not live-stream) have | |
| # unknown remaining time — they could finish at any moment. | |
| total_busy = self._pool.worker_count - self._pool.idle_count | |
| video_workers = max(0, total_busy - len(known)) | |
| all_etas: list[float | None] = list(known) + [None] * video_workers | |
| idx = position - 1 | |
| if idx >= len(all_etas): | |
| return None | |
| return all_etas[idx] | |
| def get_bridge(self, connection_id: str) -> NonBlockingInferenceBridge | None: | |
| """Return the inference bridge for an active stream, or ``None``.""" | |
| with self._lock: | |
| entry = self._streams.get(connection_id) | |
| if entry is not None: | |
| return entry.bridge | |
| return None | |
| def shutdown(self) -> None: | |
| """Release all stream workers.""" | |
| self._shutting_down = True | |
| with self._lock: | |
| stream_cids = list(self._streams.keys()) | |
| waiting_cids = list(self._waiting.keys()) | |
| self._expired.clear() | |
| for cid in waiting_cids: | |
| self._remove_waiting(cid, reason="shutdown") | |
| for cid in stream_cids: | |
| self._release_entry(cid) | |
| # -- internals --------------------------------------------------------- | |
| def _setup_stream( | |
| self, | |
| connection_id: str, | |
| session_hash: str, | |
| worker: EveWorker, | |
| registry: dict[int, FaceEntry], | |
| ) -> tuple[EveWorker, None]: | |
| """Configure a newly acquired worker and register the stream.""" | |
| worker.is_live_stream = True | |
| # Gallery restore runs on the bridge thread so the asyncio thread | |
| # is never blocked by disk I/O or pipe round-trips. While the | |
| # bridge is setting up, submit_and_get_latest returns None and | |
| # the user sees their raw camera feed. | |
| frozen_registry = dict(registry) if registry else {} | |
| def _restore_gallery() -> None: | |
| if frozen_registry: | |
| frames_per_user = [ | |
| load_media_frames(entry.path) for entry in frozen_registry.values() | |
| ] | |
| if frames_per_user: | |
| worker.send_restore_gallery(frames_per_user) | |
| else: | |
| worker.send_remove_all_users() | |
| else: | |
| worker.send_remove_all_users() | |
| bridge = NonBlockingInferenceBridge( | |
| worker, | |
| setup_fn=_restore_gallery, | |
| max_fps=self._max_fps, | |
| min_fps=self._min_fps, | |
| load_fn=lambda: self._pool._shared_load.value, | |
| ) | |
| now = time.monotonic() | |
| entry = _StreamEntry( | |
| worker=worker, | |
| start_time=now, | |
| session_hash=session_hash, | |
| connection_id=connection_id, | |
| bridge=bridge, | |
| last_activity=now, | |
| ) | |
| with self._lock: | |
| self._streams[connection_id] = entry | |
| return worker, None | |
| def _remove_waiting(self, connection_id: str, reason: str = "") -> None: | |
| """Remove a connection from the waiting queue and log the exit event.""" | |
| with self._lock: | |
| entry = self._waiting.pop(connection_id, None) | |
| if entry is not None and self._tracker: | |
| wait_seconds = time.monotonic() - entry.started_waiting | |
| self._tracker.log( | |
| entry.session_hash, | |
| "live_queue_exit", | |
| wait_seconds=round(wait_seconds, 1), | |
| reason=reason or "unknown", | |
| ) | |
| def _release_entry(self, connection_id: str, note: str = "") -> None: | |
| with self._lock: | |
| entry = self._streams.pop(connection_id, None) | |
| if entry is not None: | |
| worker_id = entry.worker.worker_id | |
| # Use last_activity (last frame received) instead of now — avoids | |
| # inflating duration by the orphan-detection silence period. | |
| duration = entry.last_activity - entry.start_time | |
| if entry.bridge is not None: | |
| entry.bridge.stop() | |
| self._pool.release(entry.worker) | |
| log_worker_activity( | |
| logger, "released", "live-inference", self._pool, worker_id, note=note | |
| ) | |
| if self._tracker: | |
| self._tracker.log( | |
| entry.session_hash, | |
| "live_stop", | |
| duration_seconds=round(duration, 1), | |
| reason=note or "normal", | |
| ) | |
| def _orphan_monitor(self, interval: float) -> None: | |
| """Periodically release workers whose stream has gone silent.""" | |
| while not self._shutting_down: | |
| time.sleep(interval) | |
| if self._shutting_down: | |
| break | |
| self._orphan_monitor_once() | |
| def _orphan_monitor_once(self) -> None: | |
| """Single pass of orphan detection — also callable from tests.""" | |
| now = time.monotonic() | |
| with self._lock: | |
| stream_snapshot = list(self._streams.items()) | |
| waiting_snapshot = list(self._waiting.items()) | |
| for cid, entry in stream_snapshot: | |
| silence = now - entry.last_activity | |
| if silence >= _ORPHAN_SILENCE_S: | |
| logger.warning( | |
| f"Orphaned live stream detected ({cid}, " | |
| f"worker {entry.worker.worker_id}, " | |
| f"silent for {silence:.0f}s) — releasing" | |
| ) | |
| self._release_entry(cid, note="orphan") | |
| for cid, waiting in waiting_snapshot: | |
| silence = now - waiting.last_activity | |
| if silence >= _ORPHAN_SILENCE_S: | |
| logger.warning( | |
| f"Orphaned waiting entry detected ({cid}, " | |
| f"silent for {silence:.0f}s) — removing" | |
| ) | |
| self._remove_waiting(cid, reason="orphan") | |