suvasis's picture
code add
e4d7d50
"""
openenv/models.py
─────────────────
Pydantic schemas that exactly match the OpenEnv 0.1 HTTP spec.
POST /reset → ResetResponse
POST /step → StepResponse
GET /state → StateResponse
All three wrap a shared Observation object that carries chess-specific
fields inside the `info` dict so the core contract stays generic.
"""
from __future__ import annotations
from typing import Any, Optional
from pydantic import BaseModel, Field
# ── Request bodies ─────────────────────────────────────────────────────────────
class StepRequest(BaseModel):
"""Action sent by the RL trainer to advance the environment by one move."""
action: str = Field(
...,
description="Chess move in UCI notation (e.g. 'e2e4') or SAN (e.g. 'e4')",
examples=["e2e4", "Nf3", "O-O"],
)
class ResetRequest(BaseModel):
"""Optional seed / config passed on reset. All fields optional."""
seed: Optional[int] = Field(None, description="RNG seed for reproducibility")
config: Optional[dict[str, Any]] = Field(
None, description="Override environment config for this episode"
)
# ── Core observation ───────────────────────────────────────────────────────────
class ChessObservation(BaseModel):
"""
Chess-specific observation. Returned inside every response as `observation`.
The `info` dict carries auxiliary data (legal moves, last move, etc.) so that
the outer schema stays OpenEnv-generic.
"""
fen: str = Field(..., description="Current board position in FEN notation")
turn: str = Field(..., description="'white' or 'black'")
move_number: int = Field(..., description="Full-move number (1-indexed)")
last_move_uci: Optional[str] = Field(None, description="Last move in UCI notation")
last_move_san: Optional[str] = Field(None, description="Last move in SAN notation")
legal_moves_uci: list[str] = Field(..., description="All legal moves in UCI notation")
is_check: bool = Field(False, description="Whether the current side is in check")
# Economy
wallet_white: float = Field(..., description="White agent wallet balance (units)")
wallet_black: float = Field(..., description="Black agent wallet balance (units)")
# Agent identities
white_model: str = Field(..., description="Model ID playing White")
black_model: str = Field(..., description="Model ID playing Black")
# Info dict for auxiliary / extensible data
info: dict[str, Any] = Field(default_factory=dict)
# ── OpenEnv response bodies ────────────────────────────────────────────────────
class ResetResponse(BaseModel):
"""
Returned by POST /reset.
OpenEnv spec: { observation, info }
"""
observation: ChessObservation
info: dict[str, Any] = Field(default_factory=dict)
class StepResponse(BaseModel):
"""
Returned by POST /step.
OpenEnv spec: { observation, reward, terminated, truncated, info }
"""
observation: ChessObservation
reward: float = Field(..., description="Per-step reward signal")
terminated: bool = Field(..., description="True if the episode ended naturally (checkmate/stalemate/draw)")
truncated: bool = Field(..., description="True if the episode was cut short (move limit)")
info: dict[str, Any] = Field(default_factory=dict)
class StateResponse(BaseModel):
"""
Returned by GET /state.
OpenEnv spec: { observation, info, episode_id, step_count, status }
"""
observation: ChessObservation
info: dict[str, Any] = Field(default_factory=dict)
episode_id: str = Field(..., description="Unique identifier for the current episode")
step_count: int = Field(..., description="Number of moves played so far")
status: str = Field(..., description="'active' | 'terminated' | 'truncated' | 'idle'")
# ── Environment info ──────────────────────────────────────────────────────────
class EnvInfo(BaseModel):
"""Returned by GET /env_info — describes environment capabilities."""
name: str = "chessecon"
version: str = "1.0.0"
description: str = (
"Two-agent chess economy environment. White plays Qwen2.5-0.5B-Instruct, "
"Black plays Llama-3.2-1B-Instruct. Agents earn/lose economic units based "
"on game outcomes. Compatible with OpenEnv 0.1 spec."
)
openenv_version: str = "0.1"
action_space: dict = Field(
default_factory=lambda: {
"type": "text",
"description": "Chess move in UCI (e2e4) or SAN (e4) notation",
}
)
observation_space: dict = Field(
default_factory=lambda: {
"type": "structured",
"fields": ["fen", "turn", "move_number", "legal_moves_uci",
"wallet_white", "wallet_black", "is_check"],
}
)
reward_range: list[float] = Field(default_factory=lambda: [-1.0, 1.0])
max_episode_steps: int = 300
agents: list[dict] = Field(
default_factory=lambda: [
{"id": "white", "model": "Qwen/Qwen2.5-0.5B-Instruct", "role": "White player"},
{"id": "black", "model": "meta-llama/Llama-3.2-1B-Instruct", "role": "Black player"},
]
)
tags: list[str] = Field(
default_factory=lambda: [
"chess", "multi-agent", "rl", "grpo", "economy",
"openenv", "two-player", "game",
]
)