MiniGridEnv / env /levels.py
yashu2000's picture
Upload folder using huggingface_hub
a03a89b verified
"""BabyAI level registry and curriculum definitions for MiniGridEnv."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class LevelConfig:
"""Configuration for a single BabyAI level."""
gym_id: str
name: str
description: str
difficulty: int
max_steps: int
expected_optimal_steps: int
requires_interaction: bool
num_objects: int
involves_language_composition: bool
fallback_gym_ids: tuple[str, ...] = ()
@property
def candidate_gym_ids(self) -> list[str]:
return [self.gym_id, *self.fallback_gym_ids]
LEVEL_REGISTRY: list[LevelConfig] = [
# Stage 0: simple navigation
LevelConfig(
gym_id="BabyAI-GoToRedBallGrey-v0",
fallback_gym_ids=("BabyAI-GoToRedBall-v0",),
name="GoToRedBall",
description="Navigate to the red ball in a single room.",
difficulty=0,
max_steps=64,
expected_optimal_steps=10,
requires_interaction=False,
num_objects=1,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-GoToObj-v0",
name="GoToObj",
description="Navigate to a specific colored object.",
difficulty=1,
max_steps=64,
expected_optimal_steps=12,
requires_interaction=False,
num_objects=2,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-GoToLocal-v0",
name="GoToLocal",
description="Navigate to a specific object with distractors.",
difficulty=2,
max_steps=64,
expected_optimal_steps=15,
requires_interaction=False,
num_objects=4,
involves_language_composition=False,
),
# Stage 1: object interaction
LevelConfig(
gym_id="BabyAI-PickupLoc-v0",
name="PickupLoc",
description="Pick up a specific object in a single room.",
difficulty=3,
max_steps=64,
expected_optimal_steps=14,
requires_interaction=True,
num_objects=4,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-OpenDoor-v0",
name="OpenDoor",
description="Open a door of a specified color.",
difficulty=3,
max_steps=64,
expected_optimal_steps=12,
requires_interaction=True,
num_objects=1,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-UnlockLocal-v0",
name="UnlockLocal",
description="Unlock a local door with the matching key.",
difficulty=4,
max_steps=128,
expected_optimal_steps=25,
requires_interaction=True,
num_objects=3,
involves_language_composition=False,
),
# Stage 2: multi-room and compositional
LevelConfig(
gym_id="BabyAI-GoTo-v0",
name="GoTo",
description="Navigate to a specified object across rooms.",
difficulty=5,
max_steps=128,
expected_optimal_steps=30,
requires_interaction=True,
num_objects=6,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-PutNextLocal-v0",
name="PutNextLocal",
description="Place one object next to another local object.",
difficulty=6,
max_steps=128,
expected_optimal_steps=20,
requires_interaction=True,
num_objects=4,
involves_language_composition=True,
),
# Stage 3: hardest compositional tasks
LevelConfig(
gym_id="BabyAI-Synth-v0",
name="Synth",
description="Random compositional instructions.",
difficulty=7,
max_steps=256,
expected_optimal_steps=40,
requires_interaction=True,
num_objects=8,
involves_language_composition=True,
),
LevelConfig(
gym_id="BabyAI-BossLevel-v0",
name="BossLevel",
description="Hardest compositional BabyAI level.",
difficulty=8,
max_steps=512,
expected_optimal_steps=80,
requires_interaction=True,
num_objects=10,
involves_language_composition=True,
),
]
def get_level(name: str) -> LevelConfig:
"""Return a level by short name (case-insensitive)."""
for level in LEVEL_REGISTRY:
if level.name.lower() == name.lower():
return level
available = ", ".join(level.name for level in LEVEL_REGISTRY)
raise ValueError(f"Unknown level '{name}'. Available levels: {available}")
def get_levels_by_difficulty(max_difficulty: int) -> list[LevelConfig]:
"""Return levels with difficulty <= max_difficulty."""
return [level for level in LEVEL_REGISTRY if level.difficulty <= max_difficulty]
@dataclass
class CurriculumConfig:
"""Curriculum stages and advancement settings."""
stages: list[list[str]] = field(
default_factory=lambda: [
["GoToRedBall"],
["GoToObj", "GoToLocal"],
["PickupLoc", "OpenDoor", "UnlockLocal"],
["GoTo", "PutNextLocal"],
["Synth", "BossLevel"],
]
)
advance_threshold: float = 0.8
episodes_per_eval: int = 100