Spaces:
Sleeping
Sleeping
| """ | |
| openenv_wrapper.py — OpenEnv-compatible environment wrapping StudentSimulator. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| try: | |
| import openenv # type: ignore | |
| if not hasattr(openenv, 'Environment'): | |
| raise ImportError | |
| _OPENENV_AVAILABLE = True | |
| except ImportError: | |
| _OPENENV_AVAILABLE = False | |
| class _EnvBase: | |
| def reset(self): raise NotImplementedError | |
| def step(self, action): raise NotImplementedError | |
| openenv = type("openenv", (), {"Environment": _EnvBase})() | |
| from .student_fsm import ( | |
| MisconceptionType, StudentType, StudentSimulator, | |
| TutorAction, StudentState, encode_state, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class Observation: | |
| student_response: str | |
| confusion: float | |
| attention: float | |
| learning_trend: float | |
| turn: int | |
| misconception_id: MisconceptionType | |
| student_type: StudentType | |
| last_action: Optional[TutorAction] = None | |
| steps_taken: int = 0 | |
| recent_actions: tuple[str, ...] = field(default_factory=tuple) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "student_response": self.student_response, | |
| "confusion": self.confusion, | |
| "attention": self.attention, | |
| "learning_trend": self.learning_trend, | |
| "turn": self.turn, | |
| "misconception_id": self.misconception_id.value, | |
| "student_type": self.student_type.value, | |
| "last_action": self.last_action.value if self.last_action else None, | |
| "steps_taken": self.steps_taken, | |
| "recent_actions": list(self.recent_actions), | |
| } | |
| def to_numpy(self): | |
| from .student_fsm import StudentState, encode_state | |
| snap = StudentState( | |
| misconception_id = self.misconception_id, | |
| student_type = self.student_type, | |
| confusion = self.confusion, | |
| attention = self.attention, | |
| learning_trend = self.learning_trend, | |
| turn = self.turn, | |
| last_action = self.last_action, | |
| ) | |
| return encode_state(snap) | |
| # --------------------------------------------------------------------------- | |
| # StepInfo | |
| # --------------------------------------------------------------------------- | |
| class StepInfo: | |
| done_reason: Optional[str] | |
| raw_action: str | |
| parsed_action: Optional[TutorAction] | |
| misconception_id: MisconceptionType | |
| # --------------------------------------------------------------------------- | |
| # Strategy tag parser | |
| # --------------------------------------------------------------------------- | |
| _STRATEGY_PATTERN = re.compile( | |
| r"<STRATEGY>\s*([a-z_]+)\s*</STRATEGY>", | |
| re.IGNORECASE, | |
| ) | |
| def _parse_action(action_str: str) -> Optional[TutorAction]: | |
| match = _STRATEGY_PATTERN.search(action_str) | |
| if not match: | |
| return None | |
| try: | |
| return TutorAction(match.group(1).lower()) | |
| except ValueError: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Teaching phase shaping | |
| # --------------------------------------------------------------------------- | |
| _PHASE_BONUS = 0.2 # small — guide without dominating | |
| _PHASE_PREFERRED: dict[int, set[TutorAction]] = { | |
| 0: {TutorAction.QUESTION, TutorAction.EXPLAIN}, # diagnose (confusion > 6) | |
| 1: {TutorAction.WORKED_EXAMPLE, TutorAction.ANALOGIZE}, # intervene (confusion > 3.5) | |
| 2: {TutorAction.QUESTION, TutorAction.CORRECT_FACT}, # consolidate (confusion <= 3.5) | |
| } | |
| def _phase_index(confusion: float) -> int: | |
| if confusion > 6.0: return 0 | |
| if confusion > 3.5: return 1 | |
| return 2 | |
| # --------------------------------------------------------------------------- | |
| # Episode constants | |
| # --------------------------------------------------------------------------- | |
| MAX_TURNS = 15 | |
| CONFUSION_SUCCESS = 2.0 | |
| ATTENTION_FLOOR = 0.5 # Lowered — soft penalty only above, terminal only at near-zero | |
| # --------------------------------------------------------------------------- | |
| # EduForgeEnv | |
| # --------------------------------------------------------------------------- | |
| class EduForgeEnv(openenv.Environment): | |
| """ | |
| OpenEnv-compatible tutoring environment. | |
| Done conditions | |
| --------------- | |
| success : confusion <= 2.0 | |
| timeout : turn_count >= 15 | |
| disengaged : attention <= 0.5 (near-zero only — no premature kills) | |
| """ | |
| def __init__( | |
| self, | |
| seed: Optional[int] = None, | |
| confusion_init: Optional[float] = None, | |
| attention_init: Optional[float] = None, | |
| misconception_init: Optional[str] = None, | |
| ) -> None: | |
| self._seed = seed | |
| self._confusion_init = confusion_init | |
| self._attention_init = attention_init | |
| self._misconception_init = misconception_init | |
| self._sim: Optional[StudentSimulator] = None | |
| from ..rewards.engine import RewardEngine | |
| self._reward_engine = RewardEngine() | |
| self._turn_count: int = 0 | |
| self._recent_actions: list[TutorAction] = [] | |
| # ------------------------------------------------------------------ | |
| # reset | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> Observation: | |
| m_init: Optional[MisconceptionType] = None | |
| if self._misconception_init is not None: | |
| try: | |
| m_init = MisconceptionType(self._misconception_init) | |
| except ValueError: | |
| pass | |
| self._sim = StudentSimulator( | |
| seed = self._seed, | |
| confusion_init = self._confusion_init, | |
| attention_init = self._attention_init, | |
| misconception_init = m_init, | |
| ) | |
| self._turn_count = 0 | |
| self._recent_actions = [] | |
| self._reward_engine.reset() | |
| snap = self._sim.state_snapshot() | |
| initial_response = self._sim.generate_response() | |
| return Observation( | |
| student_response = initial_response, | |
| confusion = snap.confusion, | |
| attention = snap.attention, | |
| learning_trend = snap.learning_trend, | |
| turn = self._turn_count, | |
| misconception_id = snap.misconception_id, | |
| student_type = snap.student_type, | |
| last_action = None, | |
| steps_taken = 0, | |
| recent_actions = tuple(), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # step | |
| # ------------------------------------------------------------------ | |
| def step( | |
| self, action_str: str | |
| ) -> tuple[Observation, float, bool, StepInfo]: | |
| if self._sim is None: | |
| raise RuntimeError("Call reset() before step().") | |
| # 1. Parse | |
| parsed_action = _parse_action(action_str) | |
| format_valid = parsed_action is not None | |
| if parsed_action is None: | |
| parsed_action = TutorAction.REPEAT | |
| # 2. Pre-transition snapshot | |
| confusion_before = self._sim.confusion | |
| attention_before = self._sim.attention | |
| misconception_now = self._sim.misconception_id | |
| # 3. Phase bonus | |
| phase = _phase_index(confusion_before) | |
| phase_bonus = _PHASE_BONUS if parsed_action in _PHASE_PREFERRED[phase] else 0.0 | |
| # 4. Transition | |
| self._sim.transition(parsed_action) | |
| self._turn_count += 1 | |
| snap = self._sim.state_snapshot() | |
| # 5. Done conditions | |
| done = False | |
| done_reason: Optional[str] = None | |
| if snap.confusion <= CONFUSION_SUCCESS: | |
| done = True | |
| done_reason = "success" | |
| elif snap.attention <= ATTENTION_FLOOR: | |
| done = True | |
| done_reason = "disengaged" | |
| elif self._turn_count >= MAX_TURNS: | |
| done = True | |
| done_reason = "timeout" | |
| # 6. Reward | |
| reward, _components = self._reward_engine.compute( | |
| confusion_before = confusion_before, | |
| confusion_after = snap.confusion, | |
| attention_before = attention_before, | |
| attention_after = snap.attention, | |
| action_text = action_str, | |
| format_valid = format_valid, | |
| learning_trend = snap.learning_trend, | |
| action = parsed_action, | |
| action_history = [a.value for a in self._recent_actions[-3:]], | |
| misconception = misconception_now, | |
| done = done, | |
| done_reason = done_reason, | |
| phase_bonus = phase_bonus, | |
| episode_length = self._turn_count, | |
| ) | |
| # 7. Update history | |
| self._recent_actions.append(parsed_action) | |
| if len(self._recent_actions) > 5: | |
| self._recent_actions.pop(0) | |
| obs = Observation( | |
| student_response = snap.last_response, | |
| confusion = snap.confusion, | |
| attention = snap.attention, | |
| learning_trend = snap.learning_trend, | |
| turn = self._turn_count, | |
| misconception_id = snap.misconception_id, | |
| student_type = snap.student_type, | |
| last_action = snap.last_action, | |
| steps_taken = self._turn_count, | |
| recent_actions = tuple(a.value for a in self._recent_actions[-3:]), | |
| ) | |
| info = StepInfo( | |
| done_reason = done_reason, | |
| raw_action = action_str, | |
| parsed_action = parsed_action, | |
| misconception_id = snap.misconception_id, | |
| ) | |
| return obs, reward, done, info | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def turn_count(self) -> int: | |
| return self._turn_count | |
| def __repr__(self) -> str: | |
| if self._sim is None: | |
| return "EduForgeEnv(not started)" | |
| snap = self._sim.state_snapshot() | |
| return ( | |
| f"EduForgeEnv(turn={self._turn_count}, " | |
| f"confusion={snap.confusion:.2f}, " | |
| f"attention={snap.attention:.2f}, " | |
| f"misconception={snap.misconception_id.value})" | |
| ) |