File size: 1,064 Bytes
d954568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np

from data.loaders import load_question_banks
from env.config import EnvConfig
from env.episode_sampler import EpisodeSampler


def test_sample_episode_shape_and_domains(fixture_bank_dir):
    banks = load_question_banks(str(fixture_bank_dir))
    cfg = EnvConfig(
        question_bank_path=str(fixture_bank_dir),
        primary_domain="PSML",
        seed=42,
    )
    rng = np.random.default_rng(42)
    sampler = EpisodeSampler(banks, cfg, rng)
    ep = sampler.sample_episode()
    assert len(ep) == cfg.num_questions
    datasets = {q.dataset for q in ep}
    assert datasets == set(cfg.all_domains)


def test_curriculum_stage_filters_task_types(fixture_bank_dir):
    banks = load_question_banks(str(fixture_bank_dir))
    cfg = EnvConfig(
        question_bank_path=str(fixture_bank_dir),
        curriculum_stage=1,
        seed=0,
    )
    rng = np.random.default_rng(0)
    sampler = EpisodeSampler(banks, cfg, rng)
    for _ in range(5):
        ep = sampler.sample_episode()
        assert all(q.task_type == "T1U" for q in ep)