"""Cell 10 — DriftCallEnv integration class. Implements ``docs/modules/env.md`` and DESIGN.md §4. ``DriftCallEnv`` is the single public surface that composes models, vendors, drift_injector, task_generator, rewards, and the optional audio boundary into an OpenEnv-compliant RL environment. Hard rules (env.md §3.8, CLAUDE.md §0): - All public dataclasses are frozen. - State transitions go through ``dataclasses.replace``; no in-place mutation. - Validation is pure: ``InvalidActionError`` raises BEFORE any state mutation. - Rewards are computed exactly once at termination and memoized. - No LLM judge anywhere; no network/disk I/O at ``__init__``. """ from __future__ import annotations import os import struct import uuid from dataclasses import dataclass, field, replace from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Literal, Protocol, cast from cells.step_04_models import ( ActionType, DriftCallAction, DriftCallObservation, DriftCallState, DriftEvent, GoalSpec, ToolResult, ) from cells.step_05_vendors import TOOLS as VENDOR_TOOLS from cells.step_05_vendors import VENDOR_REGISTRY from cells.step_06_drift_injector import ( DriftCatalogueError, DriftDomainMismatchError, DriftReapplicationError, DriftScheduleConflictError, UnknownDriftPatternError, apply_drift, build_schedule, list_patterns, ) from cells.step_07_task_generator import ( InvalidLanguageWeightError, InvalidStageError, ) from cells.step_07_task_generator import ( generate as task_generate, ) if TYPE_CHECKING: from collections.abc import Mapping # rewards is imported lazily inside _compute_rewards to keep the env importable # even before step_08_rewards.py lands; failures surface as RewardComputationError. _DEFAULT_LANGUAGE_WEIGHTS: dict[str, float] = { "en": 0.4, "hinglish": 0.4, "hi": 0.1, "ta": 0.05, "kn": 0.05, } _LANGUAGE_CODES: frozenset[str] = frozenset({"hi", "ta", "kn", "en", "hinglish"}) _STAGE_MAX_TURNS: dict[int, int] = {1: 8, 2: 12, 3: 16} _VENDOR_DOMAINS: tuple[str, ...] = ("airline", "cab", "restaurant", "hotel", "payment") _TERMINATED_VALUES: frozenset[str] = frozenset({"SUBMIT", "ABORT", "TIMEOUT", "ANTI_HACK"}) _NOW_IST: datetime = datetime(2026, 4, 25, 10, 0, tzinfo=timezone(timedelta(hours=5, minutes=30))) # --------------------------------------------------------------------------- # Error taxonomy (env.md §5) # --------------------------------------------------------------------------- class DriftCallEnvError(Exception): """Root for every typed env error (env.md §5).""" class InvalidConfigError(DriftCallEnvError): """E1 — malformed config dict.""" class EnvNotReadyError(DriftCallEnvError): """E2 — operation issued before reset().""" class EnvClosedError(DriftCallEnvError): """E3 — operation issued after close().""" class InvalidActionError(DriftCallEnvError): """E4 — action fails the per-ActionType field matrix.""" class EpisodeAlreadyTerminalError(DriftCallEnvError): """E5 — step() called after termination.""" class EpisodeNotTerminalError(DriftCallEnvError): """E6 — episode()/rewards() called before termination.""" class ConcurrentStepError(DriftCallEnvError): """E7 — reentrant step() detected.""" class UnknownDomainError(DriftCallEnvError): """E8 — PROBE_SCHEMA on a domain that is not registered.""" class UnknownToolError(DriftCallEnvError): """E9 — TOOL_CALL with a tool_name not in available_tools().""" class DriftInjectionError(DriftCallEnvError): """E10 — drift fold raised; surfaced as-is.""" class RewardComputationError(DriftCallEnvError): """E11 — compute_rewards raised; surfaced as-is.""" class AudioPipelineError(DriftCallEnvError): """E12 — TTS/ASR engine raised on a step()/reset() boundary.""" _ALL_ERROR_CLASSES: tuple[type[DriftCallEnvError], ...] = ( InvalidConfigError, EnvNotReadyError, EnvClosedError, InvalidActionError, EpisodeAlreadyTerminalError, EpisodeNotTerminalError, ConcurrentStepError, UnknownDomainError, UnknownToolError, DriftInjectionError, RewardComputationError, AudioPipelineError, ) # --------------------------------------------------------------------------- # Protocols (env.md §2.1) # --------------------------------------------------------------------------- class DriftScheduler(Protocol): def __call__( self, stage: int, episode_seed: int, goal: GoalSpec ) -> tuple[DriftEvent, ...]: ... class TTSEngine(Protocol): def synthesize( self, text: str, language_code: str, voice_pack: Any | None = None, *, seed: int = 0, sample_rate_hz: int = 16000, ) -> bytes: ... class ASREngine(Protocol): def transcribe( self, audio_bytes: bytes, language_hint: str | None, *, beam_size: int = 1, vad_filter: bool = True, max_duration_s: float = 30.0, ) -> Any: ... def _default_scheduler( stage: int, episode_seed: int, goal: GoalSpec ) -> tuple[DriftEvent, ...]: return build_schedule(stage, episode_seed, goal) # --------------------------------------------------------------------------- # Episode (env.md §4.3) — built at termination, fed to rewards.compute_rewards. # Matches the Episode shape consumed by step_08_rewards (kw fields). # --------------------------------------------------------------------------- @dataclass(frozen=True) class Episode: episode_id: str goal: GoalSpec actions: tuple[DriftCallAction, ...] action_turns: tuple[int, ...] tool_results: tuple[ToolResult, ...] tool_result_turns: tuple[int, ...] drift_log: tuple[DriftEvent, ...] vendor_states_final: dict[str, dict[str, Any]] schema_versions_final: dict[str, str] max_turns: int turns_used: int terminated_by: Literal["SUBMIT", "ABORT", "TIMEOUT", "ANTI_HACK"] stage: Literal[1, 2, 3] drift_pattern_overrides: dict[str, Any] = field(default_factory=dict) # --------------------------------------------------------------------------- # EnvConfig (env.md §4.1) # --------------------------------------------------------------------------- @dataclass(frozen=True) class EnvConfig: curriculum_stage: Literal[1, 2, 3] language_weights: dict[str, float] audio_boundary_enabled: bool max_turns_override: int | None scheduler: DriftScheduler tts_engine: TTSEngine | None asr_engine: ASREngine | None @classmethod def from_mapping(cls, raw: Mapping[str, Any] | None) -> EnvConfig: allowed = { "curriculum_stage", "language_weights", "audio_boundary_enabled", "max_turns_override", "scheduler", "tts_engine", "asr_engine", } if raw is None: raw = {} if not isinstance(raw, dict): raise InvalidConfigError( f"config must be a dict or None, got {type(raw).__name__}" ) unknown = set(raw.keys()) - allowed if unknown: raise InvalidConfigError( f"unknown config key(s): {sorted(unknown)}; " f"allowed: {sorted(allowed)}" ) stage_raw = raw.get("curriculum_stage", 1) if isinstance(stage_raw, bool) or not isinstance(stage_raw, int): raise InvalidConfigError( f"curriculum_stage must be int in {{1,2,3}}, got " f"{type(stage_raw).__name__}" ) if stage_raw not in (1, 2, 3): raise InvalidConfigError( f"curriculum_stage must be 1, 2, or 3; got {stage_raw!r}" ) stage = cast("Literal[1, 2, 3]", stage_raw) weights_raw = raw.get("language_weights", _DEFAULT_LANGUAGE_WEIGHTS) if not isinstance(weights_raw, dict) or not weights_raw: raise InvalidConfigError( "language_weights must be a non-empty dict" ) for k, v in weights_raw.items(): if k not in _LANGUAGE_CODES: raise InvalidConfigError( f"language_weights: unknown language {k!r}; " f"allowed: {sorted(_LANGUAGE_CODES)}" ) if isinstance(v, bool) or not isinstance(v, (int, float)): raise InvalidConfigError( f"language_weights[{k!r}] must be numeric, got " f"{type(v).__name__}" ) if v < 0: raise InvalidConfigError( f"language_weights[{k!r}]={v} is negative" ) total = sum(float(v) for v in weights_raw.values()) if abs(total - 1.0) > 1e-6: raise InvalidConfigError( f"language_weights sum {total!r} not within 1.0 ± 1e-6" ) # Frozen copy. weights: dict[str, float] = {k: float(v) for k, v in weights_raw.items()} audio_enabled_raw = raw.get("audio_boundary_enabled", False) if not isinstance(audio_enabled_raw, bool): raise InvalidConfigError( f"audio_boundary_enabled must be bool, got " f"{type(audio_enabled_raw).__name__}" ) audio_enabled = audio_enabled_raw max_turns_override = raw.get("max_turns_override") if max_turns_override is not None: if isinstance(max_turns_override, bool) or not isinstance( max_turns_override, int ): raise InvalidConfigError( f"max_turns_override must be int or None, got " f"{type(max_turns_override).__name__}" ) if max_turns_override < 1: raise InvalidConfigError( f"max_turns_override must be >= 1, got {max_turns_override}" ) scheduler = raw.get("scheduler", _default_scheduler) if not callable(scheduler): raise InvalidConfigError("scheduler must be callable") tts_engine = raw.get("tts_engine") asr_engine = raw.get("asr_engine") if audio_enabled: if tts_engine is None: raise InvalidConfigError( "tts_engine is required when audio_boundary_enabled is True" ) if asr_engine is None: raise InvalidConfigError( "asr_engine is required when audio_boundary_enabled is True" ) else: if tts_engine is not None: raise InvalidConfigError( "tts_engine must be None when audio_boundary_enabled is False" ) if asr_engine is not None: raise InvalidConfigError( "asr_engine must be None when audio_boundary_enabled is False" ) return cls( curriculum_stage=stage, language_weights=weights, audio_boundary_enabled=audio_enabled, max_turns_override=max_turns_override, scheduler=cast("DriftScheduler", scheduler), tts_engine=cast("TTSEngine | None", tts_engine), asr_engine=cast("ASREngine | None", asr_engine), ) # --------------------------------------------------------------------------- # DriftCallEnv # --------------------------------------------------------------------------- def _make_seed_from_urandom() -> int: raw = os.urandom(8) (value,) = struct.unpack(" dict[str, Any]: """Coerce a frozen vendor dataclass (or already-dict) into a plain dict.""" if isinstance(state, dict): return dict(state) # All vendor states are frozen dataclasses. import dataclasses as _dc if _dc.is_dataclass(state) and not isinstance(state, type): return _dc.asdict(state) # Defensive: best-effort fallback. return {"_raw": repr(state)} class DriftCallEnv: """OpenEnv-compliant RL environment for DriftCall (env.md §1).""" # -- construction -------------------------------------------------------- def __init__(self, config: dict[str, Any] | None = None) -> None: self._config: EnvConfig = EnvConfig.from_mapping(config) self._state: DriftCallState | None = None self._rewards: Any | None = None self._episode: Episode | None = None self._closed: bool = False self._seed: int | None = None self._episode_id: str | None = None # Pending side-channel notices keyed by domain (env.md §3.3). self._side_channel_pending: dict[str, str] = {} # Per-vendor-state cache (frozen dataclass or dict). Kept on the env # because DriftCallState.vendor_states is a dict[str, dict] for # compatibility with the design dataclass. self._vendor_state_objects: dict[str, Any] = {} # Re-entrancy guard (E7). self._step_in_progress: bool = False # -- internal helpers ---------------------------------------------------- @property def _max_turns(self) -> int: if self._config.max_turns_override is not None: return int(self._config.max_turns_override) return _STAGE_MAX_TURNS[self._config.curriculum_stage] def _available_tools(self) -> tuple[str, ...]: return VENDOR_TOOLS def _ensure_ready_for_step(self) -> None: if self._closed: raise EnvClosedError("env is closed") if self._state is None: raise EnvNotReadyError("reset() must be called before step()") if self._state.done: raise EpisodeAlreadyTerminalError( f"episode already terminated (terminated_by={self._terminated_by()})" ) def _terminated_by(self) -> str | None: return self._episode.terminated_by if self._episode is not None else None # -- OpenEnv primitives -------------------------------------------------- def reset(self, seed: int | None = None) -> DriftCallObservation: if self._closed: raise EnvClosedError("env is closed") if seed is None: seed = _make_seed_from_urandom() if isinstance(seed, bool) or not isinstance(seed, int): raise InvalidActionError( f"seed must be int or None, got {type(seed).__name__}" ) self._seed = int(seed) # Reset memoization; legacy state is dropped before any propagatable # exception can leak (env.md §2.2 docstring). self._state = None self._rewards = None self._episode = None self._side_channel_pending = {} self._vendor_state_objects = {} self._episode_id = None try: goal = task_generate( self._seed, self._config.curriculum_stage, cast("dict[Any, float]", self._config.language_weights), ) except (InvalidLanguageWeightError, InvalidStageError) as exc: # E1-class reset failure (env.md §2.2 raises clause). raise InvalidConfigError(str(exc)) from exc # Initial per-domain vendor state objects (frozen dataclasses). vendor_state_objects: dict[str, Any] = {} vendor_states_dict: dict[str, dict[str, Any]] = {} for domain in _VENDOR_DOMAINS: ns = VENDOR_REGISTRY[domain] vs = ns.initial_state(self._seed, goal) vendor_state_objects[domain] = vs vendor_states_dict[domain] = _vendor_state_to_dict(vs) schema_versions = {d: "v1" for d in _VENDOR_DOMAINS} try: schedule = self._config.scheduler( self._config.curriculum_stage, self._seed, goal ) except ( DriftScheduleConflictError, DriftCatalogueError, UnknownDriftPatternError, DriftDomainMismatchError, ) as exc: # Bad scheduler at reset is an E1 (env.md §7.4). raise InvalidConfigError(f"scheduler failure: {exc}") from exc self._episode_id = uuid.uuid4().hex max_turns = self._max_turns new_state = DriftCallState( episode_id=self._episode_id, goal=goal, vendor_states=vendor_states_dict, schema_versions=schema_versions, drift_schedule=tuple(schedule), drift_fired=(), turn=0, max_turns=max_turns, actions=(), done=False, ) self._state = new_state self._vendor_state_objects = vendor_state_objects if self._config.audio_boundary_enabled: tts = self._config.tts_engine assert tts is not None # validated in EnvConfig try: tts.synthesize(goal.seed_utterance, goal.language) except Exception as exc: # noqa: BLE001 — surface as E12-class # Audio failure on reset leaves env unready (env.md §5 E12). self._state = None self._vendor_state_objects = {} self._episode_id = None raise AudioPipelineError(f"TTS reset failure: {exc}") from exc return self._build_observation() def step( self, action: DriftCallAction, *, force_drift_pattern: str | None = None, ) -> DriftCallObservation: # 1a. Pure validation — must raise before any state mutation. self._ensure_ready_for_step() self._validate_action(action) if force_drift_pattern is not None: valid_ids = {p.id for p in list_patterns()} if force_drift_pattern not in valid_ids: raise InvalidActionError( f"force_drift_pattern {force_drift_pattern!r} not a known " f"pattern_id" ) if self._step_in_progress: raise ConcurrentStepError("reentrant step() detected") self._step_in_progress = True try: return self._step_inner(action, force_drift_pattern) finally: self._step_in_progress = False def _step_inner( self, action: DriftCallAction, force_drift_pattern: str | None, ) -> DriftCallObservation: assert self._state is not None # ensured above # 2. Increment turn counter. turn_current = self._state.turn + 1 self._state = replace(self._state, turn=turn_current) # 3. Fire drifts for this turn. self._fire_drifts(turn_current, force_drift_pattern) # 4. Side-channel emit pass — refresh pending notices for any vendor # whose state just mutated. self._emit_side_channel() # 5. Dispatch action. new_tool_result, terminate, terminated_by = self._dispatch(action) # 6. Record action (and ToolResult, if any) via dataclasses.replace. new_actions = self._state.actions + (action,) if new_tool_result is not None: # Tool result history lives on the state's vendor history; here we # rely on the running observation history we will rebuild in §3.4. self._tool_results = self._tool_results + (new_tool_result,) self._tool_result_turns = self._tool_result_turns + (turn_current,) self._action_turns = self._action_turns + (turn_current,) self._state = replace(self._state, actions=new_actions) # 7. Budget check — only if action did not already terminate. if not terminate and turn_current >= self._state.max_turns: terminate = True terminated_by = "TIMEOUT" # 8. If terminal, build Episode + compute rewards. if terminate: assert terminated_by is not None self._terminate(terminated_by) # 9. Build observation. return self._build_observation() def state(self) -> DriftCallState: if self._state is None: raise EnvNotReadyError("reset() must be called before state()") return self._state def close(self) -> None: # Idempotent. self._closed = True # Per env.md §9 Q7: never invoke close on shared audio engines. # Only drop per-env state. self._side_channel_pending = {} self._vendor_state_objects = {} # Note: we keep self._state, self._rewards, self._episode so post-close # audits still work (env.md §7.11). def episode(self) -> Episode: if self._episode is None: raise EpisodeNotTerminalError("episode is not terminal") return self._episode def rewards(self) -> Any: if self._rewards is None: raise EpisodeNotTerminalError("episode is not terminal") return self._rewards def done(self) -> bool: if self._state is None: return False return bool(self._state.done) # -- validation ---------------------------------------------------------- def _validate_action(self, action: DriftCallAction) -> None: if not isinstance(action, DriftCallAction): raise InvalidActionError( f"action must be DriftCallAction, got {type(action).__name__}" ) atype = action.action_type if not isinstance(atype, ActionType): raise InvalidActionError( f"action_type must be ActionType, got {type(atype).__name__}" ) # rationale length cap (env.md §3.1). if action.rationale is not None and len(action.rationale) > 200: raise InvalidActionError( f"rationale length {len(action.rationale)} exceeds 200" ) if atype == ActionType.TOOL_CALL: if not action.tool_name or not isinstance(action.tool_name, str): raise InvalidActionError("TOOL_CALL requires non-empty tool_name") if action.tool_args is None or not isinstance(action.tool_args, dict): raise InvalidActionError( "TOOL_CALL requires tool_args dict (may be empty)" ) if action.message is not None or action.confidence is not None: raise InvalidActionError( "TOOL_CALL forbids message/confidence" ) if action.tool_name not in self._available_tools(): raise UnknownToolError( f"tool_name {action.tool_name!r} not in available_tools()" ) # JSON-serializability (shallow check: must be dict; values arbitrary). return if atype == ActionType.SPEAK or atype == ActionType.CLARIFY: if not isinstance(action.message, str): raise InvalidActionError( f"{atype.value} requires str message" ) if not (1 <= len(action.message) <= 2000): raise InvalidActionError( f"{atype.value} message length must be in [1, 2000], " f"got {len(action.message)}" ) if "\x00" in action.message: raise InvalidActionError( f"{atype.value} message contains NUL byte" ) if ( action.tool_name is not None or action.tool_args is not None or action.confidence is not None ): raise InvalidActionError( f"{atype.value} forbids tool_name/tool_args/confidence" ) return if atype == ActionType.PROBE_SCHEMA: if not action.tool_name or not isinstance(action.tool_name, str): raise InvalidActionError( "PROBE_SCHEMA requires tool_name (domain string)" ) if ( action.tool_args is not None or action.message is not None or action.confidence is not None ): raise InvalidActionError( "PROBE_SCHEMA forbids tool_args/message/confidence" ) assert self._state is not None if action.tool_name not in self._state.vendor_states: raise UnknownDomainError( f"PROBE_SCHEMA: domain {action.tool_name!r} not registered" ) return if atype == ActionType.SUBMIT: if action.confidence is None or not isinstance( action.confidence, (int, float) ) or isinstance(action.confidence, bool): raise InvalidActionError("SUBMIT requires float confidence") conf = float(action.confidence) if not (0.0 <= conf <= 1.0): raise InvalidActionError( f"SUBMIT confidence {conf!r} outside [0.0, 1.0]" ) if action.tool_name is not None or action.tool_args is not None: raise InvalidActionError( "SUBMIT forbids tool_name/tool_args" ) if action.message is not None and not isinstance(action.message, str): raise InvalidActionError("SUBMIT message must be str if present") return if atype == ActionType.ABORT: if ( action.tool_name is not None or action.tool_args is not None or action.confidence is not None ): raise InvalidActionError( "ABORT forbids tool_name/tool_args/confidence" ) return # Unreachable — all six ActionType members handled above. raise InvalidActionError(f"unhandled action_type {atype!r}") # -- drift firing -------------------------------------------------------- def _fire_drifts(self, turn_current: int, force_pattern: str | None) -> None: assert self._state is not None if force_pattern is not None: patterns_by_id = {p.id: p for p in list_patterns()} pattern = patterns_by_id[force_pattern] if pattern.domain not in self._state.vendor_states: raise DriftInjectionError( f"force_drift_pattern {force_pattern!r}: domain " f"{pattern.domain!r} not registered" ) event = DriftEvent( turn=turn_current, drift_type=pattern.drift_type, domain=pattern.domain, description=pattern.description, from_version=pattern.from_version, to_version=pattern.to_version, pattern_id=pattern.id, ) try: self._state = apply_drift(self._state, event) except ( UnknownDriftPatternError, DriftDomainMismatchError, DriftReapplicationError, ) as exc: raise DriftInjectionError(str(exc)) from exc return # Schedule-driven fold. pending = tuple( e for e in self._state.drift_schedule if e.turn == turn_current and e not in self._state.drift_fired ) if not pending: return ordered = tuple(sorted(pending, key=lambda e: (e.turn, e.pattern_id))) for event in ordered: try: self._state = apply_drift(self._state, event) except ( UnknownDriftPatternError, DriftDomainMismatchError, DriftReapplicationError, ) as exc: raise DriftInjectionError(str(exc)) from exc def _emit_side_channel(self) -> None: """Refresh pending side-channel notices per env.md §3.3 clause 3.""" assert self._state is not None new_pending = dict(self._side_channel_pending) for domain in _VENDOR_DOMAINS: ns = VENDOR_REGISTRY[domain] vs_obj = self._vendor_state_objects.get(domain) if vs_obj is None: continue try: notice, new_state = ns.emit_side_channel_if_pending(vs_obj) except Exception as exc: # noqa: BLE001 — defensive raise DriftInjectionError( f"side-channel emit failed for {domain}: {exc}" ) from exc if notice is not None: existing = new_pending.get(domain) merged = ( f"{existing}\n---\n{notice}" if existing else notice ) new_pending[domain] = merged self._vendor_state_objects[domain] = new_state self._side_channel_pending = new_pending # -- dispatch ------------------------------------------------------------ @property def _tool_results(self) -> tuple[ToolResult, ...]: return getattr(self, "_tool_results_internal", ()) @_tool_results.setter def _tool_results(self, value: tuple[ToolResult, ...]) -> None: self._tool_results_internal = value @property def _tool_result_turns(self) -> tuple[int, ...]: return getattr(self, "_tool_result_turns_internal", ()) @_tool_result_turns.setter def _tool_result_turns(self, value: tuple[int, ...]) -> None: self._tool_result_turns_internal = value @property def _action_turns(self) -> tuple[int, ...]: return getattr(self, "_action_turns_internal", ()) @_action_turns.setter def _action_turns(self, value: tuple[int, ...]) -> None: self._action_turns_internal = value def _dispatch( self, action: DriftCallAction ) -> tuple[ToolResult | None, bool, str | None]: """Return (tool_result, terminate?, terminated_by?).""" assert self._state is not None atype = action.action_type if atype == ActionType.SUBMIT: return None, True, "SUBMIT" if atype == ActionType.ABORT: return None, True, "ABORT" if atype == ActionType.SPEAK or atype == ActionType.CLARIFY: return None, False, None if atype == ActionType.PROBE_SCHEMA: assert action.tool_name is not None domain = action.tool_name ns = VENDOR_REGISTRY[domain] vs_obj = self._vendor_state_objects[domain] schema_version = self._state.schema_versions[domain] schema = ns.describe_schema(vs_obj, schema_version) tr = ToolResult( tool_name=f"probe:{domain}", status="ok", response=dict(schema), schema_version=schema_version, latency_ms=0, ) return tr, False, None if atype == ActionType.TOOL_CALL: assert action.tool_name is not None and action.tool_args is not None tool_name = action.tool_name domain = tool_name.split(".", 1)[0] if domain not in self._state.vendor_states: raise UnknownDomainError( f"tool {tool_name!r} targets unknown domain {domain!r}" ) ns = VENDOR_REGISTRY[domain] vs_obj = self._vendor_state_objects[domain] schema_version = self._state.schema_versions[domain] try: if domain == "payment": tr, new_vs = ns.dispatch( tool_name, action.tool_args, vs_obj, schema_version, self._seed, _NOW_IST, ) payment_state = new_vs else: payment_state = self._vendor_state_objects.get("payment") tr, new_vs, payment_state = ns.dispatch( tool_name, action.tool_args, vs_obj, schema_version, self._seed, _NOW_IST, payment_state, ) except ValueError as exc: # Unknown tool inside a known domain → treat as anti-hack. raise UnknownToolError(str(exc)) from exc self._vendor_state_objects[domain] = new_vs if payment_state is not None: self._vendor_state_objects["payment"] = payment_state # Refresh state.vendor_states snapshot. new_vendor_states = dict(self._state.vendor_states) new_vendor_states[domain] = _vendor_state_to_dict(new_vs) if domain != "payment" and payment_state is not None: new_vendor_states["payment"] = _vendor_state_to_dict(payment_state) self._state = replace(self._state, vendor_states=new_vendor_states) # Attach pending side-channel notice (one-shot per domain). notice = self._side_channel_pending.pop(domain, None) if notice is not None: merged_response = dict(tr.response) merged_response["_notice"] = notice tr = ToolResult( tool_name=tr.tool_name, status=tr.status, response=merged_response, schema_version=tr.schema_version, latency_ms=tr.latency_ms, ) return tr, False, None # Unreachable. raise InvalidActionError(f"unhandled action_type {atype!r}") # -- termination --------------------------------------------------------- def _terminate(self, terminated_by: str) -> None: assert self._state is not None if terminated_by not in _TERMINATED_VALUES: raise RewardComputationError( f"unknown terminated_by sentinel {terminated_by!r}" ) self._state = replace(self._state, done=True) episode = Episode( episode_id=self._state.episode_id, goal=self._state.goal, actions=self._state.actions, action_turns=self._action_turns, tool_results=self._tool_results, tool_result_turns=self._tool_result_turns, drift_log=self._state.drift_fired, vendor_states_final={ d: _vendor_state_to_dict(self._vendor_state_objects[d]) for d in _VENDOR_DOMAINS }, schema_versions_final=dict(self._state.schema_versions), max_turns=self._state.max_turns, turns_used=len(self._state.actions), terminated_by=cast( "Literal['SUBMIT','ABORT','TIMEOUT','ANTI_HACK']", terminated_by ), stage=self._config.curriculum_stage, ) self._episode = episode self._rewards = self._compute_rewards(episode) @staticmethod def _compute_rewards(episode: Episode) -> Any: import importlib try: mod = importlib.import_module("cells.step_08_rewards") except ImportError as exc: raise RewardComputationError( f"rewards module unavailable: {exc}" ) from exc compute = getattr(mod, "compute_rewards", None) if compute is None: raise RewardComputationError( "cells.step_08_rewards has no compute_rewards" ) try: return compute(episode) except Exception as exc: raise RewardComputationError(str(exc)) from exc # -- observation builder ------------------------------------------------- def _build_observation(self) -> DriftCallObservation: assert self._state is not None st = self._state if st.turn == 0: last_transcript = st.goal.seed_utterance last_lang = st.goal.language last_confidence = 1.0 else: last_transcript = st.goal.seed_utterance last_lang = st.goal.language last_confidence = 1.0 return DriftCallObservation( turn=st.turn, goal=st.goal, last_transcript=last_transcript, last_lang=last_lang, last_confidence=last_confidence, tool_results=self._tool_results, drift_log=st.drift_fired, budget_remaining=max(0, st.max_turns - st.turn), available_tools=self._available_tools(), ) __all__ = [ "ASREngine", "AudioPipelineError", "ConcurrentStepError", "DriftCallEnv", "DriftCallEnvError", "DriftInjectionError", "DriftScheduler", "EnvClosedError", "EnvConfig", "EnvNotReadyError", "Episode", "EpisodeAlreadyTerminalError", "EpisodeNotTerminalError", "InvalidActionError", "InvalidConfigError", "RewardComputationError", "TTSEngine", "UnknownDomainError", "UnknownToolError", ]