""" openenv_wrapper.py — OpenEnv-compatible environment wrapping StudentSimulator. """ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Any, Optional try: import openenv # type: ignore if not hasattr(openenv, 'Environment'): raise ImportError _OPENENV_AVAILABLE = True except ImportError: _OPENENV_AVAILABLE = False class _EnvBase: def reset(self): raise NotImplementedError def step(self, action): raise NotImplementedError openenv = type("openenv", (), {"Environment": _EnvBase})() from .student_fsm import ( MisconceptionType, StudentType, StudentSimulator, TutorAction, StudentState, encode_state, ) # --------------------------------------------------------------------------- # Observation # --------------------------------------------------------------------------- @dataclass class Observation: student_response: str confusion: float attention: float learning_trend: float turn: int misconception_id: MisconceptionType student_type: StudentType last_action: Optional[TutorAction] = None steps_taken: int = 0 recent_actions: tuple[str, ...] = field(default_factory=tuple) def to_dict(self) -> dict[str, Any]: return { "student_response": self.student_response, "confusion": self.confusion, "attention": self.attention, "learning_trend": self.learning_trend, "turn": self.turn, "misconception_id": self.misconception_id.value, "student_type": self.student_type.value, "last_action": self.last_action.value if self.last_action else None, "steps_taken": self.steps_taken, "recent_actions": list(self.recent_actions), } def to_numpy(self): from .student_fsm import StudentState, encode_state snap = StudentState( misconception_id = self.misconception_id, student_type = self.student_type, confusion = self.confusion, attention = self.attention, learning_trend = self.learning_trend, turn = self.turn, last_action = self.last_action, ) return encode_state(snap) # --------------------------------------------------------------------------- # StepInfo # --------------------------------------------------------------------------- @dataclass class StepInfo: done_reason: Optional[str] raw_action: str parsed_action: Optional[TutorAction] misconception_id: MisconceptionType # --------------------------------------------------------------------------- # Strategy tag parser # --------------------------------------------------------------------------- _STRATEGY_PATTERN = re.compile( r"\s*([a-z_]+)\s*", re.IGNORECASE, ) def _parse_action(action_str: str) -> Optional[TutorAction]: match = _STRATEGY_PATTERN.search(action_str) if not match: return None try: return TutorAction(match.group(1).lower()) except ValueError: return None # --------------------------------------------------------------------------- # Teaching phase shaping # --------------------------------------------------------------------------- _PHASE_BONUS = 0.2 # small — guide without dominating _PHASE_PREFERRED: dict[int, set[TutorAction]] = { 0: {TutorAction.QUESTION, TutorAction.EXPLAIN}, # diagnose (confusion > 6) 1: {TutorAction.WORKED_EXAMPLE, TutorAction.ANALOGIZE}, # intervene (confusion > 3.5) 2: {TutorAction.QUESTION, TutorAction.CORRECT_FACT}, # consolidate (confusion <= 3.5) } def _phase_index(confusion: float) -> int: if confusion > 6.0: return 0 if confusion > 3.5: return 1 return 2 # --------------------------------------------------------------------------- # Episode constants # --------------------------------------------------------------------------- MAX_TURNS = 15 CONFUSION_SUCCESS = 2.0 ATTENTION_FLOOR = 0.5 # Lowered — soft penalty only above, terminal only at near-zero # --------------------------------------------------------------------------- # EduForgeEnv # --------------------------------------------------------------------------- class EduForgeEnv(openenv.Environment): """ OpenEnv-compatible tutoring environment. Done conditions --------------- success : confusion <= 2.0 timeout : turn_count >= 15 disengaged : attention <= 0.5 (near-zero only — no premature kills) """ def __init__( self, seed: Optional[int] = None, confusion_init: Optional[float] = None, attention_init: Optional[float] = None, misconception_init: Optional[str] = None, ) -> None: self._seed = seed self._confusion_init = confusion_init self._attention_init = attention_init self._misconception_init = misconception_init self._sim: Optional[StudentSimulator] = None from ..rewards.engine import RewardEngine self._reward_engine = RewardEngine() self._turn_count: int = 0 self._recent_actions: list[TutorAction] = [] # ------------------------------------------------------------------ # reset # ------------------------------------------------------------------ def reset(self) -> Observation: m_init: Optional[MisconceptionType] = None if self._misconception_init is not None: try: m_init = MisconceptionType(self._misconception_init) except ValueError: pass self._sim = StudentSimulator( seed = self._seed, confusion_init = self._confusion_init, attention_init = self._attention_init, misconception_init = m_init, ) self._turn_count = 0 self._recent_actions = [] self._reward_engine.reset() snap = self._sim.state_snapshot() initial_response = self._sim.generate_response() return Observation( student_response = initial_response, confusion = snap.confusion, attention = snap.attention, learning_trend = snap.learning_trend, turn = self._turn_count, misconception_id = snap.misconception_id, student_type = snap.student_type, last_action = None, steps_taken = 0, recent_actions = tuple(), ) # ------------------------------------------------------------------ # step # ------------------------------------------------------------------ def step( self, action_str: str ) -> tuple[Observation, float, bool, StepInfo]: if self._sim is None: raise RuntimeError("Call reset() before step().") # 1. Parse parsed_action = _parse_action(action_str) format_valid = parsed_action is not None if parsed_action is None: parsed_action = TutorAction.REPEAT # 2. Pre-transition snapshot confusion_before = self._sim.confusion attention_before = self._sim.attention misconception_now = self._sim.misconception_id # 3. Phase bonus phase = _phase_index(confusion_before) phase_bonus = _PHASE_BONUS if parsed_action in _PHASE_PREFERRED[phase] else 0.0 # 4. Transition self._sim.transition(parsed_action) self._turn_count += 1 snap = self._sim.state_snapshot() # 5. Done conditions done = False done_reason: Optional[str] = None if snap.confusion <= CONFUSION_SUCCESS: done = True done_reason = "success" elif snap.attention <= ATTENTION_FLOOR: done = True done_reason = "disengaged" elif self._turn_count >= MAX_TURNS: done = True done_reason = "timeout" # 6. Reward reward, _components = self._reward_engine.compute( confusion_before = confusion_before, confusion_after = snap.confusion, attention_before = attention_before, attention_after = snap.attention, action_text = action_str, format_valid = format_valid, learning_trend = snap.learning_trend, action = parsed_action, action_history = [a.value for a in self._recent_actions[-3:]], misconception = misconception_now, done = done, done_reason = done_reason, phase_bonus = phase_bonus, episode_length = self._turn_count, ) # 7. Update history self._recent_actions.append(parsed_action) if len(self._recent_actions) > 5: self._recent_actions.pop(0) obs = Observation( student_response = snap.last_response, confusion = snap.confusion, attention = snap.attention, learning_trend = snap.learning_trend, turn = self._turn_count, misconception_id = snap.misconception_id, student_type = snap.student_type, last_action = snap.last_action, steps_taken = self._turn_count, recent_actions = tuple(a.value for a in self._recent_actions[-3:]), ) info = StepInfo( done_reason = done_reason, raw_action = action_str, parsed_action = parsed_action, misconception_id = snap.misconception_id, ) return obs, reward, done, info # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @property def turn_count(self) -> int: return self._turn_count def __repr__(self) -> str: if self._sim is None: return "EduForgeEnv(not started)" snap = self._sim.state_snapshot() return ( f"EduForgeEnv(turn={self._turn_count}, " f"confusion={snap.confusion:.2f}, " f"attention={snap.attention:.2f}, " f"misconception={snap.misconception_id.value})" )