Corp_AI / src /auditenv /datasets /factory.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
raw
history blame contribute delete
744 Bytes
from __future__ import annotations
from auditenv.datasets.easy_generator import generate_easy_episode
from auditenv.datasets.hard_generator import generate_hard_episode
from auditenv.datasets.medium_generator import generate_medium_episode
from auditenv.datasets.types import GeneratedEpisode
def generate_episode(task_id: str, seed: int = 42, config_path: str = "configs/datasets.yaml") -> GeneratedEpisode:
if task_id == "easy":
return generate_easy_episode(seed=seed)
if task_id == "medium":
return generate_medium_episode(seed=seed, config_path=config_path)
if task_id == "hard":
return generate_hard_episode(seed=seed, config_path=config_path)
raise ValueError(f"Unsupported task_id: {task_id}")