import numpy as np from data.loaders import load_question_banks from env.config import EnvConfig from env.episode_sampler import EpisodeSampler def test_sample_episode_shape_and_domains(fixture_bank_dir): banks = load_question_banks(str(fixture_bank_dir)) cfg = EnvConfig( question_bank_path=str(fixture_bank_dir), primary_domain="PSML", seed=42, ) rng = np.random.default_rng(42) sampler = EpisodeSampler(banks, cfg, rng) ep = sampler.sample_episode() assert len(ep) == cfg.num_questions datasets = {q.dataset for q in ep} assert datasets == set(cfg.all_domains) def test_curriculum_stage_filters_task_types(fixture_bank_dir): banks = load_question_banks(str(fixture_bank_dir)) cfg = EnvConfig( question_bank_path=str(fixture_bank_dir), curriculum_stage=1, seed=0, ) rng = np.random.default_rng(0) sampler = EpisodeSampler(banks, cfg, rng) for _ in range(5): ep = sampler.sample_episode() assert all(q.task_type == "T1U" for q in ep)