| from warehouse_env.env import WarehouseEnv |
| from warehouse_env.models import WarehouseAction, RobotAction |
| from warehouse_env.graders import GRADER_REGISTRY |
|
|
| |
| env = WarehouseEnv() |
|
|
| |
| print("Available tasks:", env.list_tasks()) |
|
|
| |
| |
| |
| task_id = "crisis_management" |
| obs = env.reset(task_id=task_id) |
|
|
| print(f"\nStarted task: {task_id}") |
| print(f"Initial description: {obs.description}") |
|
|
| |
| |
| |
| action = WarehouseAction( |
| robots=[RobotAction(robot_id=i, action_type="wait") for i in range(5)] |
| ) |
|
|
| print("\n--- Simulating 15 steps ---") |
| for step in range(1, 16): |
| obs = env.step(action) |
| |
| |
| wasted_reward = obs.metadata.get("reward_breakdown", {}).get("wasted_step", 0.0) |
| |
| if step == 15: |
| print(f"\nStep {step} Observation Summary:") |
| print(obs.description) |
| print(f"Step Reward: {obs.reward} (Wasted step penalty: {wasted_reward})") |
|
|
| |
| |
| env._episode.done = True |
| score = GRADER_REGISTRY[task_id](env) |
| print(f"\nFinal Grader Score for {task_id}: {score} (out of 1.0)") |