"""BabyAI level registry and curriculum definitions for MiniGridEnv.""" from __future__ import annotations from dataclasses import dataclass, field @dataclass(frozen=True) class LevelConfig: """Configuration for a single BabyAI level.""" gym_id: str name: str description: str difficulty: int max_steps: int expected_optimal_steps: int requires_interaction: bool num_objects: int involves_language_composition: bool fallback_gym_ids: tuple[str, ...] = () @property def candidate_gym_ids(self) -> list[str]: return [self.gym_id, *self.fallback_gym_ids] LEVEL_REGISTRY: list[LevelConfig] = [ # Stage 0: simple navigation LevelConfig( gym_id="BabyAI-GoToRedBallGrey-v0", fallback_gym_ids=("BabyAI-GoToRedBall-v0",), name="GoToRedBall", description="Navigate to the red ball in a single room.", difficulty=0, max_steps=64, expected_optimal_steps=10, requires_interaction=False, num_objects=1, involves_language_composition=False, ), LevelConfig( gym_id="BabyAI-GoToObj-v0", name="GoToObj", description="Navigate to a specific colored object.", difficulty=1, max_steps=64, expected_optimal_steps=12, requires_interaction=False, num_objects=2, involves_language_composition=False, ), LevelConfig( gym_id="BabyAI-GoToLocal-v0", name="GoToLocal", description="Navigate to a specific object with distractors.", difficulty=2, max_steps=64, expected_optimal_steps=15, requires_interaction=False, num_objects=4, involves_language_composition=False, ), # Stage 1: object interaction LevelConfig( gym_id="BabyAI-PickupLoc-v0", name="PickupLoc", description="Pick up a specific object in a single room.", difficulty=3, max_steps=64, expected_optimal_steps=14, requires_interaction=True, num_objects=4, involves_language_composition=False, ), LevelConfig( gym_id="BabyAI-OpenDoor-v0", name="OpenDoor", description="Open a door of a specified color.", difficulty=3, max_steps=64, expected_optimal_steps=12, requires_interaction=True, num_objects=1, involves_language_composition=False, ), LevelConfig( gym_id="BabyAI-UnlockLocal-v0", name="UnlockLocal", description="Unlock a local door with the matching key.", difficulty=4, max_steps=128, expected_optimal_steps=25, requires_interaction=True, num_objects=3, involves_language_composition=False, ), # Stage 2: multi-room and compositional LevelConfig( gym_id="BabyAI-GoTo-v0", name="GoTo", description="Navigate to a specified object across rooms.", difficulty=5, max_steps=128, expected_optimal_steps=30, requires_interaction=True, num_objects=6, involves_language_composition=False, ), LevelConfig( gym_id="BabyAI-PutNextLocal-v0", name="PutNextLocal", description="Place one object next to another local object.", difficulty=6, max_steps=128, expected_optimal_steps=20, requires_interaction=True, num_objects=4, involves_language_composition=True, ), # Stage 3: hardest compositional tasks LevelConfig( gym_id="BabyAI-Synth-v0", name="Synth", description="Random compositional instructions.", difficulty=7, max_steps=256, expected_optimal_steps=40, requires_interaction=True, num_objects=8, involves_language_composition=True, ), LevelConfig( gym_id="BabyAI-BossLevel-v0", name="BossLevel", description="Hardest compositional BabyAI level.", difficulty=8, max_steps=512, expected_optimal_steps=80, requires_interaction=True, num_objects=10, involves_language_composition=True, ), ] def get_level(name: str) -> LevelConfig: """Return a level by short name (case-insensitive).""" for level in LEVEL_REGISTRY: if level.name.lower() == name.lower(): return level available = ", ".join(level.name for level in LEVEL_REGISTRY) raise ValueError(f"Unknown level '{name}'. Available levels: {available}") def get_levels_by_difficulty(max_difficulty: int) -> list[LevelConfig]: """Return levels with difficulty <= max_difficulty.""" return [level for level in LEVEL_REGISTRY if level.difficulty <= max_difficulty] @dataclass class CurriculumConfig: """Curriculum stages and advancement settings.""" stages: list[list[str]] = field( default_factory=lambda: [ ["GoToRedBall"], ["GoToObj", "GoToLocal"], ["PickupLoc", "OpenDoor", "UnlockLocal"], ["GoTo", "PutNextLocal"], ["Synth", "BossLevel"], ] ) advance_threshold: float = 0.8 episodes_per_eval: int = 100