multi-agent-strat / client.py
Avnishjain's picture
Upload 21 files
6888575 verified
"""Typed client for strategy environment."""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import StrategyAction, StrategyObservation, StrategyState
class StrategyEnv(EnvClient[StrategyAction, StrategyObservation, StrategyState]):
def _step_payload(self, action: StrategyAction) -> Dict[str, Any]:
return action.model_dump(exclude_none=True)
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
obs_data = payload.get("observation", {}) or {}
obs = StrategyObservation(
done=bool(payload.get("done", obs_data.get("done", False))),
reward=payload.get("reward", obs_data.get("reward")),
task_id=str(obs_data.get("task_id", "")),
difficulty=str(obs_data.get("difficulty", "easy")),
objective=str(obs_data.get("objective", "")),
turn=int(obs_data.get("turn", 0)),
max_turns=int(obs_data.get("max_turns", 12)),
active_rule=str(obs_data.get("active_rule", "expansion")),
rule_hint=str(obs_data.get("rule_hint", "")),
own_resources=int(obs_data.get("own_resources", 0)),
own_defense=int(obs_data.get("own_defense", 0)),
own_intel=int(obs_data.get("own_intel", 0)),
visible_opponent_resources=int(obs_data.get("visible_opponent_resources", 0)),
visible_opponent_defense=int(obs_data.get("visible_opponent_defense", 0)),
last_public_event=str(obs_data.get("last_public_event", "")),
last_agent_action=str(obs_data.get("last_agent_action", "none")),
last_opponent_action=str(obs_data.get("last_opponent_action", "none")),
recent_rule_history=list(obs_data.get("recent_rule_history", [])),
allowed_actions=list(obs_data.get("allowed_actions", [])),
)
return StepResult(
observation=obs,
reward=payload.get("reward", obs.reward),
done=bool(payload.get("done", obs.done)),
info=payload.get("info", {}),
)
def _parse_state(self, payload: Dict[str, Any]) -> StrategyState:
return StrategyState(
episode_id=payload.get("episode_id"),
step_count=int(payload.get("step_count", 0)),
task_id=str(payload.get("task_id", "")),
difficulty=str(payload.get("difficulty", "easy")),
objective=str(payload.get("objective", "")),
turn=int(payload.get("turn", 0)),
max_turns=int(payload.get("max_turns", 12)),
active_rule=str(payload.get("active_rule", "expansion")),
own_resources=int(payload.get("own_resources", 0)),
own_defense=int(payload.get("own_defense", 0)),
own_intel=int(payload.get("own_intel", 0)),
visible_opponent_resources=int(payload.get("visible_opponent_resources", 0)),
visible_opponent_defense=int(payload.get("visible_opponent_defense", 0)),
cumulative_reward=float(payload.get("cumulative_reward", 0.0)),
done=bool(payload.get("done", False)),
history=list(payload.get("history", [])),
)