"""Core OpenEnv environment for TemporalBench MCQ episodes.""" from __future__ import annotations import uuid from collections import defaultdict from dataclasses import replace from typing import Any, Optional import numpy as np from data.loaders import load_question_banks from data.question import TSQuestion from .config import EnvConfig from .episode_sampler import EpisodeSampler from .grading import grade_answer from .models import TemporalBenchAction, TemporalBenchObservation, TemporalBenchState from .reward import compute_episode_bonus, compute_mcq_reward try: from openenv.core.env_server.interfaces import Environment except ImportError: from abc import ABC, abstractmethod from typing import Generic, TypeVar ActT = TypeVar("ActT") ObsT = TypeVar("ObsT") StateT = TypeVar("StateT") class Environment(ABC, Generic[ActT, ObsT, StateT]): @abstractmethod def reset(self, seed=None, episode_id=None, **kwargs): ... @abstractmethod def step(self, action, timeout_s=None, **kwargs): ... @property @abstractmethod def state(self): ... class TemporalBenchEnvironment( Environment[TemporalBenchAction, TemporalBenchObservation, TemporalBenchState] ): """Multi-step MCQ environment over a pre-built TemporalBench question bank.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self, config: Optional[EnvConfig] = None, **kwargs: Any): super().__init__(**kwargs) self._config = config or EnvConfig() seed = self._config.seed self._rng = np.random.default_rng(seed) self._banks = load_question_banks(self._config.question_bank_path) self._sampler = EpisodeSampler(self._banks, self._config, self._rng) self._episode_id: Optional[str] = None self._questions: list[TSQuestion] = [] self._answered: int = 0 self._history: list[dict[str, Any]] = [] self._done: bool = False self._total_correct: int = 0 self._total_reward: float = 0.0 self._domain_correct: dict[str, int] = defaultdict(int) self._task_correct: dict[str, int] = defaultdict(int) self._task_total: dict[str, int] = defaultdict(int) self._last_metadata: dict[str, Any] = {} def _accuracy_so_far(self) -> float: if self._answered == 0: return 0.0 return self._total_correct / self._answered def _per_task_accuracy(self) -> dict[str, float]: out: dict[str, float] = {} for k, tot in self._task_total.items(): out[k] = (self._task_correct[k] / tot) if tot else 0.0 return out def _build_observation( self, *, reward: float | None, done: bool, ) -> TemporalBenchObservation: n = self._config.num_questions if done or self._answered >= n: return TemporalBenchObservation( step_idx=self._answered, steps_remaining=0, max_steps=n, question="", options=[], task_type="", dataset="", history=list(self._history), accuracy_so_far=self._accuracy_so_far(), done=True, reward=reward, metadata=dict(self._last_metadata), ) q = self._questions[self._answered] steps_remaining = n - self._answered return TemporalBenchObservation( step_idx=self._answered, steps_remaining=steps_remaining, max_steps=n, question=q.prompt, options=list(q.options), task_type=q.task_type, dataset=q.dataset, history=list(self._history), accuracy_so_far=self._accuracy_so_far(), done=False, reward=reward, metadata=dict(self._last_metadata), ) def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> TemporalBenchObservation: curriculum_kw = kwargs.pop("curriculum_stage", None) if seed is not None: self._rng = np.random.default_rng(seed) cfg = self._config if curriculum_kw is not None: cfg = replace(self._config, curriculum_stage=int(curriculum_kw)) self._sampler = EpisodeSampler(self._banks, cfg, self._rng) self._episode_id = episode_id or str(uuid.uuid4()) self._questions = self._sampler.sample_episode() self._answered = 0 self._history = [] self._done = False self._total_correct = 0 self._total_reward = 0.0 self._domain_correct = defaultdict(int) self._task_correct = defaultdict(int) self._task_total = defaultdict(int) self._last_metadata = {} return self._build_observation(reward=0.0, done=False) def step( self, action: TemporalBenchAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> TemporalBenchObservation: del timeout_s, kwargs if self._done: self._last_metadata = {"info": "Episode already done."} return self._build_observation(reward=0.0, done=True) self._last_metadata = {} n = self._config.num_questions if self._answered >= n: self._done = True self._last_metadata = {"info": "Episode already complete."} return self._build_observation(reward=0.0, done=True) q = self._questions[self._answered] if not str(action.answer).strip(): self._last_metadata = {"error": "answer must be a non-empty string."} return self._build_observation(reward=0.0, done=False) fully_correct, score = grade_answer(action.answer, q, self._config) r_step = compute_mcq_reward(score, alpha=self._config.alpha) self._history.append( { "question_id": q.question_id, "dataset": q.dataset, "task_type": q.task_type, "submitted": action.answer, "correct": fully_correct, "reward": r_step, } ) self._task_total[q.task_type] += 1 if fully_correct: self._total_correct += 1 self._domain_correct[q.dataset] += 1 self._task_correct[q.task_type] += 1 self._answered += 1 total_reward_this_step = r_step if self._answered >= n: bonus = compute_episode_bonus( self._total_correct, n, dict(self._domain_correct), all_domains=tuple(self._config.all_domains), lambda_ep=self._config.lambda_ep, ) total_reward_this_step = r_step + bonus self._done = True self._last_metadata = { "episode_bonus": bonus, "domain_correct_counts": dict(self._domain_correct), } self._total_reward += total_reward_this_step return self._build_observation( reward=total_reward_this_step, done=self._done, ) @property def state(self) -> TemporalBenchState: return TemporalBenchState( episode_id=self._episode_id, step_count=self._answered, total_correct=self._total_correct, total_questions=self._config.num_questions, current_accuracy=self._accuracy_so_far(), primary_domain=self._config.primary_domain, per_task_type_accuracy=self._per_task_accuracy(), total_reward=self._total_reward, )