Spaces:
Running
Running
| """Mobile-facing WebSocket session for the narration engine. | |
| MomentEnvelope in, ControlFrame + SceneAudio out, per docs/contracts. | |
| Backpressure is D8 (queue depth <= 1, coalesce-to-newest); freshness is D9 | |
| (`play_by` = created_at + 60 s). Pipeline failures become error frames — | |
| the socket itself never crashes on a bad moment. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import base64 | |
| import contextlib | |
| import inspect | |
| import io | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import uuid | |
| import wave | |
| from collections import OrderedDict | |
| from collections.abc import Callable | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta, timezone | |
| from importlib import resources | |
| from pathlib import Path | |
| from typing import Any | |
| import jsonschema | |
| import numpy as np | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from PIL import Image | |
| from small_cuts import narrator, tts | |
| from small_cuts.styles import DEFAULT_STYLE_KEY | |
| CONTRACT_VERSION = "1.1.0" | |
| PLAY_BY_SECONDS = 60 | |
| MAX_FRAME_SIDE = 1024 # contract cap: decoded longest side <= 1024 px (moment.schema.json) | |
| SEEN_MOMENTS_CAP = 4096 # a day of moments is far less | |
| CONTRACTS_DIR_ENV = "SMALL_CUTS_CONTRACTS_DIR" | |
| _SOURCE_CONTRACTS = Path(__file__).resolve().parents[3] / "docs" / "contracts" | |
| SceneSink = Callable[[dict[str, Any]], Any] | |
| """Receives every successful scene; Task 2 plugs the library/SSE fan-out here.""" | |
| def _contract_text(name: str) -> str: | |
| configured = os.environ.get(CONTRACTS_DIR_ENV) | |
| if configured: | |
| return (Path(configured) / name).read_text() | |
| source_contract = _SOURCE_CONTRACTS / name | |
| if source_contract.exists(): | |
| return source_contract.read_text() | |
| return (resources.files("small_cuts") / "contracts" / name).read_text() | |
| def _validator(name: str) -> jsonschema.Draft202012Validator: | |
| return jsonschema.Draft202012Validator(json.loads(_contract_text(name))) | |
| _MOMENT = _validator("moment.schema.json") | |
| _SCENE_AUDIO = _validator("scene-audio.schema.json") | |
| _BACKGROUND_STORAGE_TASKS: set[asyncio.Task[None]] = set() | |
| def _noop_sink(scene: dict[str, Any]) -> None: | |
| return None | |
| class MomentIdLRU: | |
| """Bounded dedupe set: insertion-ordered, oldest ids evicted past `cap`.""" | |
| def __init__(self, cap: int = SEEN_MOMENTS_CAP) -> None: | |
| self._cap = cap | |
| self._ids: OrderedDict[str, None] = OrderedDict() | |
| def __contains__(self, moment_id: object) -> bool: | |
| return moment_id in self._ids | |
| def add(self, moment_id: str) -> None: | |
| self._ids[moment_id] = None | |
| self._ids.move_to_end(moment_id) | |
| while len(self._ids) > self._cap: | |
| self._ids.popitem(last=False) | |
| def discard(self, moment_id: str) -> None: | |
| self._ids.pop(moment_id, None) | |
| class EngineState: | |
| """Process-lifetime state shared across session sockets.""" | |
| sink: SceneSink = _noop_sink | |
| error_sink: SceneSink | None = None # receives every error ControlFrame (viewer fan-out, D9) | |
| seen_moment_ids: MomentIdLRU = field(default_factory=MomentIdLRU) | |
| class _Queued: | |
| envelope: dict[str, Any] | |
| queued_at: float | |
| class _ValidationFailure(Exception): | |
| """Post-admission validation failure (undecodable or over-cap frame); never retryable.""" | |
| def __init__(self, code: str, message: str) -> None: | |
| super().__init__(message) | |
| self.code = code | |
| def _log_worker_failure(task: asyncio.Task) -> None: | |
| """A drain-task bug must fail loudly, not strand moments as unretrieved exceptions.""" | |
| if task.cancelled(): | |
| return | |
| exc = task.exception() | |
| if exc is not None: | |
| print(f"small_cuts.engine: session worker task crashed: {exc!r}", file=sys.stderr) | |
| def _retain_background_storage(task: asyncio.Task[None]) -> None: | |
| """Keep shielded scene storage alive after the client WebSocket is gone.""" | |
| _BACKGROUND_STORAGE_TASKS.add(task) | |
| task.add_done_callback(_BACKGROUND_STORAGE_TASKS.discard) | |
| class SessionRunner: | |
| """One connected capture app: admission, the single queue slot, the pipeline.""" | |
| def __init__(self, ws: WebSocket, state: EngineState) -> None: | |
| self._ws = ws | |
| self._state = state | |
| self._send_lock = asyncio.Lock() | |
| self._pending: _Queued | None = None | |
| self._worker: asyncio.Task | None = None | |
| self._processing = False | |
| self._last_status: tuple[bool, int] | None = None | |
| async def run(self) -> None: | |
| try: | |
| while True: | |
| message = await self._ws.receive() | |
| if message["type"] == "websocket.disconnect": | |
| break | |
| text = message.get("text") | |
| if text is None: # binary frame: not in the contract, but don't drop the socket | |
| await self._send_ack(None, "rejected", "binary frames not supported") | |
| continue | |
| await self._admit(text) | |
| except WebSocketDisconnect: | |
| pass | |
| finally: | |
| if self._worker is not None: | |
| self._worker.cancel() | |
| # -- admission (every envelope gets exactly one ack) ---------------------- | |
| async def _admit(self, raw: str) -> None: | |
| try: | |
| envelope = json.loads(raw) | |
| except json.JSONDecodeError as exc: | |
| await self._send_ack(None, "rejected", f"invalid JSON: {exc}") | |
| return | |
| moment_id = envelope.get("moment_id") if isinstance(envelope, dict) else None | |
| if not isinstance(moment_id, str): | |
| moment_id = None | |
| error = jsonschema.exceptions.best_match(_MOMENT.iter_errors(envelope)) | |
| if error is not None: | |
| await self._send_ack(moment_id, "rejected", error.message) | |
| return | |
| dedupe_key = _moment_dedupe_key(envelope) | |
| if dedupe_key in self._state.seen_moment_ids: | |
| await self._send_ack(moment_id, "duplicate") | |
| return | |
| self._state.seen_moment_ids.add(dedupe_key) | |
| queued = _Queued(envelope, time.perf_counter()) | |
| if not self._processing: | |
| self._processing = True | |
| await self._send_ack(moment_id, "accepted") | |
| self._worker = asyncio.create_task(self._drain(queued)) | |
| self._worker.add_done_callback(_log_worker_failure) | |
| elif self._pending is None: | |
| self._pending = queued | |
| await self._send_ack(moment_id, "accepted") | |
| else: # D8: replace the un-started moment; stale narration is worse than none | |
| dropped = self._pending | |
| self._pending = queued | |
| await self._send_ack(dropped.envelope["moment_id"], "dropped_coalesced") | |
| await self._send_ack(moment_id, "accepted") | |
| await self._emit_status() | |
| # -- processing ------------------------------------------------------------ | |
| async def _drain(self, queued: _Queued) -> None: | |
| current: _Queued | None = queued | |
| try: | |
| while current is not None: | |
| await self._process(current) | |
| current, self._pending = self._pending, None | |
| await self._emit_status() | |
| finally: | |
| self._processing = False | |
| await self._emit_status() # skipped on cancellation: the socket is gone | |
| async def _process(self, item: _Queued) -> None: | |
| envelope = item.envelope | |
| moment_id: str = envelope["moment_id"] | |
| context = envelope.get("context") or {} | |
| style_key = context.get("style_key") or DEFAULT_STYLE_KEY | |
| started = time.perf_counter() | |
| queue_ms = _ms(started - item.queued_at) | |
| stage = "narration" | |
| try: | |
| image, narration = await asyncio.to_thread( | |
| _decode_and_narrate, | |
| envelope, | |
| style_key, | |
| context.get("user_hint", ""), | |
| ) | |
| narration_ms = _ms(time.perf_counter() - started) | |
| stage = "tts" | |
| tts_started = time.perf_counter() | |
| speech = await asyncio.to_thread(tts.speak, narration.text) | |
| audio_b64 = base64.b64encode(_wav_bytes(speech.audio, speech.sample_rate)).decode() | |
| tts_ms = _ms(time.perf_counter() - tts_started) | |
| stage = "storage" # the outgoing SceneAudio is the engine's stored artifact | |
| created_at = datetime.now(timezone.utc) | |
| payload = { | |
| "contract_version": CONTRACT_VERSION, | |
| "scene_id": str(uuid.uuid4()), | |
| "moment_id": moment_id, | |
| "created_at": created_at.isoformat(), | |
| "play_by": (created_at + timedelta(seconds=PLAY_BY_SECONDS)).isoformat(), | |
| "format": "wav_complete", | |
| "audio_b64": audio_b64, | |
| "sample_rate": speech.sample_rate, | |
| "narration": narration.text, | |
| } | |
| _SCENE_AUDIO.validate(payload) # outgoing drift becomes an error frame, never silence | |
| await self._send_json(payload) | |
| except _ValidationFailure as exc: | |
| # The resend would fail the same way, but dedupe only what produced a scene. | |
| self._state.seen_moment_ids.discard(_moment_dedupe_key(envelope)) | |
| await self._send_error(moment_id, "validation", exc, code=exc.code, retryable=False) | |
| return | |
| except Exception as exc: | |
| # Drop the id so a client resend is genuinely re-processed (honest retryable). | |
| self._state.seen_moment_ids.discard(_moment_dedupe_key(envelope)) | |
| retryable = stage in ("narration", "tts") | |
| code = "scene_audio_schema_drift" if stage == "storage" else None | |
| await self._send_error(moment_id, stage, exc, code=code, retryable=retryable) | |
| return | |
| storage_task = asyncio.create_task( | |
| self._finish_scene_storage( | |
| envelope=envelope, | |
| image=image, | |
| scene_audio=payload, | |
| narration_text=narration.text, | |
| title=narration.title, | |
| speech=speech, | |
| style_key=style_key, | |
| queue_ms=queue_ms, | |
| narration_ms=narration_ms, | |
| tts_ms=tts_ms, | |
| ) | |
| ) | |
| _retain_background_storage(storage_task) | |
| try: | |
| await asyncio.shield(storage_task) | |
| except asyncio.CancelledError: | |
| storage_task.add_done_callback(_log_worker_failure) | |
| raise | |
| async def _finish_scene_storage( | |
| self, | |
| *, | |
| envelope: dict[str, Any], | |
| image: Image.Image, | |
| scene_audio: dict[str, Any], | |
| narration_text: str, | |
| title: str, | |
| speech: tts.Speech, | |
| style_key: str, | |
| queue_ms: int, | |
| narration_ms: int, | |
| tts_ms: int, | |
| ) -> None: | |
| clip_frames = await asyncio.to_thread( | |
| _decode_clip_frames_for_storage, envelope, image, scene_audio["scene_id"] | |
| ) | |
| await self._hand_to_sink( | |
| self._state.sink, | |
| { | |
| "scene_id": scene_audio["scene_id"], | |
| "moment_id": envelope["moment_id"], | |
| "session_id": envelope["session_id"], | |
| "captured_at": envelope["captured_at"], | |
| "created_at": scene_audio["created_at"], | |
| "style_key": style_key, | |
| "title": title, | |
| "narration": narration_text, | |
| "image": image, | |
| "clip_frames": clip_frames, | |
| "audio": speech.audio, | |
| "sample_rate": speech.sample_rate, | |
| "latency_ms": { | |
| "queue": queue_ms, | |
| "narration": narration_ms, | |
| "tts": tts_ms, | |
| "total": queue_ms + narration_ms + tts_ms, | |
| }, | |
| }, | |
| ) | |
| async def _hand_to_sink(self, sink: SceneSink | None, payload: dict[str, Any]) -> None: | |
| if sink is None: | |
| return | |
| with contextlib.suppress(Exception): # a sink bug must not kill the session | |
| result = sink(payload) | |
| if inspect.isawaitable(result): | |
| await result | |
| # -- outbound frames --------------------------------------------------------- | |
| async def _send_ack(self, moment_id: str | None, result: str, detail: str = "") -> None: | |
| ack: dict[str, Any] = {"result": result} | |
| if detail: | |
| ack["detail"] = detail[:200] | |
| await self._send_json( | |
| { | |
| "contract_version": CONTRACT_VERSION, | |
| "kind": "ack", | |
| "moment_id": moment_id, | |
| "ack": ack, | |
| } | |
| ) | |
| async def _send_error( | |
| self, | |
| moment_id: str, | |
| stage: str, | |
| exc: Exception, | |
| *, | |
| retryable: bool, | |
| code: str | None = None, | |
| ) -> None: | |
| frame = { | |
| "contract_version": CONTRACT_VERSION, | |
| "kind": "error", | |
| "moment_id": moment_id, | |
| "error": { | |
| "stage": stage, | |
| "code": (code or type(exc).__name__)[:60], | |
| "message": str(exc)[:300], | |
| "retryable": retryable, | |
| }, | |
| } | |
| await self._send_json(frame) | |
| # D9 honest timeline: the same failure fans out to the viewer stream. | |
| await self._hand_to_sink(self._state.error_sink, frame) | |
| async def _emit_status(self) -> None: | |
| snapshot = (self._processing, int(self._pending is not None)) | |
| if snapshot == self._last_status: | |
| return | |
| self._last_status = snapshot | |
| await self._send_json( | |
| { | |
| "contract_version": CONTRACT_VERSION, | |
| "kind": "status", | |
| "moment_id": None, | |
| "status": {"busy": snapshot[0], "queue_depth": snapshot[1]}, | |
| } | |
| ) | |
| async def _send_json(self, payload: dict[str, Any]) -> None: | |
| text = json.dumps(payload) # serialization bugs must surface, not be swallowed | |
| async with self._send_lock: | |
| with contextlib.suppress(Exception): # client gone mid-send; run() closes out | |
| await self._ws.send_text(text) | |
| def _decode_and_narrate( | |
| envelope: dict[str, Any], style_key: str, scene_hint: str | |
| ) -> tuple[Image.Image, narrator.Narration]: | |
| """Decode the selected frame + narrate it in one worker-thread hop.""" | |
| try: | |
| selected = _decode_frame(envelope["frames"][0]) | |
| _validate_frame_size(selected) | |
| except _ValidationFailure: | |
| raise | |
| except Exception as exc: | |
| raise _ValidationFailure("frame_decode_failed", f"undecodable frame: {exc}") from exc | |
| return ( | |
| selected, | |
| narrator.narrate(selected, style_key=style_key, scene_hint=scene_hint), | |
| ) | |
| def _decode_frame(frame: dict[str, Any]) -> Image.Image: | |
| data = base64.b64decode(frame["jpeg_b64"]) | |
| image = Image.open(io.BytesIO(data)) | |
| image.load() | |
| return image | |
| def _moment_dedupe_key(envelope: dict[str, Any]) -> str: | |
| return f"{envelope['session_id']}:{envelope['moment_id']}" | |
| def _validate_frame_size(image: Image.Image) -> None: | |
| longest = max(image.size) | |
| if longest > MAX_FRAME_SIDE: | |
| raise _ValidationFailure( | |
| "frame_exceeds_cap", | |
| f"decoded longest side {longest} px exceeds the {MAX_FRAME_SIDE} px contract cap", | |
| ) | |
| def _decode_clip_frames(envelope: dict[str, Any], selected: Image.Image) -> list[Image.Image]: | |
| decoded: list[tuple[int, int, Image.Image]] = [ | |
| (int(envelope["frames"][0].get("ts_offset_ms", 0)), 0, selected) | |
| ] | |
| for index, frame in enumerate(envelope["frames"][1:], start=1): | |
| image = _decode_frame(frame) | |
| _validate_frame_size(image) | |
| decoded.append((int(frame.get("ts_offset_ms", index)), index, image)) | |
| return [image for _, _, image in sorted(decoded, key=lambda item: (item[0], item[1]))] | |
| def _decode_clip_frames_for_storage( | |
| envelope: dict[str, Any], selected: Image.Image, scene_id: str | |
| ) -> list[Image.Image]: | |
| """Decode viewer-only supplemental frames after SceneAudio is already sent.""" | |
| if len(envelope["frames"]) < 2: | |
| return [selected] | |
| try: | |
| return _decode_clip_frames(envelope, selected) | |
| except Exception as exc: | |
| print( | |
| f"small_cuts.engine: clip frame decode failed for scene {scene_id}: {exc!r}", | |
| file=sys.stderr, | |
| ) | |
| return [selected] | |
| def _wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: | |
| buffer = io.BytesIO() | |
| try: | |
| import soundfile | |
| soundfile.write(buffer, audio, sample_rate, format="WAV", subtype="PCM_16") | |
| except ImportError: | |
| pcm = (np.clip(audio, -1.0, 1.0) * 32767.0).astype("<i2") | |
| with wave.open(buffer, "wb") as wav: | |
| wav.setnchannels(1) | |
| wav.setsampwidth(2) | |
| wav.setframerate(sample_rate) | |
| wav.writeframes(pcm.tobytes()) | |
| return buffer.getvalue() | |
| def _ms(seconds: float) -> int: | |
| return max(0, round(seconds * 1000)) | |