mini-rl-env / grid_env /models.py
sohambose98's picture
updated the tests and graders
eaa79f0
"""
Pydantic models for the warehouse fulfillment environment.
"""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, Tuple
try:
from pydantic import BaseModel, ConfigDict, Field
except ImportError:
class BaseModel:
def __init__(self, **data: Any) -> None:
for key, value in data.items():
setattr(self, key, value)
def model_dump(self) -> Dict[str, Any]:
def _convert(value: Any) -> Any:
if hasattr(value, "model_dump"):
return value.model_dump()
if isinstance(value, list):
return [_convert(item) for item in value]
if isinstance(value, tuple):
return tuple(_convert(item) for item in value)
if isinstance(value, dict):
return {key: _convert(item) for key, item in value.items()}
return value
return {key: _convert(value) for key, value in self.__dict__.items()}
@classmethod
def model_json_schema(cls) -> Dict[str, Any]:
return {"title": cls.__name__, "type": "object", "properties": {}}
def Field(default: Any = None, **kwargs: Any) -> Any:
if "default_factory" in kwargs:
return kwargs["default_factory"]()
return default
ConfigDict = dict
try:
from openenv.core.env_server.interfaces import Action, Observation, State
except ImportError:
class Action(BaseModel):
pass
class Observation(BaseModel):
pass
class State(BaseModel):
pass
Position = Tuple[int, int]
Heading = Literal["N", "E", "S", "W"]
Command = Literal[
"turn_left",
"turn_right",
"move_forward",
"scan_bin",
"pick_item",
"pack_item",
"recharge",
"rest",
"wait",
]
class OpenEnvModel(BaseModel):
model_config = ConfigDict(extra="forbid")
class WarehouseAction(Action, OpenEnvModel):
command: Command
class OrderLine(OpenEnvModel):
sku: str
quantity: int
class BinState(OpenEnvModel):
bin_id: str
position: Position
sku: str
quantity: int
weight: int = 1
value: float = 0.0
class TaskDefinition(OpenEnvModel):
task_id: str
difficulty: Literal["easy", "medium", "hard", "expert"]
title: str
description: str
max_steps: int
battery_capacity: int
low_battery_threshold: int
agent_start: Position
agent_heading: Heading
dock_position: Position
pack_station_position: Position
charger_position: Position
bins: List[BinState]
order: List[OrderLine]
required_scans: List[str] = Field(default_factory=list)
rubric_criteria: List[Dict[str, str]] = Field(default_factory=list)
obstacles: List[Position] = Field(default_factory=list)
carry_capacity: int = 99
stamina_capacity: int = 0
stamina_move_cost: int = 1
rest_position: Optional[Position] = None
profit_target: float = 0.0
class PendingOrderLine(OpenEnvModel):
sku: str
remaining: int
class PackedOrderLine(OpenEnvModel):
sku: str
packed: int
class WarehouseObservation(Observation, OpenEnvModel):
task_id: str
mission: str
narrative: str
agent_position: Position
heading: Heading
front_cell: str
carrying: Optional[str]
carrying_weight: int = 0
battery_level: int
stamina_level: int = 0
money: float = 0.0
visible_bins: List[str]
pending_order: List[PendingOrderLine]
packed_order: List[PackedOrderLine]
progress_ratio: float
class WarehouseReward(OpenEnvModel):
value: float
reason: str
completion_ratio: float
class WarehouseMetrics(OpenEnvModel):
correct_scans: int = 0
wrong_scans: int = 0
correct_picks: int = 0
wrong_picks: int = 0
correct_packs: int = 0
invalid_actions: int = 0
recharges: int = 0
battery_depletion_events: int = 0
distance_travelled: int = 0
stamina_depletion_events: int = 0
rest_events: int = 0
obstacle_collisions: int = 0
money_earned: float = 0.0
money_lost: float = 0.0
overweight_attempts: int = 0
class WarehouseState(State, OpenEnvModel):
episode_id: str
task_id: str
difficulty: str
step_count: int
done: bool
success: bool
max_steps: int
grid_size: Position
agent_position: Position
heading: Heading
carrying: Optional[str]
carrying_weight: int = 0
battery_level: int
battery_capacity: int
stamina_level: int = 0
stamina_capacity: int = 0
money: float = 0.0
profit_target: float = 0.0
dock_position: Position
pack_station_position: Position
charger_position: Position
obstacles: List[Position] = Field(default_factory=list)
bins: List[BinState]
order: List[OrderLine]
packed_order: List[OrderLine]
scanned_bins: List[str]
metrics: WarehouseMetrics
action_history: List[str]
total_reward: float
completion_ratio: float
task_description: str
class BaselineCommand(OpenEnvModel):
command: Command
rationale: str = ""
def model_to_dict(model: Any) -> Dict[str, Any]:
if hasattr(model, "model_dump"):
return model.model_dump()
if isinstance(model, dict):
return model
raise TypeError(f"Unsupported model type: {type(model)!r}")