TemporalBenchEnv / tests /test_env.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
from pathlib import Path
import pytest
from env.config import EnvConfig
from env.temporal_bench_env import TemporalBenchEnvironment
from models import TemporalBenchAction
def _bank_dir() -> Path:
return Path(__file__).resolve().parent / "fixtures" / "banks"
def test_full_episode_all_correct():
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0)
env = TemporalBenchEnvironment(config=cfg)
obs0 = env.reset(seed=123)
assert not obs0.done
assert obs0.step_idx == 0
assert len(obs0.options) >= 2
total_r = 0.0
obs = obs0
for i in range(cfg.num_questions):
cur = env._questions[env._answered] # noqa: SLF001
obs = env.step(TemporalBenchAction(answer=cur.answer))
total_r += obs.reward or 0.0
if i < cfg.num_questions - 1:
assert not obs.done
assert obs.step_idx == i + 1
else:
assert obs.done
assert obs.step_idx == cfg.num_questions
st = env.state
assert st.total_correct == cfg.num_questions
assert st.step_count == cfg.num_questions
assert st.total_reward == pytest.approx(total_r)
def test_episode_with_one_wrong_answer():
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0)
env = TemporalBenchEnvironment(config=cfg)
env.reset(seed=0)
for i in range(cfg.num_questions):
cur = env._questions[env._answered] # noqa: SLF001
wrong = next(o for o in cur.options if o != cur.answer)
act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer)
obs = env.step(act)
if i == cfg.num_questions - 1:
assert obs.done
meta = obs.metadata or {}
assert "episode_bonus" in meta
bonus = float(meta["episode_bonus"])
lo = 0.5 * (8 / 9) * 0.8
hi = 0.5 * (8 / 9) * 1.0
assert lo - 1e-9 <= bonus <= hi + 1e-9
assert env.state.total_correct == 8
def test_double_step_after_done():
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1)
env = TemporalBenchEnvironment(config=cfg)
env.reset(seed=1)
for _ in range(cfg.num_questions):
cur = env._questions[env._answered] # noqa: SLF001
env.step(TemporalBenchAction(answer=cur.answer))
obs = env.step(TemporalBenchAction(answer="x"))
assert obs.done
def test_empty_answer_no_advance():
cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2)
env = TemporalBenchEnvironment(config=cfg)
env.reset(seed=2)
before = env._answered # noqa: SLF001
obs = env.step(TemporalBenchAction(answer=" "))
assert env._answered == before # noqa: SLF001
assert "error" in (obs.metadata or {})