ArbitrAgent / diplomacy_env.py
AbeBhatti
last update before doc changes
c2097f9
import random
from typing import Any, Dict, Tuple
import numpy as np
from diplomacy import Game
from openenv.env import Env
from sentence_transformers import SentenceTransformer
class DiplomacyNegotiationEnv(Env):
"""
OpenEnv-compatible wrapper around the diplomacy.Game engine.
Observation: 384-dim MiniLM embedding of a textual game-state description
from the perspective of a single power (e.g. ENGLAND).
Action: free-form text describing strategic intent (logged but not yet parsed).
"""
def __init__(self, power_name: str = "ENGLAND", seed: int | None = None):
self.power_name = power_name.upper()
self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
self.game: Game | None = None
self.current_phase: int = 0
self.prev_sc_count: int = 0
self.max_phases: int = 50
if seed is not None:
random.seed(seed)
def reset(self) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Reset the underlying Diplomacy game and return initial observation + info."""
self.game = Game()
self.current_phase = 0
state = self.game.get_state()
centers = state.get("centers", {})
self.prev_sc_count = len(centers.get(self.power_name, []))
obs = self._get_observation()
info = {"phase": state.get("name"), "sc_count": self.prev_sc_count}
return obs, info
def step(self, action: str):
"""
Advance one phase.
- Currently ignores the semantic content of `action` and instead
submits random legal orders for all powers.
- Logs the provided action in the returned info for later analysis.
"""
if self.game is None:
raise RuntimeError("Environment must be reset() before step().")
# Submit random legal orders for all powers.
all_possible = self.game.get_all_possible_orders()
for power, locs in self.game.get_orderable_locations().items():
orders = []
for loc in locs:
loc_orders = all_possible.get(loc.upper(), [])
if loc_orders:
orders.append(random.choice(list(loc_orders)))
if orders:
self.game.set_orders(power, orders)
self.game.process()
self.current_phase += 1
reward = self._compute_reward()
obs = self._get_observation()
done = self.game.is_game_done or self.current_phase >= self.max_phases
state = self.game.get_state()
curr_sc = len(state.get("centers", {}).get(self.power_name, []))
if self.game.is_game_done:
done_reason = "game_complete"
elif self.current_phase >= self.max_phases:
done_reason = "max_phases"
else:
done_reason = None
info = {
"phase": state.get("name"),
"sc_count": curr_sc,
"sc_delta": curr_sc - self.prev_sc_count,
"action_logged": action,
"done_reason": done_reason,
}
return obs, reward, done, info
def _compute_reward(self) -> float:
"""Shaped reward based on SC changes, relative rank, and game outcome."""
if self.game is None:
return 0.0
state = self.game.get_state()
centers = state.get("centers", {})
curr_sc = len(centers.get(self.power_name, []))
all_counts = {p: len(c) for p, c in centers.items()}
delta = curr_sc - self.prev_sc_count
self.prev_sc_count = curr_sc
reward = 0.0
if delta > 0:
reward += 1.0
if delta < 0:
reward -= 1.0
if curr_sc == 0:
reward -= 2.0
# Relative position bonuses/penalties.
if all_counts:
sorted_counts = sorted(all_counts.values(), reverse=True)
top_two = sorted_counts[:2]
bottom_two = sorted_counts[-2:]
if curr_sc in top_two:
reward += 0.3
if curr_sc in bottom_two and curr_sc > 0:
reward -= 0.2
# Game outcome bonus when completed.
if self.game.is_game_done:
outcome = getattr(self.game, "outcome", [])
if isinstance(outcome, list) and len(outcome) > 1:
if self.power_name in [w.upper() for w in outcome[1:]]:
reward += 2.0
return float(reward)
def _get_observation(self) -> np.ndarray:
"""Return a 384-dim MiniLM embedding of the current game state text."""
text = self._get_state_text()
embedding = self.encoder.encode(text, convert_to_numpy=True)
# Ensure consistent dtype for downstream RL code.
return embedding.astype(np.float32)
def _get_state_text(self) -> str:
"""Human-readable textual description of the current game state."""
if self.game is None:
return "Environment not initialized."
state = self.game.get_state()
centers = state.get("centers", {})
units = state.get("units", {})
phase = state.get("name", "UNKNOWN")
my_scs = centers.get(self.power_name, [])
my_units = units.get(self.power_name, [])
curr_sc = len(my_scs)
delta = curr_sc - self.prev_sc_count
# Coarse strategic position label.
if curr_sc > 10:
position = "dominant"
elif curr_sc >= 7:
position = "strong"
elif curr_sc >= 4:
position = "stable"
elif curr_sc >= 2:
position = "weak"
else:
position = "critical"
lines: list[str] = [
"DIPLOMACY GAME STATE",
f"Phase: {phase}",
f"Playing as: {self.power_name}",
"",
f"My units: {', '.join(my_units) or 'None'}",
f"My supply centers: {', '.join(my_scs) or 'None'} ({curr_sc} centers)",
"",
"Other powers:",
]
for power in sorted(centers.keys()):
if power == self.power_name:
continue
sc_count = len(centers.get(power, []))
unit_list = units.get(power, [])
lines.append(
f" {power}: {sc_count} SCs | Units: {', '.join(unit_list) or 'None'}"
)
lines += [
"",
f"Strategic position: {position}",
f"Supply center delta: {delta:+d}",
]
return "\n".join(lines)
def render(self):
"""Print and return the current state text."""
text = self._get_state_text()
print(text)
return text
def close(self):
"""Clean up the underlying game."""
self.game = None
print("Environment closed.")
@property
def observation_space(self) -> Dict[str, Any]:
return {"type": "continuous", "shape": (384,), "dtype": "float32"}
@property
def action_space(self) -> Dict[str, Any]:
return {"type": "text", "description": "Natural language strategic intent"}