Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import pytest | |
| from env.config import EnvConfig | |
| from env.temporal_bench_env import TemporalBenchEnvironment | |
| from models import TemporalBenchAction | |
| def _bank_dir() -> Path: | |
| return Path(__file__).resolve().parent / "fixtures" / "banks" | |
| def test_full_episode_all_correct(): | |
| cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0) | |
| env = TemporalBenchEnvironment(config=cfg) | |
| obs0 = env.reset(seed=123) | |
| assert not obs0.done | |
| assert obs0.step_idx == 0 | |
| assert len(obs0.options) >= 2 | |
| total_r = 0.0 | |
| obs = obs0 | |
| for i in range(cfg.num_questions): | |
| cur = env._questions[env._answered] # noqa: SLF001 | |
| obs = env.step(TemporalBenchAction(answer=cur.answer)) | |
| total_r += obs.reward or 0.0 | |
| if i < cfg.num_questions - 1: | |
| assert not obs.done | |
| assert obs.step_idx == i + 1 | |
| else: | |
| assert obs.done | |
| assert obs.step_idx == cfg.num_questions | |
| st = env.state | |
| assert st.total_correct == cfg.num_questions | |
| assert st.step_count == cfg.num_questions | |
| assert st.total_reward == pytest.approx(total_r) | |
| def test_episode_with_one_wrong_answer(): | |
| cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0) | |
| env = TemporalBenchEnvironment(config=cfg) | |
| env.reset(seed=0) | |
| for i in range(cfg.num_questions): | |
| cur = env._questions[env._answered] # noqa: SLF001 | |
| wrong = next(o for o in cur.options if o != cur.answer) | |
| act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer) | |
| obs = env.step(act) | |
| if i == cfg.num_questions - 1: | |
| assert obs.done | |
| meta = obs.metadata or {} | |
| assert "episode_bonus" in meta | |
| bonus = float(meta["episode_bonus"]) | |
| lo = 0.5 * (8 / 9) * 0.8 | |
| hi = 0.5 * (8 / 9) * 1.0 | |
| assert lo - 1e-9 <= bonus <= hi + 1e-9 | |
| assert env.state.total_correct == 8 | |
| def test_double_step_after_done(): | |
| cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1) | |
| env = TemporalBenchEnvironment(config=cfg) | |
| env.reset(seed=1) | |
| for _ in range(cfg.num_questions): | |
| cur = env._questions[env._answered] # noqa: SLF001 | |
| env.step(TemporalBenchAction(answer=cur.answer)) | |
| obs = env.step(TemporalBenchAction(answer="x")) | |
| assert obs.done | |
| def test_empty_answer_no_advance(): | |
| cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2) | |
| env = TemporalBenchEnvironment(config=cfg) | |
| env.reset(seed=2) | |
| before = env._answered # noqa: SLF001 | |
| obs = env.step(TemporalBenchAction(answer=" ")) | |
| assert env._answered == before # noqa: SLF001 | |
| assert "error" in (obs.metadata or {}) | |