Spaces:
Build error
Build error
| """DriftCall demo Space — Gradio 5.x entrypoint. | |
| Implements ``docs/modules/deploy_demo_space.md`` (sealed). Single-file demo: | |
| mic → ASR → DriftCallEnv → Gemma 3n E2B (base | trained LoRA) → TTS → speaker | |
| with a live trace panel and a manual drift-injection dropdown. | |
| Hard rules: | |
| - Heavy deps (``gradio``, ``spaces``, ``peft``, ``torch``, ``transformers``, | |
| ``unsloth``) are imported lazily inside callables so this module imports | |
| cleanly in CI / on machines without GPUs / Gradio. | |
| - ``infer_turn`` never writes to disk and never calls push-to-hub. | |
| - Latency budget: ≤ 8 s on ZeroGPU warm, ≤ 12 s on A10G warm. | |
| - All 9 error modes (deploy_demo_space.md §5) surface as ``status_msg`` and | |
| positional safe defaults; the UI never crashes. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import logging | |
| import os | |
| import threading | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import TYPE_CHECKING, Any, Literal | |
| import numpy as np | |
| from cells.step_04_models import ActionType, DriftCallAction | |
| from cells.step_06_drift_injector import list_patterns | |
| from cells.step_10_env import DriftCallEnv | |
| if TYPE_CHECKING: | |
| import pandas as pd | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Public types + constants | |
| # --------------------------------------------------------------------------- | |
| CheckpointId = Literal["base", "trained"] | |
| _MAX_SESSIONS: int = 10 | |
| _IDLE_TTL_S: int = 900 | |
| _GPU_DURATION_S: int = 60 | |
| _LATENCY_BUDGET_S: float = 8.0 | |
| _TRACE_COLUMNS: tuple[str, ...] = ( | |
| "turn_idx", | |
| "actor", | |
| "action_or_event", | |
| "tool_response_preview", | |
| "reward_delta", | |
| ) | |
| _BASE_MODEL_ID_DEFAULT: str = "unsloth/gemma-3n-E2B-it" | |
| _TRAINED_ADAPTER_ID_DEFAULT: str = "DGXAI/gemma-3n-e2b-driftcall-lora" | |
| _HARDWARE_ENV_VAR: str = "DRIFTCALL_HARDWARE" | |
| _HARDWARE_FALLBACK_ENV_VAR: str = "DRIFTCALL_HARDWARE_FALLBACK" | |
| _TRACE_PREVIEW_LEN: int = 120 | |
| _FALLBACK_SR_HZ: int = 16000 | |
| _FALLBACK_SILENCE_LEN: int = _FALLBACK_SR_HZ # 1 s of silence | |
| _DRIFT_PATTERN_IDS: tuple[str, ...] = tuple(p.id for p in list_patterns()) | |
| # --------------------------------------------------------------------------- | |
| # Errors (deploy_demo_space.md §5) | |
| # --------------------------------------------------------------------------- | |
| class TrainedAdapterMissingError(RuntimeError): | |
| """5.2 — LoRA download failed at boot or adapter file corrupt.""" | |
| class CheckpointMismatchError(RuntimeError): | |
| """5.5 — LoRA was trained on a different base_model_id.""" | |
| class SessionCapacityError(RuntimeError): | |
| """5.7 — > 10 concurrent sessions.""" | |
| class EnvStepError(RuntimeError): | |
| """5.8 — env raised during step().""" | |
| class ZeroGPUUnavailableError(RuntimeError): | |
| """5.1 — @spaces.GPU request rejected.""" | |
| class AudioDecodeError(RuntimeError): | |
| """5.6 — ASR could not decode mic audio.""" | |
| # --------------------------------------------------------------------------- | |
| # DemoSessionState (deploy_demo_space.md §4.1) | |
| # --------------------------------------------------------------------------- | |
| class TraceRow: | |
| turn_idx: int | |
| actor: Literal["user", "agent", "env", "drift", "reward"] | |
| action_or_event: str | |
| tool_response_preview: str | |
| reward_delta: float | |
| class DemoSessionState: | |
| """Per-tab state. Mutated only by ``demo.app_gradio`` itself.""" | |
| session_id: str | |
| env: DriftCallEnv | |
| last_observation: Any | None = None | |
| episode_trace: list[TraceRow] = field(default_factory=list) | |
| audio_buffer: list[bytes] = field(default_factory=list) | |
| current_checkpoint: CheckpointId = "base" | |
| turn_idx: int = 0 | |
| created_at_ms: int = 0 | |
| last_activity_ms: int = 0 | |
| # --------------------------------------------------------------------------- | |
| # Process-wide registry | |
| # --------------------------------------------------------------------------- | |
| _REGISTRY: dict[str, DemoSessionState] = {} | |
| _REGISTRY_LOCK = threading.Lock() | |
| def _now_ms() -> int: | |
| return int(time.time() * 1000) | |
| def _make_env() -> DriftCallEnv: | |
| """Build a fresh env. Audio boundary is enabled in production but the | |
| constructor requires real engines; tests inject a stub via ``_make_env`` | |
| monkeypatch.""" | |
| return DriftCallEnv() | |
| def get_session(session_id: str) -> DemoSessionState: | |
| """Return the session for this UUID or create a fresh one. §3.3.""" | |
| with _REGISTRY_LOCK: | |
| existing = _REGISTRY.get(session_id) | |
| if existing is not None: | |
| existing.last_activity_ms = _now_ms() | |
| return existing | |
| if len(_REGISTRY) >= _MAX_SESSIONS: | |
| raise SessionCapacityError( | |
| f"demo at capacity ({_MAX_SESSIONS} concurrent sessions)" | |
| ) | |
| env = _make_env() | |
| state = DemoSessionState( | |
| session_id=session_id, | |
| env=env, | |
| created_at_ms=_now_ms(), | |
| last_activity_ms=_now_ms(), | |
| ) | |
| _REGISTRY[session_id] = state | |
| return state | |
| def reset_session(session_id: str) -> DemoSessionState: | |
| """Hard-reset: close env, drop trace, re-allocate. §3.5.""" | |
| with _REGISTRY_LOCK: | |
| existing = _REGISTRY.pop(session_id, None) | |
| if existing is not None: | |
| try: | |
| existing.env.close() | |
| except Exception: | |
| logger.exception("env.close raised on reset_session for %s", session_id) | |
| return get_session(session_id) | |
| def gc_sessions(max_idle_s: int = _IDLE_TTL_S) -> int: | |
| """Evict sessions idle past ``max_idle_s``. Returns the count evicted.""" | |
| cutoff_ms = _now_ms() - max_idle_s * 1000 | |
| with _REGISTRY_LOCK: | |
| stale = [sid for sid, s in _REGISTRY.items() if s.last_activity_ms < cutoff_ms] | |
| for sid in stale: | |
| entry = _REGISTRY.pop(sid) | |
| try: | |
| entry.env.close() | |
| except Exception: | |
| logger.exception("env.close raised on gc for %s", sid) | |
| return len(stale) | |
| def _clear_registry_for_tests() -> None: | |
| """Pytest helper — never used in production code paths.""" | |
| with _REGISTRY_LOCK: | |
| for entry in _REGISTRY.values(): | |
| with contextlib.suppress(Exception): | |
| entry.env.close() | |
| _REGISTRY.clear() | |
| # --------------------------------------------------------------------------- | |
| # DriftToggleBridge (deploy_demo_space.md §2.5, §3.8, §7.3) | |
| # --------------------------------------------------------------------------- | |
| class DriftToggleBridge: | |
| """One-slot per-session queue with last-write-wins coalescence.""" | |
| def __init__(self) -> None: | |
| self._slots: dict[str, str] = {} | |
| self._lock = threading.Lock() | |
| def queue(self, session_id: str, pattern_id: str | None) -> None: | |
| with self._lock: | |
| if pattern_id is None: | |
| self._slots.pop(session_id, None) | |
| else: | |
| self._slots[session_id] = pattern_id | |
| def consume(self, session_id: str) -> str | None: | |
| with self._lock: | |
| return self._slots.pop(session_id, None) | |
| _DEFAULT_BRIDGE = DriftToggleBridge() | |
| def get_drift_bridge() -> DriftToggleBridge: | |
| return _DEFAULT_BRIDGE | |
| # --------------------------------------------------------------------------- | |
| # Trace panel (deploy_demo_space.md §2.6, §4.3) | |
| # --------------------------------------------------------------------------- | |
| def render_trace(state: DemoSessionState) -> pd.DataFrame: | |
| """Pure rendering — never mutates ``state``.""" | |
| import pandas as pd | |
| rows = [ | |
| { | |
| "turn_idx": r.turn_idx, | |
| "actor": r.actor, | |
| "action_or_event": r.action_or_event, | |
| "tool_response_preview": r.tool_response_preview, | |
| "reward_delta": float(r.reward_delta), | |
| } | |
| for r in state.episode_trace | |
| ] | |
| return pd.DataFrame(rows, columns=list(_TRACE_COLUMNS)) | |
| def _empty_trace_df() -> pd.DataFrame: | |
| import pandas as pd | |
| return pd.DataFrame([], columns=list(_TRACE_COLUMNS)) | |
| # --------------------------------------------------------------------------- | |
| # ModelLoader (deploy_demo_space.md §2.3, §3.2) | |
| # --------------------------------------------------------------------------- | |
| class ModelLoader: | |
| """Process-wide singleton. Holds the 4-bit base model + LoRA adapter.""" | |
| def __init__( | |
| self, | |
| *, | |
| base_model_id: str = _BASE_MODEL_ID_DEFAULT, | |
| trained_adapter_id: str = _TRAINED_ADAPTER_ID_DEFAULT, | |
| max_seq_length: int = 4096, | |
| ) -> None: | |
| self._base_model_id = base_model_id | |
| self._trained_adapter_id = trained_adapter_id | |
| self._max_seq_length = max_seq_length | |
| self._model: Any | None = None | |
| self._tokenizer: Any | None = None | |
| self._trained_available: bool = False | |
| self._lock = threading.Lock() | |
| def _load_base(self) -> tuple[Any, Any]: | |
| """Load the 4-bit base model. Patched by tests via ``_load_base``.""" | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tok_cls: Any = AutoTokenizer | |
| model_cls: Any = AutoModelForCausalLM | |
| tokenizer = tok_cls.from_pretrained(self._base_model_id) | |
| model = model_cls.from_pretrained(self._base_model_id) | |
| return model, tokenizer | |
| def _try_mount_adapter(self, model: Any) -> Any | None: | |
| """Mount ``self._trained_adapter_id`` as the ``driftcall`` LoRA. None on miss.""" | |
| try: | |
| from peft import PeftModel | |
| except ImportError as exc: | |
| logger.warning("peft import failed: %s", exc) | |
| return None | |
| try: | |
| return PeftModel.from_pretrained( | |
| model, self._trained_adapter_id, adapter_name="driftcall" | |
| ) | |
| except CheckpointMismatchError as exc: | |
| logger.warning("checkpoint mismatch on LoRA mount: %s", exc) | |
| return None | |
| except Exception as exc: | |
| # Captures EntryNotFoundError, HTTPError(404), generic peft failures. | |
| logger.warning("LoRA mount failed: %s", exc) | |
| return None | |
| def ensure_loaded(self) -> None: | |
| """Lazy load — first ZeroGPU-decorated call invokes this.""" | |
| with self._lock: | |
| if self._model is not None: | |
| return | |
| base_model, tokenizer = self._load_base() | |
| self._tokenizer = tokenizer | |
| wrapped = self._try_mount_adapter(base_model) | |
| if wrapped is None: | |
| self._model = base_model | |
| self._trained_available = False | |
| else: | |
| self._model = wrapped | |
| self._trained_available = True | |
| def is_trained_available(self) -> bool: | |
| return self._trained_available | |
| def generate( | |
| self, | |
| messages: list[dict[str, str]], | |
| *, | |
| checkpoint: CheckpointId, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.2, | |
| top_p: float = 0.95, | |
| seed: int = 0, | |
| ) -> str: | |
| """Run the correct adapter and return the completion text.""" | |
| self.ensure_loaded() | |
| if checkpoint == "trained" and not self._trained_available: | |
| raise TrainedAdapterMissingError( | |
| "Trained adapter unavailable; falling back to base" | |
| ) | |
| model = self._model | |
| if model is None: | |
| raise RuntimeError("model unexpectedly None") | |
| if checkpoint == "base": | |
| ctx = model.disable_adapter() if hasattr(model, "disable_adapter") else _NullCtx() | |
| with ctx: | |
| return self._run(model, messages, max_new_tokens, temperature, top_p, seed) | |
| # trained | |
| if hasattr(model, "set_adapter"): | |
| model.set_adapter("driftcall") | |
| if hasattr(model, "enable_adapter_layers"): | |
| model.enable_adapter_layers() | |
| return self._run(model, messages, max_new_tokens, temperature, top_p, seed) | |
| def _run( | |
| self, | |
| model: Any, | |
| messages: list[dict[str, str]], | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| seed: int, | |
| ) -> str: | |
| """Default implementation hits HF generate. Tests stub via ``_run``.""" | |
| try: | |
| return str( | |
| model.generate( | |
| messages=messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| seed=seed, | |
| ) | |
| ) | |
| except Exception as exc: | |
| raise RuntimeError(f"model.generate raised: {exc}") from exc | |
| class _NullCtx: | |
| """Trivial context manager fallback when peft.disable_adapter is absent.""" | |
| def __enter__(self) -> _NullCtx: | |
| return self | |
| def __exit__(self, *exc: Any) -> Literal[False]: | |
| return False | |
| _MODEL_LOADER: ModelLoader | None = None | |
| _MODEL_LOADER_LOCK = threading.Lock() | |
| def get_model_loader() -> ModelLoader: | |
| """Return process-wide singleton, instantiated on first call.""" | |
| global _MODEL_LOADER | |
| with _MODEL_LOADER_LOCK: | |
| if _MODEL_LOADER is None: | |
| _MODEL_LOADER = ModelLoader() | |
| return _MODEL_LOADER | |
| def _reset_model_loader_for_tests() -> None: | |
| global _MODEL_LOADER | |
| with _MODEL_LOADER_LOCK: | |
| _MODEL_LOADER = None | |
| # --------------------------------------------------------------------------- | |
| # Hardware probing (deploy_demo_space.md §3.1) | |
| # --------------------------------------------------------------------------- | |
| class HardwareProbe: | |
| zerogpu: bool | |
| a10g: bool | |
| def _probe_hardware() -> HardwareProbe: | |
| """Probe both targets. Patched in tests.""" | |
| fallback = os.environ.get(_HARDWARE_FALLBACK_ENV_VAR, "") | |
| hardware = os.environ.get(_HARDWARE_ENV_VAR, "zero-gpu") | |
| return HardwareProbe(zerogpu=hardware == "zero-gpu", a10g=fallback == "a10g") | |
| class DeploymentAbortedError(RuntimeError): | |
| """Raised when neither ZeroGPU nor A10G is available.""" | |
| def deploy_check() -> str: | |
| """Return the README front-matter ``hardware:`` value to write. | |
| Raises ``DeploymentAbortedError("both-gpus-unavailable")`` when both | |
| fallbacks fail; the pitch reverts to a pre-recorded video. | |
| """ | |
| probe = _probe_hardware() | |
| if probe.zerogpu: | |
| return "zero-gpu" | |
| if probe.a10g: | |
| logger.info("zero-gpu unavailable; falling back to A10G small") | |
| return "a10g-small" | |
| logger.warning("Fall back to pre-recorded video — see risk_book.md") | |
| raise DeploymentAbortedError("both-gpus-unavailable") | |
| # --------------------------------------------------------------------------- | |
| # Audio helpers | |
| # --------------------------------------------------------------------------- | |
| def _safe_silence() -> tuple[int, np.ndarray]: | |
| """1 s of silence at 16 kHz — used as the safe-default audio output.""" | |
| return _FALLBACK_SR_HZ, np.zeros(_FALLBACK_SILENCE_LEN, dtype=np.float32) | |
| def _safe_defaults() -> tuple[str, tuple[int, np.ndarray], pd.DataFrame, dict[str, float], str]: | |
| return "", _safe_silence(), _empty_trace_df(), {}, "" | |
| def _truncate_preview(payload: Any) -> str: | |
| text = str(payload) | |
| if len(text) <= _TRACE_PREVIEW_LEN: | |
| return text | |
| return text[: _TRACE_PREVIEW_LEN - 1] + "…" | |
| # --------------------------------------------------------------------------- | |
| # infer_turn (deploy_demo_space.md §2.2) | |
| # --------------------------------------------------------------------------- | |
| def infer_turn( | |
| audio_tuple: tuple[int, np.ndarray] | None, | |
| checkpoint: CheckpointId, | |
| manual_drift: str | None, | |
| session_id: str, | |
| *, | |
| text_input: str | None = None, | |
| ) -> tuple[str, tuple[int, np.ndarray], pd.DataFrame, dict[str, float], str]: | |
| """Handle one mic-to-speaker turn. Never raises — surfaces every error | |
| via the ``status_msg`` slot of the return tuple.""" | |
| # 1. Session — error 5.7. | |
| try: | |
| state = get_session(session_id) | |
| except SessionCapacityError: | |
| empty, silence, df, defaults, _ = _safe_defaults() | |
| return empty, silence, df, defaults, "Demo at capacity — try again in a minute." | |
| # 2. Audio input — error 5.3. | |
| if audio_tuple is None and not (text_input and text_input.strip()): | |
| empty, silence, _df, empty_rewards, _ = _safe_defaults() | |
| return ( | |
| empty, | |
| silence, | |
| render_trace(state), | |
| empty_rewards, | |
| "No audio received; press mic or type a brief.", | |
| ) | |
| # 3. ASR — error 5.6. | |
| transcript: str = "" | |
| if audio_tuple is not None: | |
| try: | |
| transcript = _run_asr(audio_tuple) | |
| except AudioDecodeError: | |
| empty2, silence2, _df2, empty_rewards2, _ = _safe_defaults() | |
| return ( | |
| empty2, | |
| silence2, | |
| render_trace(state), | |
| empty_rewards2, | |
| "Could not decode mic audio; please try again.", | |
| ) | |
| else: | |
| transcript = (text_input or "").strip() | |
| # 4. Drift consume. | |
| forced = get_drift_bridge().consume(session_id) | |
| if manual_drift is not None: | |
| forced = manual_drift | |
| # 5. Build action — first-turn agent uses transcript as a SPEAK. | |
| state.turn_idx += 1 | |
| state.episode_trace.append( | |
| TraceRow( | |
| turn_idx=state.turn_idx, | |
| actor="user", | |
| action_or_event=transcript, | |
| tool_response_preview="", | |
| reward_delta=0.0, | |
| ) | |
| ) | |
| if forced is not None: | |
| state.episode_trace.append( | |
| TraceRow( | |
| turn_idx=state.turn_idx, | |
| actor="drift", | |
| action_or_event=f"manual:{forced}", | |
| tool_response_preview="", | |
| reward_delta=0.0, | |
| ) | |
| ) | |
| action = DriftCallAction( | |
| action_type=ActionType.SPEAK, | |
| message=transcript or "(empty)", | |
| ) | |
| # 6. Env step — error 5.8. | |
| try: | |
| if forced is not None: | |
| obs = state.env.step(action, force_drift_pattern=forced) | |
| else: | |
| obs = state.env.step(action) | |
| state.last_observation = obs | |
| state.episode_trace.append( | |
| TraceRow( | |
| turn_idx=state.turn_idx, | |
| actor="env", | |
| action_or_event="200 OK", | |
| tool_response_preview=_truncate_preview(obs.last_transcript), | |
| reward_delta=0.0, | |
| ) | |
| ) | |
| except Exception as exc: | |
| state.episode_trace.append( | |
| TraceRow( | |
| turn_idx=state.turn_idx, | |
| actor="env", | |
| action_or_event=f"rejected: {exc.__class__.__name__}", | |
| tool_response_preview=_truncate_preview(str(exc)), | |
| reward_delta=0.0, | |
| ) | |
| ) | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| f"Env rejected action: {exc}; episode unchanged.", | |
| ) | |
| # 7. Generate — errors 5.1 / 5.2 / 5.4. | |
| loader = get_model_loader() | |
| use_checkpoint: CheckpointId = checkpoint | |
| if checkpoint == "trained" and not loader.is_trained_available(): | |
| use_checkpoint = "base" | |
| status_warning = "Trained adapter unavailable; showing base model only." | |
| else: | |
| status_warning = "" | |
| reply: str | |
| try: | |
| reply = _generate_with_retries(loader, transcript, use_checkpoint) | |
| except TrainedAdapterMissingError: | |
| use_checkpoint = "base" | |
| status_warning = "Trained adapter unavailable; showing base model only." | |
| try: | |
| reply = _generate_with_retries(loader, transcript, use_checkpoint) | |
| except Exception as exc2: | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| f"Generate failed: {exc2}", | |
| ) | |
| except _OOMRetryFailure: | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| "GPU out of memory this turn; reducing context and retrying.", | |
| ) | |
| except _ZeroGPUFailure: | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| "GPU unavailable; the demo is running on CPU and will be slow.", | |
| ) | |
| except _TimeoutFailure: | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| "Turn timed out after 60 s — the model was slow; try again.", | |
| ) | |
| except Exception as exc: | |
| return ( | |
| transcript, | |
| _safe_silence(), | |
| render_trace(state), | |
| {}, | |
| f"Generate failed: {exc}", | |
| ) | |
| state.episode_trace.append( | |
| TraceRow( | |
| turn_idx=state.turn_idx, | |
| actor="agent", | |
| action_or_event=f"SPEAK \"{reply[:60]}\"", | |
| tool_response_preview="", | |
| reward_delta=0.0, | |
| ) | |
| ) | |
| # 8. TTS. | |
| try: | |
| audio_out = _run_tts(reply, lang_hint="en") | |
| except Exception: | |
| audio_out = _safe_silence() | |
| rewards: dict[str, float] = {} | |
| state.last_activity_ms = _now_ms() | |
| return transcript, audio_out, render_trace(state), rewards, status_warning | |
| # --------------------------------------------------------------------------- | |
| # Subroutines (kept module-level so tests can patch each one) | |
| # --------------------------------------------------------------------------- | |
| class _OOMRetryFailure(RuntimeError): | |
| pass | |
| class _ZeroGPUFailure(RuntimeError): | |
| pass | |
| class _TimeoutFailure(RuntimeError): | |
| pass | |
| def _run_asr(audio_tuple: tuple[int, np.ndarray]) -> str: | |
| """Default ASR path. Tests patch this to bypass the audio singleton.""" | |
| sr, wav = audio_tuple | |
| if sr != 16000: | |
| # Silent fallback rather than raising — gradio mic clips are usually 16k. | |
| return "" | |
| pcm_bytes = wav.astype(np.float32).tobytes() | |
| from cells.step_09_audio import get_asr_engine | |
| asr = get_asr_engine() | |
| result = asr.transcribe(pcm_bytes, "en") | |
| return result.text | |
| def _run_tts(text: str, *, lang_hint: str = "en") -> tuple[int, np.ndarray]: | |
| """Default TTS path.""" | |
| from cells.step_09_audio import get_tts_engine | |
| tts = get_tts_engine() | |
| lang_for_tts: Any = lang_hint | |
| out: tuple[int, np.ndarray] = tts.synthesize_to_gradio(text, lang_for_tts) | |
| return out | |
| def _generate_with_retries( | |
| loader: ModelLoader, transcript: str, checkpoint: CheckpointId | |
| ) -> str: | |
| """Wraps ``ModelLoader.generate`` with the OOM / ZeroGPU retry policy.""" | |
| messages = [{"role": "user", "content": transcript or "(empty)"}] | |
| try: | |
| return loader.generate(messages, checkpoint=checkpoint, max_new_tokens=256) | |
| except ZeroGPUUnavailableError: | |
| # 5.1 — retry once, then fall back. | |
| time.sleep(0) # advisory; tests patch | |
| try: | |
| return loader.generate(messages, checkpoint=checkpoint, max_new_tokens=256) | |
| except ZeroGPUUnavailableError as exc: | |
| raise _ZeroGPUFailure() from exc | |
| except TimeoutError as exc: | |
| raise _TimeoutFailure() from exc | |
| except Exception as exc: | |
| # Treat any CUDA OOM (real or stub) as 5.4. | |
| msg = str(exc).lower() | |
| if "out of memory" in msg or exc.__class__.__name__ == "OutOfMemoryError": | |
| try: | |
| _empty_cuda_cache() | |
| # Shrink context: drop oldest message + reduce tokens. | |
| shrunk = messages[1:] if len(messages) > 1 else messages | |
| return loader.generate( | |
| shrunk, checkpoint=checkpoint, max_new_tokens=128 | |
| ) | |
| except Exception as exc2: | |
| msg2 = str(exc2).lower() | |
| if "out of memory" in msg2 or exc2.__class__.__name__ == "OutOfMemoryError": | |
| raise _OOMRetryFailure() from exc2 | |
| raise | |
| raise | |
| def _empty_cuda_cache() -> None: | |
| """Best-effort CUDA cache clear. Tests patch this.""" | |
| try: | |
| import torch | |
| if hasattr(torch, "cuda") and hasattr(torch.cuda, "empty_cache"): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| return | |
| # --------------------------------------------------------------------------- | |
| # Warmup + UI builder | |
| # --------------------------------------------------------------------------- | |
| def warmup_on_boot() -> None: | |
| """Page in CUDA kernels: ASR + TTS + dummy generate.""" | |
| try: | |
| from cells.step_09_audio import get_asr_engine, get_tts_engine | |
| asr = get_asr_engine() | |
| tts = get_tts_engine() | |
| if hasattr(asr, "warmup"): | |
| asr.warmup() | |
| if hasattr(tts, "warmup"): | |
| tts.warmup() | |
| except Exception: | |
| logger.exception("audio warmup failed") | |
| try: | |
| loader = get_model_loader() | |
| loader.ensure_loaded() | |
| loader.generate( | |
| [{"role": "user", "content": "warmup"}], checkpoint="base", max_new_tokens=4 | |
| ) | |
| except Exception: | |
| logger.exception("model warmup failed") | |
| def build_ui() -> Any: | |
| """Construct the Gradio Blocks graph. Idempotent and pure.""" | |
| import gradio as gr | |
| loader = get_model_loader() | |
| trained_ok = False | |
| try: | |
| trained_ok = loader.is_trained_available() | |
| except Exception: | |
| trained_ok = False | |
| radio_choices = ["base", "trained"] if trained_ok else ["base"] | |
| radio_label = ( | |
| "Checkpoint" | |
| if trained_ok | |
| else "Checkpoint (Trained adapter unavailable at boot — base only)" | |
| ) | |
| drift_choices: list[Any] = [*_DRIFT_PATTERN_IDS, None] | |
| with gr.Blocks(title="DriftCall Demo") as demo: | |
| session_state = gr.State(value=str(uuid.uuid4())) | |
| with gr.Row(): | |
| mic = gr.Audio( | |
| sources=["microphone"], | |
| type="numpy", | |
| label="Speak your brief", | |
| ) | |
| speaker = gr.Audio( | |
| type="numpy", | |
| label="Speaker", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| checkpoint = gr.Radio( | |
| choices=radio_choices, | |
| value="base", | |
| label=radio_label, | |
| ) | |
| drift = gr.Dropdown( | |
| choices=drift_choices, | |
| value=None, | |
| label="Manual drift", | |
| ) | |
| textbox = gr.Textbox( | |
| placeholder="Or type a brief here", | |
| label="Text fallback", | |
| ) | |
| transcript = gr.Textbox(label="Transcript", interactive=False) | |
| trace = gr.DataFrame( | |
| value=_empty_trace_df(), | |
| wrap=True, | |
| max_height=400, | |
| interactive=False, | |
| headers=list(_TRACE_COLUMNS), | |
| ) | |
| rewards = gr.JSON(label="Rewards") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| reset_btn = gr.Button("New episode") | |
| def _on_submit( | |
| mic_in: Any, | |
| ckpt: Any, | |
| drift_pat: Any, | |
| sid: Any, | |
| text_in: Any, | |
| ) -> Any: | |
| return infer_turn(mic_in, ckpt, drift_pat, sid, text_input=text_in) | |
| mic.stop_recording( | |
| _on_submit, | |
| inputs=[mic, checkpoint, drift, session_state, textbox], | |
| outputs=[transcript, speaker, trace, rewards, status], | |
| ) | |
| def _on_reset(sid: Any) -> Any: | |
| reset_session(sid) | |
| return "", _empty_trace_df(), {}, "Episode reset." | |
| reset_btn.click( | |
| _on_reset, | |
| inputs=[session_state], | |
| outputs=[transcript, trace, rewards, status], | |
| ) | |
| return demo | |
| def _launch_for_production() -> None: | |
| warmup_on_boot() | |
| ui = build_ui() | |
| ui.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) | |
| if __name__ == "__main__": | |
| _launch_for_production() | |
| __all__ = [ | |
| "AudioDecodeError", | |
| "CheckpointId", | |
| "CheckpointMismatchError", | |
| "DemoSessionState", | |
| "DeploymentAbortedError", | |
| "DriftToggleBridge", | |
| "EnvStepError", | |
| "HardwareProbe", | |
| "ModelLoader", | |
| "SessionCapacityError", | |
| "TraceRow", | |
| "TrainedAdapterMissingError", | |
| "ZeroGPUUnavailableError", | |
| "build_ui", | |
| "deploy_check", | |
| "gc_sessions", | |
| "get_drift_bridge", | |
| "get_model_loader", | |
| "get_session", | |
| "infer_turn", | |
| "render_trace", | |
| "reset_session", | |
| "warmup_on_boot", | |
| ] | |