from pathlib import Path import pytest from env.config import EnvConfig from env.temporal_bench_env import TemporalBenchEnvironment from models import TemporalBenchAction def _bank_dir() -> Path: return Path(__file__).resolve().parent / "fixtures" / "banks" def test_full_episode_all_correct(): cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0) env = TemporalBenchEnvironment(config=cfg) obs0 = env.reset(seed=123) assert not obs0.done assert obs0.step_idx == 0 assert len(obs0.options) >= 2 total_r = 0.0 obs = obs0 for i in range(cfg.num_questions): cur = env._questions[env._answered] # noqa: SLF001 obs = env.step(TemporalBenchAction(answer=cur.answer)) total_r += obs.reward or 0.0 if i < cfg.num_questions - 1: assert not obs.done assert obs.step_idx == i + 1 else: assert obs.done assert obs.step_idx == cfg.num_questions st = env.state assert st.total_correct == cfg.num_questions assert st.step_count == cfg.num_questions assert st.total_reward == pytest.approx(total_r) def test_episode_with_one_wrong_answer(): cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0) env = TemporalBenchEnvironment(config=cfg) env.reset(seed=0) for i in range(cfg.num_questions): cur = env._questions[env._answered] # noqa: SLF001 wrong = next(o for o in cur.options if o != cur.answer) act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer) obs = env.step(act) if i == cfg.num_questions - 1: assert obs.done meta = obs.metadata or {} assert "episode_bonus" in meta bonus = float(meta["episode_bonus"]) lo = 0.5 * (8 / 9) * 0.8 hi = 0.5 * (8 / 9) * 1.0 assert lo - 1e-9 <= bonus <= hi + 1e-9 assert env.state.total_correct == 8 def test_double_step_after_done(): cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1) env = TemporalBenchEnvironment(config=cfg) env.reset(seed=1) for _ in range(cfg.num_questions): cur = env._questions[env._answered] # noqa: SLF001 env.step(TemporalBenchAction(answer=cur.answer)) obs = env.step(TemporalBenchAction(answer="x")) assert obs.done def test_empty_answer_no_advance(): cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2) env = TemporalBenchEnvironment(config=cfg) env.reset(seed=2) before = env._answered # noqa: SLF001 obs = env.step(TemporalBenchAction(answer=" ")) assert env._answered == before # noqa: SLF001 assert "error" in (obs.metadata or {})