File size: 744 Bytes
a617acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from __future__ import annotations

from auditenv.datasets.easy_generator import generate_easy_episode
from auditenv.datasets.hard_generator import generate_hard_episode
from auditenv.datasets.medium_generator import generate_medium_episode
from auditenv.datasets.types import GeneratedEpisode


def generate_episode(task_id: str, seed: int = 42, config_path: str = "configs/datasets.yaml") -> GeneratedEpisode:
    if task_id == "easy":
        return generate_easy_episode(seed=seed)
    if task_id == "medium":
        return generate_medium_episode(seed=seed, config_path=config_path)
    if task_id == "hard":
        return generate_hard_episode(seed=seed, config_path=config_path)
    raise ValueError(f"Unsupported task_id: {task_id}")