Spaces:
Sleeping
Sleeping
| import copy | |
| import random | |
| from typing import Optional | |
| from schemas import Observation, Action, StepResult | |
| from reward import compute_reward | |
| from tools import extract_concepts, detect_weakness | |
| class TutorEnv: | |
| def __init__(self, tasks, seed: Optional[int] = None, stochastic: bool = False): | |
| self.tasks = copy.deepcopy(tasks) | |
| self.current = None | |
| self.current_chat_history = [] | |
| self.step_count = 0 | |
| self.tool_output = None | |
| self.episode_done = False | |
| self.last_action_type = None | |
| self.stochastic = stochastic | |
| self.seed = seed | |
| self.rng = random.Random(seed) | |
| self.max_steps = 4 | |
| def _build_chat_history(self, chat_history): | |
| if not self.stochastic: | |
| return list(chat_history) | |
| noise_candidates = [ | |
| "Reminder: Focus on understanding, not rote memorization.", | |
| "Distractor: Student also mentioned sleep issues before exams.", | |
| "Hint: Time budgeting is often the main bottleneck.", | |
| ] | |
| history = list(chat_history) | |
| if self.rng.random() < 0.4: | |
| history.append(self.rng.choice(noise_candidates)) | |
| return history | |
| def _extract_features(self): | |
| constraints = (self.current or {}).get("constraints") or {} | |
| text = " ".join(self.current_chat_history).lower() | |
| return { | |
| "message_count": len(self.current_chat_history), | |
| "token_count": len(text.split()), | |
| "has_constraints": bool(constraints), | |
| "exam_in_days": constraints.get("exam_in_days"), | |
| "has_time_budget": bool(constraints.get("time_per_day")), | |
| "mentions_exam": ("exam" in text), | |
| "mentions_time_pressure": ("time" in text or "timed" in text), | |
| } | |
| def _observation(self, session_id: Optional[str] = None): | |
| return Observation( | |
| task_id=self.current["task_id"], | |
| difficulty=self.current["difficulty"], | |
| chat_history=list(self.current_chat_history), | |
| constraints=self.current.get("constraints"), | |
| step_count=self.step_count, | |
| features=self._extract_features(), | |
| session_id=session_id, | |
| ) | |
| def reset(self, task, session_id: Optional[str] = None, seed: Optional[int] = None, stochastic: Optional[bool] = None): | |
| self.current = copy.deepcopy(task) | |
| if seed is not None: | |
| self.seed = seed | |
| self.rng = random.Random(seed) | |
| if stochastic is not None: | |
| self.stochastic = stochastic | |
| self.current_chat_history = self._build_chat_history(self.current["chat_history"]) | |
| self.step_count = 0 | |
| self.tool_output = None | |
| self.episode_done = False | |
| self.last_action_type = None | |
| return self._observation(session_id=session_id) | |
| def step(self, action: Action): | |
| if self.current is None: | |
| raise ValueError("Environment not initialized. Call reset() first.") | |
| if self.episode_done: | |
| raise ValueError("Episode already finished. Call reset() before calling step() again.") | |
| if self.step_count >= self.max_steps: | |
| self.episode_done = True | |
| raise ValueError("Maximum step limit reached. Call reset() to start a new episode.") | |
| if action.type not in {"tool", "final_answer"}: | |
| raise ValueError(f"Invalid action type: {action.type}") | |
| if action.type == "tool" and not action.tool_name: | |
| raise ValueError("tool_name is required when type='tool'.") | |
| if action.type == "final_answer" and not (action.content or "").strip(): | |
| raise ValueError("content is required when type='final_answer'.") | |
| self.step_count += 1 | |
| self.last_action_type = action.type | |
| # --- TOOL STEP --- | |
| if action.type == "tool": | |
| if action.tool_name == "extract_concepts": | |
| self.tool_output = extract_concepts(self.current_chat_history) | |
| elif action.tool_name == "detect_weakness": | |
| self.tool_output = detect_weakness(self.current_chat_history) | |
| else: | |
| raise ValueError(f"Unknown tool: {action.tool_name}") | |
| # append tool output to observation | |
| self.current_chat_history = list(self.current_chat_history) + [f"[tool:{action.tool_name}] {self.tool_output}"] | |
| obs = self._observation() | |
| return StepResult( | |
| observation=obs, | |
| reward=0.08, | |
| done=False, | |
| info={ | |
| "tool_output": self.tool_output, | |
| "action_valid": True, | |
| "step_budget_remaining": self.max_steps - self.step_count, | |
| }, | |
| ) | |
| # --- FINAL STEP --- | |
| elif action.type == "final_answer": | |
| output = action.content | |
| result = compute_reward( | |
| output, | |
| self.current["expected"], | |
| constraints=self.current.get("constraints"), | |
| tool_output=self.tool_output, | |
| step_count=self.step_count, | |
| ) | |
| self.episode_done = True | |
| return StepResult( | |
| observation=self._observation(), | |
| reward=result["score"], | |
| done=True, | |
| info=result["breakdown"], | |
| ) | |
| def state(self): | |
| if self.current is None: | |
| return None | |
| return { | |
| "task_id": self.current["task_id"], | |
| "difficulty": self.current["difficulty"], | |
| "step_count": self.step_count, | |
| "episode_done": self.episode_done, | |
| "last_action_type": self.last_action_type, | |
| "seed": self.seed, | |
| "stochastic": self.stochastic, | |
| "features": self._extract_features(), | |
| } | |