File size: 2,764 Bytes
d954568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pathlib import Path

import pytest

from env.config import EnvConfig
from env.temporal_bench_env import TemporalBenchEnvironment
from models import TemporalBenchAction


def _bank_dir() -> Path:
    return Path(__file__).resolve().parent / "fixtures" / "banks"


def test_full_episode_all_correct():
    cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=123, lambda_ep=0.5, alpha=1.0)
    env = TemporalBenchEnvironment(config=cfg)
    obs0 = env.reset(seed=123)
    assert not obs0.done
    assert obs0.step_idx == 0
    assert len(obs0.options) >= 2

    total_r = 0.0
    obs = obs0
    for i in range(cfg.num_questions):
        cur = env._questions[env._answered]  # noqa: SLF001
        obs = env.step(TemporalBenchAction(answer=cur.answer))
        total_r += obs.reward or 0.0
        if i < cfg.num_questions - 1:
            assert not obs.done
            assert obs.step_idx == i + 1
        else:
            assert obs.done
            assert obs.step_idx == cfg.num_questions

    st = env.state
    assert st.total_correct == cfg.num_questions
    assert st.step_count == cfg.num_questions
    assert st.total_reward == pytest.approx(total_r)


def test_episode_with_one_wrong_answer():
    cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=0, lambda_ep=0.5, alpha=1.0)
    env = TemporalBenchEnvironment(config=cfg)
    env.reset(seed=0)
    for i in range(cfg.num_questions):
        cur = env._questions[env._answered]  # noqa: SLF001
        wrong = next(o for o in cur.options if o != cur.answer)
        act = TemporalBenchAction(answer=wrong if i == 0 else cur.answer)
        obs = env.step(act)
        if i == cfg.num_questions - 1:
            assert obs.done
            meta = obs.metadata or {}
            assert "episode_bonus" in meta
            bonus = float(meta["episode_bonus"])
            lo = 0.5 * (8 / 9) * 0.8
            hi = 0.5 * (8 / 9) * 1.0
            assert lo - 1e-9 <= bonus <= hi + 1e-9
    assert env.state.total_correct == 8


def test_double_step_after_done():
    cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=1)
    env = TemporalBenchEnvironment(config=cfg)
    env.reset(seed=1)
    for _ in range(cfg.num_questions):
        cur = env._questions[env._answered]  # noqa: SLF001
        env.step(TemporalBenchAction(answer=cur.answer))
    obs = env.step(TemporalBenchAction(answer="x"))
    assert obs.done


def test_empty_answer_no_advance():
    cfg = EnvConfig(question_bank_path=str(_bank_dir()), seed=2)
    env = TemporalBenchEnvironment(config=cfg)
    env.reset(seed=2)
    before = env._answered  # noqa: SLF001
    obs = env.step(TemporalBenchAction(answer="  "))
    assert env._answered == before  # noqa: SLF001
    assert "error" in (obs.metadata or {})