TemporalBenchEnv / tests /test_sampler.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
import numpy as np
from data.loaders import load_question_banks
from env.config import EnvConfig
from env.episode_sampler import EpisodeSampler
def test_sample_episode_shape_and_domains(fixture_bank_dir):
banks = load_question_banks(str(fixture_bank_dir))
cfg = EnvConfig(
question_bank_path=str(fixture_bank_dir),
primary_domain="PSML",
seed=42,
)
rng = np.random.default_rng(42)
sampler = EpisodeSampler(banks, cfg, rng)
ep = sampler.sample_episode()
assert len(ep) == cfg.num_questions
datasets = {q.dataset for q in ep}
assert datasets == set(cfg.all_domains)
def test_curriculum_stage_filters_task_types(fixture_bank_dir):
banks = load_question_banks(str(fixture_bank_dir))
cfg = EnvConfig(
question_bank_path=str(fixture_bank_dir),
curriculum_stage=1,
seed=0,
)
rng = np.random.default_rng(0)
sampler = EpisodeSampler(banks, cfg, rng)
for _ in range(5):
ep = sampler.sample_episode()
assert all(q.task_type == "T1U" for q in ep)