Spaces:
Sleeping
Sleeping
| """Core OpenEnv environment for TemporalBench MCQ episodes.""" | |
| from __future__ import annotations | |
| import uuid | |
| from collections import defaultdict | |
| from dataclasses import replace | |
| from typing import Any, Optional | |
| import numpy as np | |
| from data.loaders import load_question_banks | |
| from data.question import TSQuestion | |
| from .config import EnvConfig | |
| from .episode_sampler import EpisodeSampler | |
| from .grading import grade_answer | |
| from .models import TemporalBenchAction, TemporalBenchObservation, TemporalBenchState | |
| from .reward import compute_episode_bonus, compute_mcq_reward | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| except ImportError: | |
| from abc import ABC, abstractmethod | |
| from typing import Generic, TypeVar | |
| ActT = TypeVar("ActT") | |
| ObsT = TypeVar("ObsT") | |
| StateT = TypeVar("StateT") | |
| class Environment(ABC, Generic[ActT, ObsT, StateT]): | |
| def reset(self, seed=None, episode_id=None, **kwargs): ... | |
| def step(self, action, timeout_s=None, **kwargs): ... | |
| def state(self): ... | |
| class TemporalBenchEnvironment( | |
| Environment[TemporalBenchAction, TemporalBenchObservation, TemporalBenchState] | |
| ): | |
| """Multi-step MCQ environment over a pre-built TemporalBench question bank.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, config: Optional[EnvConfig] = None, **kwargs: Any): | |
| super().__init__(**kwargs) | |
| self._config = config or EnvConfig() | |
| seed = self._config.seed | |
| self._rng = np.random.default_rng(seed) | |
| self._banks = load_question_banks(self._config.question_bank_path) | |
| self._sampler = EpisodeSampler(self._banks, self._config, self._rng) | |
| self._episode_id: Optional[str] = None | |
| self._questions: list[TSQuestion] = [] | |
| self._answered: int = 0 | |
| self._history: list[dict[str, Any]] = [] | |
| self._done: bool = False | |
| self._total_correct: int = 0 | |
| self._total_reward: float = 0.0 | |
| self._domain_correct: dict[str, int] = defaultdict(int) | |
| self._task_correct: dict[str, int] = defaultdict(int) | |
| self._task_total: dict[str, int] = defaultdict(int) | |
| self._last_metadata: dict[str, Any] = {} | |
| def _accuracy_so_far(self) -> float: | |
| if self._answered == 0: | |
| return 0.0 | |
| return self._total_correct / self._answered | |
| def _per_task_accuracy(self) -> dict[str, float]: | |
| out: dict[str, float] = {} | |
| for k, tot in self._task_total.items(): | |
| out[k] = (self._task_correct[k] / tot) if tot else 0.0 | |
| return out | |
| def _build_observation( | |
| self, | |
| *, | |
| reward: float | None, | |
| done: bool, | |
| ) -> TemporalBenchObservation: | |
| n = self._config.num_questions | |
| if done or self._answered >= n: | |
| return TemporalBenchObservation( | |
| step_idx=self._answered, | |
| steps_remaining=0, | |
| max_steps=n, | |
| question="", | |
| options=[], | |
| task_type="", | |
| dataset="", | |
| history=list(self._history), | |
| accuracy_so_far=self._accuracy_so_far(), | |
| done=True, | |
| reward=reward, | |
| metadata=dict(self._last_metadata), | |
| ) | |
| q = self._questions[self._answered] | |
| steps_remaining = n - self._answered | |
| return TemporalBenchObservation( | |
| step_idx=self._answered, | |
| steps_remaining=steps_remaining, | |
| max_steps=n, | |
| question=q.prompt, | |
| options=list(q.options), | |
| task_type=q.task_type, | |
| dataset=q.dataset, | |
| history=list(self._history), | |
| accuracy_so_far=self._accuracy_so_far(), | |
| done=False, | |
| reward=reward, | |
| metadata=dict(self._last_metadata), | |
| ) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> TemporalBenchObservation: | |
| curriculum_kw = kwargs.pop("curriculum_stage", None) | |
| if seed is not None: | |
| self._rng = np.random.default_rng(seed) | |
| cfg = self._config | |
| if curriculum_kw is not None: | |
| cfg = replace(self._config, curriculum_stage=int(curriculum_kw)) | |
| self._sampler = EpisodeSampler(self._banks, cfg, self._rng) | |
| self._episode_id = episode_id or str(uuid.uuid4()) | |
| self._questions = self._sampler.sample_episode() | |
| self._answered = 0 | |
| self._history = [] | |
| self._done = False | |
| self._total_correct = 0 | |
| self._total_reward = 0.0 | |
| self._domain_correct = defaultdict(int) | |
| self._task_correct = defaultdict(int) | |
| self._task_total = defaultdict(int) | |
| self._last_metadata = {} | |
| return self._build_observation(reward=0.0, done=False) | |
| def step( | |
| self, | |
| action: TemporalBenchAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> TemporalBenchObservation: | |
| del timeout_s, kwargs | |
| if self._done: | |
| self._last_metadata = {"info": "Episode already done."} | |
| return self._build_observation(reward=0.0, done=True) | |
| self._last_metadata = {} | |
| n = self._config.num_questions | |
| if self._answered >= n: | |
| self._done = True | |
| self._last_metadata = {"info": "Episode already complete."} | |
| return self._build_observation(reward=0.0, done=True) | |
| q = self._questions[self._answered] | |
| if not str(action.answer).strip(): | |
| self._last_metadata = {"error": "answer must be a non-empty string."} | |
| return self._build_observation(reward=0.0, done=False) | |
| fully_correct, score = grade_answer(action.answer, q, self._config) | |
| r_step = compute_mcq_reward(score, alpha=self._config.alpha) | |
| self._history.append( | |
| { | |
| "question_id": q.question_id, | |
| "dataset": q.dataset, | |
| "task_type": q.task_type, | |
| "submitted": action.answer, | |
| "correct": fully_correct, | |
| "reward": r_step, | |
| } | |
| ) | |
| self._task_total[q.task_type] += 1 | |
| if fully_correct: | |
| self._total_correct += 1 | |
| self._domain_correct[q.dataset] += 1 | |
| self._task_correct[q.task_type] += 1 | |
| self._answered += 1 | |
| total_reward_this_step = r_step | |
| if self._answered >= n: | |
| bonus = compute_episode_bonus( | |
| self._total_correct, | |
| n, | |
| dict(self._domain_correct), | |
| all_domains=tuple(self._config.all_domains), | |
| lambda_ep=self._config.lambda_ep, | |
| ) | |
| total_reward_this_step = r_step + bonus | |
| self._done = True | |
| self._last_metadata = { | |
| "episode_bonus": bonus, | |
| "domain_correct_counts": dict(self._domain_correct), | |
| } | |
| self._total_reward += total_reward_this_step | |
| return self._build_observation( | |
| reward=total_reward_this_step, | |
| done=self._done, | |
| ) | |
| def state(self) -> TemporalBenchState: | |
| return TemporalBenchState( | |
| episode_id=self._episode_id, | |
| step_count=self._answered, | |
| total_correct=self._total_correct, | |
| total_questions=self._config.num_questions, | |
| current_accuracy=self._accuracy_so_far(), | |
| primary_domain=self._config.primary_domain, | |
| per_task_type_accuracy=self._per_task_accuracy(), | |
| total_reward=self._total_reward, | |
| ) | |