sensAI-Generic-Object-Detection / shared /live_stream_manager.py
beaupreda's picture
Upload sensAI-Generic-Object-Detection with upload_repo.py
13170f7 verified
Raw
History Blame Contribute Delete
18.6 kB
"""Manages worker assignment and time limits for live WebRTC streams.
Each active WebRTC stream holds a dedicated ``EveWorker`` for its duration
(acquiring/releasing per-frame would destroy throughput). The manager:
- Assigns workers to streams on first frame (keyed by WebRTC connection ID)
- Returns immediately with a "waiting" status when no worker is available
(so WebRTC frame handlers never block)
- Restores the Face ID gallery on stream start
- Enforces a configurable timeout when the queue is full (other users waiting)
- Reclaims orphaned workers when a stream goes silent (no frames for 30 s)
"""
import random
import threading
import time
from dataclasses import dataclass, field
from eve_worker_pool import EveWorker, EveWorkerPool, log_worker_activity
from face_id_tab import FaceEntry
from frame_utils import load_media_frames
from log_utils import setup_logger
from non_blocking_bridge import NonBlockingInferenceBridge
from usage_analytics import UsageTracker
logger = setup_logger("LiveStreamManager")
_WAIT_TIMEOUT_S = 300.0 # Max seconds to wait for a worker before giving up
_ORPHAN_SILENCE_S = 10.0
@dataclass
class _StreamEntry:
worker: EveWorker
start_time: float
session_hash: str
connection_id: str
bridge: NonBlockingInferenceBridge | None = None
last_activity: float = field(default_factory=time.monotonic)
# When pressure is detected (all workers busy + someone waiting), record the timestamp.
# Resets to None when idle workers become available again.
pressure_start: float | None = None
# Per-stream jittered timeout so streams don't all expire simultaneously
pressure_timeout: float = 0.0
@dataclass
class _WaitingEntry:
"""Tracks a stream that is waiting for a worker to become available."""
session_hash: str
connection_id: str
started_waiting: float
last_activity: float
registry: dict[int, FaceEntry]
class LiveStreamManager:
"""Manages per-stream worker assignment with timeout enforcement.
Streams are keyed by ``connection_id`` — a unique identifier per WebRTC
peer connection (e.g. FastRTC's ``webrtc_id``).
Args:
pool: The ``EveWorkerPool`` to acquire / release workers from.
timeout_seconds: Max stream duration before reclaiming the worker,
but **only** when the pool has no idle workers (queue is waiting).
``None`` disables pressure-based timeouts entirely.
session_lifetime_seconds: Absolute max duration for any live session.
After this time the stream is stopped regardless of pressure.
``None`` disables the session lifetime limit.
orphan_check_interval_s: How often to scan for stale streams.
max_fps: Inference rate when the system is idle. ``None``
disables the FPS cap entirely.
min_fps: Inference rate under full system pressure. Defaults
to ``max_fps`` (no scaling).
"""
def __init__(
self,
pool: EveWorkerPool,
timeout_seconds: float | None = None,
session_lifetime_seconds: float | None = 240.0,
orphan_check_interval_s: float = 5.0,
max_fps: float | None = None,
min_fps: float | None = None,
tracker: UsageTracker | None = None,
):
self._pool = pool
self._timeout_s = timeout_seconds
self._session_lifetime_s = session_lifetime_seconds
self._max_fps = max_fps
self._min_fps = min_fps
self._tracker = tracker
self._lock = threading.Lock()
self._streams: dict[str, _StreamEntry] = {}
self._waiting: dict[str, _WaitingEntry] = {}
self._expired: set[str] = set() # connections whose session lifetime ended
self._shutting_down = False
self._orphan_thread = threading.Thread(
target=self._orphan_monitor,
args=(orphan_check_interval_s,),
daemon=True,
)
self._orphan_thread.start()
# -- public API --------------------------------------------------------
def get_or_acquire(
self,
connection_id: str,
session_hash: str,
registry: dict[int, FaceEntry],
) -> tuple[EveWorker | None, str | None]:
"""Return the worker for the current stream, acquiring one if needed.
This method **never blocks**. If no worker is available on the first
frame, it returns ``(None, "waiting")`` so the caller can draw an
overlay and return the frame immediately (keeping WebRTC alive).
Args:
connection_id: Unique WebRTC connection identifier (e.g. webrtc_id).
session_hash: Gradio session hash (for pool affinity + gallery restore).
registry: Face ID registry ``{user_id: media_path}`` for gallery restore.
Returns:
``(worker, None)`` on success.
``(None, "waiting")`` when waiting for a worker.
``(None, reason)`` when the stream should stop (timeout / error).
"""
# 0. Already expired — don't re-enter the queue
with self._lock:
if connection_id in self._expired:
return None, "Session ended\nclick `Start Inference` to restart."
# 1. Already has a worker — check session lifetime, then pressure
with self._lock:
entry = self._streams.get(connection_id)
if entry is not None:
now = time.monotonic()
entry.last_activity = now
# Session lifetime check (absolute cap, independent of pressure)
if self._session_lifetime_s is not None:
if now - entry.start_time >= self._session_lifetime_s:
self._release_entry(connection_id, note="session-expired")
with self._lock:
self._expired.add(connection_id)
return None, ("Session ended\nclick `Start Inference` to restart.")
# Pressure-based countdown (only when enabled)
if self._timeout_s is not None:
under_pressure = self._pool.idle_count == 0 and (
self._pool.waiting_count > 0 or len(self._waiting) > 0
)
if under_pressure:
if entry.pressure_start is None:
entry.pressure_timeout = self._timeout_s + random.uniform(-10, 5)
entry.pressure_start = now
logger.info(
f"Live stream ({connection_id}) under pressure\n"
f"{entry.pressure_timeout:.0f}s countdown started"
)
elif now - entry.pressure_start >= entry.pressure_timeout:
self._release_entry(connection_id, note="pressure-expired")
return None, (
"Stream ended\nother users are waiting. "
"Click `Start Inference` to rejoin the queue."
)
else:
# Pressure relieved — reset countdown
if entry.pressure_start is not None:
entry.pressure_start = None
return entry.worker, None
# 2. Already waiting — retry non-blocking acquire
with self._lock:
waiting = self._waiting.get(connection_id)
if waiting is not None:
waiting.last_activity = time.monotonic()
worker = self._pool.try_acquire(waiting.session_hash)
if worker is not None:
self._remove_waiting(connection_id, reason="worker_acquired")
result = self._setup_stream(
connection_id, waiting.session_hash, worker, waiting.registry
)
log_worker_activity(
logger,
"acquired",
"live-inference",
self._pool,
worker.worker_id,
note="was queued",
)
return result
elapsed = time.monotonic() - waiting.started_waiting
if elapsed >= _WAIT_TIMEOUT_S:
self._remove_waiting(connection_id, reason="timeout")
return None, (
f"No workers available\ntimed out after {_WAIT_TIMEOUT_S:.0f}s. "
"Please try again later."
)
return None, "waiting"
# 3. Brand new stream — try non-blocking acquire
worker = self._pool.try_acquire(session_hash)
if worker is not None:
result = self._setup_stream(connection_id, session_hash, worker, registry)
log_worker_activity(logger, "acquired", "live-inference", self._pool, worker.worker_id)
return result
# No worker available — start waiting
now = time.monotonic()
with self._lock:
self._waiting[connection_id] = _WaitingEntry(
session_hash=session_hash,
connection_id=connection_id,
started_waiting=now,
last_activity=now,
registry=dict(registry) if registry else {},
)
log_worker_activity(logger, "queued", "live-inference", self._pool)
if self._tracker:
self._tracker.log(session_hash, "live_queue_enter")
return None, "waiting"
def waiting_position(self, connection_id: str) -> tuple[int, int]:
"""Return ``(position, total_waiting)`` for the given connection.
*position* is 1-based (1 = next in line). If the connection is not in
the waiting list, position equals ``total_waiting + 1`` (just joined).
"""
with self._lock:
entries = sorted(self._waiting.values(), key=lambda e: e.started_waiting)
total = len(entries)
for i, entry in enumerate(entries):
if entry.connection_id == connection_id:
return i + 1, total
return total + 1, total
def release(self, connection_id: str) -> None:
"""Release the worker bound to the given connection (if any)."""
self._release_entry(connection_id)
self._remove_waiting(connection_id, reason="released")
with self._lock:
self._expired.discard(connection_id)
def countdown_remaining(self, connection_id: str) -> float | None:
"""Seconds until this stream is reclaimed due to queue pressure.
Returns ``None`` if no countdown is active (no pressure: idle workers
exist or nobody is waiting). Returns the remaining seconds otherwise.
"""
with self._lock:
entry = self._streams.get(connection_id)
if entry is None or entry.pressure_start is None:
return None
elapsed = time.monotonic() - entry.pressure_start
return max(0.0, entry.pressure_timeout - elapsed)
def session_remaining(self, connection_id: str) -> float | None:
"""Seconds left in this stream's session lifetime.
Returns ``None`` if no session lifetime is configured or if the
connection is not an active stream. Returns the remaining seconds
otherwise (clamped to >= 0).
"""
if self._session_lifetime_s is None:
return None
with self._lock:
entry = self._streams.get(connection_id)
if entry is None:
return None
elapsed = time.monotonic() - entry.start_time
return max(0.0, self._session_lifetime_s - elapsed)
def estimated_wait(self, position: int) -> float | None:
"""Estimated seconds until the *position*-th worker frees up.
Considers **all** busy workers:
- Live-stream workers use session lifetime to estimate remaining time.
- Video-processing workers have unknown duration (sorted last).
Known ETAs are sorted ascending so position 1 gets the soonest-to-free
worker, position 2 the second-soonest, etc.
Args:
position: 1-based queue position (1 = next in line).
Returns:
Estimated seconds, or ``None`` when unknown or out of range.
"""
if position < 1:
return None
now = time.monotonic()
# Known ETAs from live streams with a session lifetime
known: list[float] = []
if self._session_lifetime_s is not None:
with self._lock:
entries = list(self._streams.values())
known = sorted(
max(0.0, self._session_lifetime_s - (now - e.start_time)) for e in entries
)
# Workers doing video processing (busy, not live-stream) have
# unknown remaining time — they could finish at any moment.
total_busy = self._pool.worker_count - self._pool.idle_count
video_workers = max(0, total_busy - len(known))
all_etas: list[float | None] = list(known) + [None] * video_workers
idx = position - 1
if idx >= len(all_etas):
return None
return all_etas[idx]
def get_bridge(self, connection_id: str) -> NonBlockingInferenceBridge | None:
"""Return the inference bridge for an active stream, or ``None``."""
with self._lock:
entry = self._streams.get(connection_id)
if entry is not None:
return entry.bridge
return None
def shutdown(self) -> None:
"""Release all stream workers."""
self._shutting_down = True
with self._lock:
stream_cids = list(self._streams.keys())
waiting_cids = list(self._waiting.keys())
self._expired.clear()
for cid in waiting_cids:
self._remove_waiting(cid, reason="shutdown")
for cid in stream_cids:
self._release_entry(cid)
# -- internals ---------------------------------------------------------
def _setup_stream(
self,
connection_id: str,
session_hash: str,
worker: EveWorker,
registry: dict[int, FaceEntry],
) -> tuple[EveWorker, None]:
"""Configure a newly acquired worker and register the stream."""
worker.is_live_stream = True
# Gallery restore runs on the bridge thread so the asyncio thread
# is never blocked by disk I/O or pipe round-trips. While the
# bridge is setting up, submit_and_get_latest returns None and
# the user sees their raw camera feed.
frozen_registry = dict(registry) if registry else {}
def _restore_gallery() -> None:
if frozen_registry:
frames_per_user = [
load_media_frames(entry.path) for entry in frozen_registry.values()
]
if frames_per_user:
worker.send_restore_gallery(frames_per_user)
else:
worker.send_remove_all_users()
else:
worker.send_remove_all_users()
bridge = NonBlockingInferenceBridge(
worker,
setup_fn=_restore_gallery,
max_fps=self._max_fps,
min_fps=self._min_fps,
load_fn=lambda: self._pool._shared_load.value,
)
now = time.monotonic()
entry = _StreamEntry(
worker=worker,
start_time=now,
session_hash=session_hash,
connection_id=connection_id,
bridge=bridge,
last_activity=now,
)
with self._lock:
self._streams[connection_id] = entry
return worker, None
def _remove_waiting(self, connection_id: str, reason: str = "") -> None:
"""Remove a connection from the waiting queue and log the exit event."""
with self._lock:
entry = self._waiting.pop(connection_id, None)
if entry is not None and self._tracker:
wait_seconds = time.monotonic() - entry.started_waiting
self._tracker.log(
entry.session_hash,
"live_queue_exit",
wait_seconds=round(wait_seconds, 1),
reason=reason or "unknown",
)
def _release_entry(self, connection_id: str, note: str = "") -> None:
with self._lock:
entry = self._streams.pop(connection_id, None)
if entry is not None:
worker_id = entry.worker.worker_id
# Use last_activity (last frame received) instead of now — avoids
# inflating duration by the orphan-detection silence period.
duration = entry.last_activity - entry.start_time
if entry.bridge is not None:
entry.bridge.stop()
self._pool.release(entry.worker)
log_worker_activity(
logger, "released", "live-inference", self._pool, worker_id, note=note
)
if self._tracker:
self._tracker.log(
entry.session_hash,
"live_stop",
duration_seconds=round(duration, 1),
reason=note or "normal",
)
def _orphan_monitor(self, interval: float) -> None:
"""Periodically release workers whose stream has gone silent."""
while not self._shutting_down:
time.sleep(interval)
if self._shutting_down:
break
self._orphan_monitor_once()
def _orphan_monitor_once(self) -> None:
"""Single pass of orphan detection — also callable from tests."""
now = time.monotonic()
with self._lock:
stream_snapshot = list(self._streams.items())
waiting_snapshot = list(self._waiting.items())
for cid, entry in stream_snapshot:
silence = now - entry.last_activity
if silence >= _ORPHAN_SILENCE_S:
logger.warning(
f"Orphaned live stream detected ({cid}, "
f"worker {entry.worker.worker_id}, "
f"silent for {silence:.0f}s) — releasing"
)
self._release_entry(cid, note="orphan")
for cid, waiting in waiting_snapshot:
silence = now - waiting.last_activity
if silence >= _ORPHAN_SILENCE_S:
logger.warning(
f"Orphaned waiting entry detected ({cid}, "
f"silent for {silence:.0f}s) — removing"
)
self._remove_waiting(cid, reason="orphan")