Spaces:
Sleeping
Sleeping
File size: 5,223 Bytes
a03a89b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """BabyAI level registry and curriculum definitions for MiniGridEnv."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class LevelConfig:
"""Configuration for a single BabyAI level."""
gym_id: str
name: str
description: str
difficulty: int
max_steps: int
expected_optimal_steps: int
requires_interaction: bool
num_objects: int
involves_language_composition: bool
fallback_gym_ids: tuple[str, ...] = ()
@property
def candidate_gym_ids(self) -> list[str]:
return [self.gym_id, *self.fallback_gym_ids]
LEVEL_REGISTRY: list[LevelConfig] = [
# Stage 0: simple navigation
LevelConfig(
gym_id="BabyAI-GoToRedBallGrey-v0",
fallback_gym_ids=("BabyAI-GoToRedBall-v0",),
name="GoToRedBall",
description="Navigate to the red ball in a single room.",
difficulty=0,
max_steps=64,
expected_optimal_steps=10,
requires_interaction=False,
num_objects=1,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-GoToObj-v0",
name="GoToObj",
description="Navigate to a specific colored object.",
difficulty=1,
max_steps=64,
expected_optimal_steps=12,
requires_interaction=False,
num_objects=2,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-GoToLocal-v0",
name="GoToLocal",
description="Navigate to a specific object with distractors.",
difficulty=2,
max_steps=64,
expected_optimal_steps=15,
requires_interaction=False,
num_objects=4,
involves_language_composition=False,
),
# Stage 1: object interaction
LevelConfig(
gym_id="BabyAI-PickupLoc-v0",
name="PickupLoc",
description="Pick up a specific object in a single room.",
difficulty=3,
max_steps=64,
expected_optimal_steps=14,
requires_interaction=True,
num_objects=4,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-OpenDoor-v0",
name="OpenDoor",
description="Open a door of a specified color.",
difficulty=3,
max_steps=64,
expected_optimal_steps=12,
requires_interaction=True,
num_objects=1,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-UnlockLocal-v0",
name="UnlockLocal",
description="Unlock a local door with the matching key.",
difficulty=4,
max_steps=128,
expected_optimal_steps=25,
requires_interaction=True,
num_objects=3,
involves_language_composition=False,
),
# Stage 2: multi-room and compositional
LevelConfig(
gym_id="BabyAI-GoTo-v0",
name="GoTo",
description="Navigate to a specified object across rooms.",
difficulty=5,
max_steps=128,
expected_optimal_steps=30,
requires_interaction=True,
num_objects=6,
involves_language_composition=False,
),
LevelConfig(
gym_id="BabyAI-PutNextLocal-v0",
name="PutNextLocal",
description="Place one object next to another local object.",
difficulty=6,
max_steps=128,
expected_optimal_steps=20,
requires_interaction=True,
num_objects=4,
involves_language_composition=True,
),
# Stage 3: hardest compositional tasks
LevelConfig(
gym_id="BabyAI-Synth-v0",
name="Synth",
description="Random compositional instructions.",
difficulty=7,
max_steps=256,
expected_optimal_steps=40,
requires_interaction=True,
num_objects=8,
involves_language_composition=True,
),
LevelConfig(
gym_id="BabyAI-BossLevel-v0",
name="BossLevel",
description="Hardest compositional BabyAI level.",
difficulty=8,
max_steps=512,
expected_optimal_steps=80,
requires_interaction=True,
num_objects=10,
involves_language_composition=True,
),
]
def get_level(name: str) -> LevelConfig:
"""Return a level by short name (case-insensitive)."""
for level in LEVEL_REGISTRY:
if level.name.lower() == name.lower():
return level
available = ", ".join(level.name for level in LEVEL_REGISTRY)
raise ValueError(f"Unknown level '{name}'. Available levels: {available}")
def get_levels_by_difficulty(max_difficulty: int) -> list[LevelConfig]:
"""Return levels with difficulty <= max_difficulty."""
return [level for level in LEVEL_REGISTRY if level.difficulty <= max_difficulty]
@dataclass
class CurriculumConfig:
"""Curriculum stages and advancement settings."""
stages: list[list[str]] = field(
default_factory=lambda: [
["GoToRedBall"],
["GoToObj", "GoToLocal"],
["PickupLoc", "OpenDoor", "UnlockLocal"],
["GoTo", "PutNextLocal"],
["Synth", "BossLevel"],
]
)
advance_threshold: float = 0.8
episodes_per_eval: int = 100
|