Spaces:
Sleeping
Sleeping
| """Core ESC environment: OpenEnv-style step() / reset() / state().""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from .grader import GradeBreakdown, final_task_score, grade_step | |
| from .models import ( | |
| Action, | |
| EnvState, | |
| Observation, | |
| ResetResult, | |
| Reward, | |
| StepResult, | |
| ) | |
| from .seeker import ( | |
| SeekerState, | |
| Stage, | |
| extract_features, | |
| resolution_score, | |
| step_seeker, | |
| ) | |
| from .tasks import TASKS, TaskSpec, get_task | |
| class ESCEnv: | |
| """Emotional Support Conversations environment. | |
| Usage (in-process): | |
| env = ESCEnv() | |
| obs = env.reset(task_id="work_stress_venting") | |
| result = env.step(Action(message="That sounds really hard. What's weighing on you most right now?")) | |
| """ | |
| def __init__(self) -> None: | |
| self._task: Optional[TaskSpec] = None | |
| self._seeker: Optional[SeekerState] = None | |
| self._turn: int = 0 | |
| self._done: bool = False | |
| self._cumulative_reward: float = 0.0 | |
| self._transcript: List[Dict[str, str]] = [] | |
| self._agent_messages: List[str] = [] | |
| self._had_safety_reference: bool = False | |
| self._last_obs: Optional[Observation] = None | |
| # ------------------------------------------------------------------ reset | |
| def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> ResetResult: | |
| """Reset to a clean initial state for the given task (default: easy).""" | |
| task_id = task_id or "work_stress_venting" | |
| self._task = get_task(task_id) | |
| self._seeker = SeekerState.from_persona(self._task.persona) | |
| self._turn = 0 | |
| self._done = False | |
| self._cumulative_reward = 0.0 | |
| self._transcript = [ | |
| {"role": "seeker", "text": self._task.persona.surface_concern} | |
| ] | |
| self._agent_messages = [] | |
| self._had_safety_reference = False | |
| obs = Observation( | |
| seeker_utterance=self._task.persona.surface_concern, | |
| turn=0, | |
| remaining_turns=self._task.max_turns, | |
| stage_hint=self._seeker.stage.value, | |
| task_id=self._task.id, | |
| scenario_brief=self._task.persona.scenario_brief, | |
| ) | |
| self._last_obs = obs | |
| return ResetResult( | |
| observation=obs, | |
| info={ | |
| "difficulty": self._task.difficulty, | |
| "max_turns": self._task.max_turns, | |
| "success_threshold": self._task.success_threshold, | |
| }, | |
| ) | |
| # ------------------------------------------------------------------- step | |
| def step(self, action: Action) -> StepResult: | |
| if self._task is None or self._seeker is None: | |
| raise RuntimeError("env.step() called before reset()") | |
| if self._done: | |
| raise RuntimeError("env.step() called on a finished episode — call reset()") | |
| # 1. Record the agent's turn. | |
| normalized_message = " ".join(action.message.lower().split()) | |
| repetitive = normalized_message in self._agent_messages | |
| self._transcript.append({"role": "agent", "text": action.message}) | |
| self._agent_messages.append(normalized_message) | |
| # 2. Snapshot pre-action state (for reward deltas and future-oriented lookahead). | |
| pre_state = self._seeker.snapshot() | |
| # 3. Extract features and advance seeker dynamics. | |
| features = extract_features(action.message) | |
| if features.safety > 0: | |
| self._had_safety_reference = True | |
| transition = step_seeker(self._seeker, features) | |
| post_state = transition.new_state # same object, mutated | |
| self._seeker = post_state | |
| self._turn += 1 | |
| transition.flags["repetitive"] = repetitive | |
| # 4. Grade the step. | |
| breakdown: GradeBreakdown = grade_step( | |
| pre_state=pre_state, | |
| post_state=post_state, | |
| features=features, | |
| flags=transition.flags, | |
| ) | |
| self._cumulative_reward += breakdown.value | |
| # 5. Record seeker's reply. | |
| self._transcript.append({"role": "seeker", "text": transition.seeker_utterance}) | |
| # 6. Termination check. | |
| reached_required_stage = post_state.stage.value == self._task.required_final_stage | |
| met_trust_target = post_state.trust >= self._task.min_final_trust | |
| met_distress_target = post_state.distress <= self._task.max_final_distress | |
| revealed_if_required = (not self._task.require_reveal) or post_state.revealed | |
| safety_if_required = (not self._task.require_safety_reference) or self._had_safety_reference | |
| natural_done = bool( | |
| reached_required_stage | |
| and met_trust_target | |
| and met_distress_target | |
| and revealed_if_required | |
| and safety_if_required | |
| ) | |
| trust_collapse = post_state.trust <= 0.05 | |
| budget_exhausted = self._turn >= self._task.max_turns | |
| done = bool(natural_done or trust_collapse or budget_exhausted) | |
| self._done = done | |
| # 7. Build the next observation. | |
| obs = Observation( | |
| seeker_utterance=transition.seeker_utterance, | |
| turn=self._turn, | |
| remaining_turns=max(0, self._task.max_turns - self._turn), | |
| stage_hint=post_state.stage.value, | |
| task_id=self._task.id, | |
| scenario_brief=self._task.persona.scenario_brief, | |
| ) | |
| self._last_obs = obs | |
| info: Dict[str, Any] = { | |
| "features": features.__dict__, | |
| "flags": transition.flags, | |
| "stage": post_state.stage.value, | |
| "resolution_score": resolution_score(post_state), | |
| "natural_done": natural_done, | |
| "repetitive": repetitive, | |
| "had_safety_reference": self._had_safety_reference, | |
| "meets_trust_target": met_trust_target, | |
| "meets_distress_target": met_distress_target, | |
| "revealed_if_required": revealed_if_required, | |
| "safety_if_required": safety_if_required, | |
| "trust_collapse": trust_collapse, | |
| "budget_exhausted": budget_exhausted, | |
| "reward_components": breakdown.components, | |
| } | |
| if done: | |
| info["final"] = final_task_score( | |
| cumulative_reward=self._cumulative_reward, | |
| steps_taken=self._turn, | |
| max_turns=self._task.max_turns, | |
| final_state=post_state, | |
| success_threshold=self._task.success_threshold, | |
| completed=natural_done, | |
| ) | |
| reward_detail = Reward( | |
| value=breakdown.value, | |
| immediate=breakdown.immediate, | |
| future_oriented=breakdown.future_oriented, | |
| penalties=breakdown.penalties, | |
| components={k: float(v) for k, v in breakdown.components.items()}, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=breakdown.value, | |
| reward_detail=reward_detail, | |
| done=done, | |
| info=info, | |
| ) | |
| # ------------------------------------------------------------------ state | |
| def state(self) -> EnvState: | |
| if self._task is None: | |
| raise RuntimeError("env.state() called before reset()") | |
| return EnvState( | |
| task_id=self._task.id, | |
| turn=self._turn, | |
| max_turns=self._task.max_turns, | |
| done=self._done, | |
| cumulative_reward=self._cumulative_reward, | |
| transcript=list(self._transcript), | |
| ) | |
| # ---------------------------------------------------------------- listing | |
| def list_tasks() -> List[Dict[str, Any]]: | |
| return [ | |
| { | |
| "id": t.id, | |
| "difficulty": t.difficulty, | |
| "max_turns": t.max_turns, | |
| "success_threshold": t.success_threshold, | |
| "scenario_brief": t.persona.scenario_brief, | |
| } | |
| for t in TASKS.values() | |
| ] | |
| # ------------------------------------------------------------- serialization | |
| def export_state(self) -> Dict[str, Any]: | |
| if self._task is None or self._seeker is None: | |
| raise RuntimeError("env.export_state() called before reset()") | |
| seeker_state = { | |
| "distress": self._seeker.distress, | |
| "trust": self._seeker.trust, | |
| "openness": self._seeker.openness, | |
| "revealed": self._seeker.revealed, | |
| "stage": self._seeker.stage.value, | |
| "last_line_idx_by_stage": { | |
| stage.value: idx for stage, idx in self._seeker.last_line_idx_by_stage.items() | |
| }, | |
| "turn": self._seeker.turn, | |
| } | |
| return { | |
| "task_id": self._task.id, | |
| "turn": self._turn, | |
| "done": self._done, | |
| "cumulative_reward": self._cumulative_reward, | |
| "transcript": list(self._transcript), | |
| "agent_messages": list(self._agent_messages), | |
| "had_safety_reference": self._had_safety_reference, | |
| "seeker": seeker_state, | |
| } | |
| def from_state(cls, data: Dict[str, Any]) -> "ESCEnv": | |
| task = get_task(str(data["task_id"])) | |
| seeker_data = data["seeker"] | |
| env = cls() | |
| env._task = task | |
| env._turn = int(data["turn"]) | |
| env._done = bool(data["done"]) | |
| env._cumulative_reward = float(data["cumulative_reward"]) | |
| env._transcript = list(data.get("transcript", [])) | |
| env._agent_messages = list(data.get("agent_messages", [])) | |
| env._had_safety_reference = bool(data.get("had_safety_reference", False)) | |
| env._seeker = SeekerState( | |
| persona=task.persona, | |
| distress=float(seeker_data["distress"]), | |
| trust=float(seeker_data["trust"]), | |
| openness=float(seeker_data["openness"]), | |
| revealed=bool(seeker_data["revealed"]), | |
| stage=Stage(str(seeker_data["stage"])), | |
| last_line_idx_by_stage={ | |
| Stage(stage_name): int(idx) | |
| for stage_name, idx in seeker_data["last_line_idx_by_stage"].items() | |
| }, | |
| turn=int(seeker_data["turn"]), | |
| ) | |
| if env._transcript: | |
| last_seeker_text = next( | |
| (entry["text"] for entry in reversed(env._transcript) if entry.get("role") == "seeker"), | |
| task.persona.surface_concern, | |
| ) | |
| env._last_obs = Observation( | |
| seeker_utterance=last_seeker_text, | |
| turn=env._turn, | |
| remaining_turns=max(0, task.max_turns - env._turn), | |
| stage_hint=env._seeker.stage.value, | |
| task_id=task.id, | |
| scenario_brief=task.persona.scenario_brief, | |
| ) | |
| return env | |