meta-hackathon / src /env.py
5ivatej's picture
Make HTTP sessions stateless for multi-node deployment
ce4a7da
"""Core ESC environment: OpenEnv-style step() / reset() / state()."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from .grader import GradeBreakdown, final_task_score, grade_step
from .models import (
Action,
EnvState,
Observation,
ResetResult,
Reward,
StepResult,
)
from .seeker import (
SeekerState,
Stage,
extract_features,
resolution_score,
step_seeker,
)
from .tasks import TASKS, TaskSpec, get_task
class ESCEnv:
"""Emotional Support Conversations environment.
Usage (in-process):
env = ESCEnv()
obs = env.reset(task_id="work_stress_venting")
result = env.step(Action(message="That sounds really hard. What's weighing on you most right now?"))
"""
def __init__(self) -> None:
self._task: Optional[TaskSpec] = None
self._seeker: Optional[SeekerState] = None
self._turn: int = 0
self._done: bool = False
self._cumulative_reward: float = 0.0
self._transcript: List[Dict[str, str]] = []
self._agent_messages: List[str] = []
self._had_safety_reference: bool = False
self._last_obs: Optional[Observation] = None
# ------------------------------------------------------------------ reset
def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> ResetResult:
"""Reset to a clean initial state for the given task (default: easy)."""
task_id = task_id or "work_stress_venting"
self._task = get_task(task_id)
self._seeker = SeekerState.from_persona(self._task.persona)
self._turn = 0
self._done = False
self._cumulative_reward = 0.0
self._transcript = [
{"role": "seeker", "text": self._task.persona.surface_concern}
]
self._agent_messages = []
self._had_safety_reference = False
obs = Observation(
seeker_utterance=self._task.persona.surface_concern,
turn=0,
remaining_turns=self._task.max_turns,
stage_hint=self._seeker.stage.value,
task_id=self._task.id,
scenario_brief=self._task.persona.scenario_brief,
)
self._last_obs = obs
return ResetResult(
observation=obs,
info={
"difficulty": self._task.difficulty,
"max_turns": self._task.max_turns,
"success_threshold": self._task.success_threshold,
},
)
# ------------------------------------------------------------------- step
def step(self, action: Action) -> StepResult:
if self._task is None or self._seeker is None:
raise RuntimeError("env.step() called before reset()")
if self._done:
raise RuntimeError("env.step() called on a finished episode — call reset()")
# 1. Record the agent's turn.
normalized_message = " ".join(action.message.lower().split())
repetitive = normalized_message in self._agent_messages
self._transcript.append({"role": "agent", "text": action.message})
self._agent_messages.append(normalized_message)
# 2. Snapshot pre-action state (for reward deltas and future-oriented lookahead).
pre_state = self._seeker.snapshot()
# 3. Extract features and advance seeker dynamics.
features = extract_features(action.message)
if features.safety > 0:
self._had_safety_reference = True
transition = step_seeker(self._seeker, features)
post_state = transition.new_state # same object, mutated
self._seeker = post_state
self._turn += 1
transition.flags["repetitive"] = repetitive
# 4. Grade the step.
breakdown: GradeBreakdown = grade_step(
pre_state=pre_state,
post_state=post_state,
features=features,
flags=transition.flags,
)
self._cumulative_reward += breakdown.value
# 5. Record seeker's reply.
self._transcript.append({"role": "seeker", "text": transition.seeker_utterance})
# 6. Termination check.
reached_required_stage = post_state.stage.value == self._task.required_final_stage
met_trust_target = post_state.trust >= self._task.min_final_trust
met_distress_target = post_state.distress <= self._task.max_final_distress
revealed_if_required = (not self._task.require_reveal) or post_state.revealed
safety_if_required = (not self._task.require_safety_reference) or self._had_safety_reference
natural_done = bool(
reached_required_stage
and met_trust_target
and met_distress_target
and revealed_if_required
and safety_if_required
)
trust_collapse = post_state.trust <= 0.05
budget_exhausted = self._turn >= self._task.max_turns
done = bool(natural_done or trust_collapse or budget_exhausted)
self._done = done
# 7. Build the next observation.
obs = Observation(
seeker_utterance=transition.seeker_utterance,
turn=self._turn,
remaining_turns=max(0, self._task.max_turns - self._turn),
stage_hint=post_state.stage.value,
task_id=self._task.id,
scenario_brief=self._task.persona.scenario_brief,
)
self._last_obs = obs
info: Dict[str, Any] = {
"features": features.__dict__,
"flags": transition.flags,
"stage": post_state.stage.value,
"resolution_score": resolution_score(post_state),
"natural_done": natural_done,
"repetitive": repetitive,
"had_safety_reference": self._had_safety_reference,
"meets_trust_target": met_trust_target,
"meets_distress_target": met_distress_target,
"revealed_if_required": revealed_if_required,
"safety_if_required": safety_if_required,
"trust_collapse": trust_collapse,
"budget_exhausted": budget_exhausted,
"reward_components": breakdown.components,
}
if done:
info["final"] = final_task_score(
cumulative_reward=self._cumulative_reward,
steps_taken=self._turn,
max_turns=self._task.max_turns,
final_state=post_state,
success_threshold=self._task.success_threshold,
completed=natural_done,
)
reward_detail = Reward(
value=breakdown.value,
immediate=breakdown.immediate,
future_oriented=breakdown.future_oriented,
penalties=breakdown.penalties,
components={k: float(v) for k, v in breakdown.components.items()},
)
return StepResult(
observation=obs,
reward=breakdown.value,
reward_detail=reward_detail,
done=done,
info=info,
)
# ------------------------------------------------------------------ state
def state(self) -> EnvState:
if self._task is None:
raise RuntimeError("env.state() called before reset()")
return EnvState(
task_id=self._task.id,
turn=self._turn,
max_turns=self._task.max_turns,
done=self._done,
cumulative_reward=self._cumulative_reward,
transcript=list(self._transcript),
)
# ---------------------------------------------------------------- listing
@staticmethod
def list_tasks() -> List[Dict[str, Any]]:
return [
{
"id": t.id,
"difficulty": t.difficulty,
"max_turns": t.max_turns,
"success_threshold": t.success_threshold,
"scenario_brief": t.persona.scenario_brief,
}
for t in TASKS.values()
]
# ------------------------------------------------------------- serialization
def export_state(self) -> Dict[str, Any]:
if self._task is None or self._seeker is None:
raise RuntimeError("env.export_state() called before reset()")
seeker_state = {
"distress": self._seeker.distress,
"trust": self._seeker.trust,
"openness": self._seeker.openness,
"revealed": self._seeker.revealed,
"stage": self._seeker.stage.value,
"last_line_idx_by_stage": {
stage.value: idx for stage, idx in self._seeker.last_line_idx_by_stage.items()
},
"turn": self._seeker.turn,
}
return {
"task_id": self._task.id,
"turn": self._turn,
"done": self._done,
"cumulative_reward": self._cumulative_reward,
"transcript": list(self._transcript),
"agent_messages": list(self._agent_messages),
"had_safety_reference": self._had_safety_reference,
"seeker": seeker_state,
}
@classmethod
def from_state(cls, data: Dict[str, Any]) -> "ESCEnv":
task = get_task(str(data["task_id"]))
seeker_data = data["seeker"]
env = cls()
env._task = task
env._turn = int(data["turn"])
env._done = bool(data["done"])
env._cumulative_reward = float(data["cumulative_reward"])
env._transcript = list(data.get("transcript", []))
env._agent_messages = list(data.get("agent_messages", []))
env._had_safety_reference = bool(data.get("had_safety_reference", False))
env._seeker = SeekerState(
persona=task.persona,
distress=float(seeker_data["distress"]),
trust=float(seeker_data["trust"]),
openness=float(seeker_data["openness"]),
revealed=bool(seeker_data["revealed"]),
stage=Stage(str(seeker_data["stage"])),
last_line_idx_by_stage={
Stage(stage_name): int(idx)
for stage_name, idx in seeker_data["last_line_idx_by_stage"].items()
},
turn=int(seeker_data["turn"]),
)
if env._transcript:
last_seeker_text = next(
(entry["text"] for entry in reversed(env._transcript) if entry.get("role") == "seeker"),
task.persona.surface_concern,
)
env._last_obs = Observation(
seeker_utterance=last_seeker_text,
turn=env._turn,
remaining_turns=max(0, task.max_turns - env._turn),
stage_hint=env._seeker.stage.value,
task_id=task.id,
scenario_brief=task.persona.scenario_brief,
)
return env