TemporalBenchEnv / env /temporal_bench_env.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
"""Core OpenEnv environment for TemporalBench MCQ episodes."""
from __future__ import annotations
import uuid
from collections import defaultdict
from dataclasses import replace
from typing import Any, Optional
import numpy as np
from data.loaders import load_question_banks
from data.question import TSQuestion
from .config import EnvConfig
from .episode_sampler import EpisodeSampler
from .grading import grade_answer
from .models import TemporalBenchAction, TemporalBenchObservation, TemporalBenchState
from .reward import compute_episode_bonus, compute_mcq_reward
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
ActT = TypeVar("ActT")
ObsT = TypeVar("ObsT")
StateT = TypeVar("StateT")
class Environment(ABC, Generic[ActT, ObsT, StateT]):
@abstractmethod
def reset(self, seed=None, episode_id=None, **kwargs): ...
@abstractmethod
def step(self, action, timeout_s=None, **kwargs): ...
@property
@abstractmethod
def state(self): ...
class TemporalBenchEnvironment(
Environment[TemporalBenchAction, TemporalBenchObservation, TemporalBenchState]
):
"""Multi-step MCQ environment over a pre-built TemporalBench question bank."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, config: Optional[EnvConfig] = None, **kwargs: Any):
super().__init__(**kwargs)
self._config = config or EnvConfig()
seed = self._config.seed
self._rng = np.random.default_rng(seed)
self._banks = load_question_banks(self._config.question_bank_path)
self._sampler = EpisodeSampler(self._banks, self._config, self._rng)
self._episode_id: Optional[str] = None
self._questions: list[TSQuestion] = []
self._answered: int = 0
self._history: list[dict[str, Any]] = []
self._done: bool = False
self._total_correct: int = 0
self._total_reward: float = 0.0
self._domain_correct: dict[str, int] = defaultdict(int)
self._task_correct: dict[str, int] = defaultdict(int)
self._task_total: dict[str, int] = defaultdict(int)
self._last_metadata: dict[str, Any] = {}
def _accuracy_so_far(self) -> float:
if self._answered == 0:
return 0.0
return self._total_correct / self._answered
def _per_task_accuracy(self) -> dict[str, float]:
out: dict[str, float] = {}
for k, tot in self._task_total.items():
out[k] = (self._task_correct[k] / tot) if tot else 0.0
return out
def _build_observation(
self,
*,
reward: float | None,
done: bool,
) -> TemporalBenchObservation:
n = self._config.num_questions
if done or self._answered >= n:
return TemporalBenchObservation(
step_idx=self._answered,
steps_remaining=0,
max_steps=n,
question="",
options=[],
task_type="",
dataset="",
history=list(self._history),
accuracy_so_far=self._accuracy_so_far(),
done=True,
reward=reward,
metadata=dict(self._last_metadata),
)
q = self._questions[self._answered]
steps_remaining = n - self._answered
return TemporalBenchObservation(
step_idx=self._answered,
steps_remaining=steps_remaining,
max_steps=n,
question=q.prompt,
options=list(q.options),
task_type=q.task_type,
dataset=q.dataset,
history=list(self._history),
accuracy_so_far=self._accuracy_so_far(),
done=False,
reward=reward,
metadata=dict(self._last_metadata),
)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> TemporalBenchObservation:
curriculum_kw = kwargs.pop("curriculum_stage", None)
if seed is not None:
self._rng = np.random.default_rng(seed)
cfg = self._config
if curriculum_kw is not None:
cfg = replace(self._config, curriculum_stage=int(curriculum_kw))
self._sampler = EpisodeSampler(self._banks, cfg, self._rng)
self._episode_id = episode_id or str(uuid.uuid4())
self._questions = self._sampler.sample_episode()
self._answered = 0
self._history = []
self._done = False
self._total_correct = 0
self._total_reward = 0.0
self._domain_correct = defaultdict(int)
self._task_correct = defaultdict(int)
self._task_total = defaultdict(int)
self._last_metadata = {}
return self._build_observation(reward=0.0, done=False)
def step(
self,
action: TemporalBenchAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> TemporalBenchObservation:
del timeout_s, kwargs
if self._done:
self._last_metadata = {"info": "Episode already done."}
return self._build_observation(reward=0.0, done=True)
self._last_metadata = {}
n = self._config.num_questions
if self._answered >= n:
self._done = True
self._last_metadata = {"info": "Episode already complete."}
return self._build_observation(reward=0.0, done=True)
q = self._questions[self._answered]
if not str(action.answer).strip():
self._last_metadata = {"error": "answer must be a non-empty string."}
return self._build_observation(reward=0.0, done=False)
fully_correct, score = grade_answer(action.answer, q, self._config)
r_step = compute_mcq_reward(score, alpha=self._config.alpha)
self._history.append(
{
"question_id": q.question_id,
"dataset": q.dataset,
"task_type": q.task_type,
"submitted": action.answer,
"correct": fully_correct,
"reward": r_step,
}
)
self._task_total[q.task_type] += 1
if fully_correct:
self._total_correct += 1
self._domain_correct[q.dataset] += 1
self._task_correct[q.task_type] += 1
self._answered += 1
total_reward_this_step = r_step
if self._answered >= n:
bonus = compute_episode_bonus(
self._total_correct,
n,
dict(self._domain_correct),
all_domains=tuple(self._config.all_domains),
lambda_ep=self._config.lambda_ep,
)
total_reward_this_step = r_step + bonus
self._done = True
self._last_metadata = {
"episode_bonus": bonus,
"domain_correct_counts": dict(self._domain_correct),
}
self._total_reward += total_reward_this_step
return self._build_observation(
reward=total_reward_this_step,
done=self._done,
)
@property
def state(self) -> TemporalBenchState:
return TemporalBenchState(
episode_id=self._episode_id,
step_count=self._answered,
total_correct=self._total_correct,
total_questions=self._config.num_questions,
current_accuracy=self._accuracy_so_far(),
primary_domain=self._config.primary_domain,
per_task_type_accuracy=self._per_task_accuracy(),
total_reward=self._total_reward,
)