Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Dockerfile +2 -0
- build/lib/freeciv_env/__init__.py +10 -0
- build/lib/freeciv_env/adapter.py +335 -0
- build/lib/freeciv_env/client.py +22 -0
- build/lib/freeciv_env/grpo.py +97 -0
- build/lib/freeciv_env/models.py +112 -0
- build/lib/freeciv_env/runtime.py +401 -0
- build/lib/freeciv_env/server/__init__.py +3 -0
- build/lib/freeciv_env/server/app.py +42 -0
- build/lib/freeciv_env/server/freeciv_environment.py +163 -0
- build/lib/server/__init__.py +0 -0
- build/lib/server/app.py +10 -0
- freeciv_env.egg-info/PKG-INFO +13 -0
- freeciv_env/server/Dockerfile +2 -0
Dockerfile
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
FROM ${BASE_IMAGE} AS builder
|
| 3 |
|
|
|
|
|
|
|
| 4 |
WORKDIR /app/env
|
| 5 |
COPY . /app/env
|
| 6 |
|
|
|
|
| 1 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
FROM ${BASE_IMAGE} AS builder
|
| 3 |
|
| 4 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 5 |
+
|
| 6 |
WORKDIR /app/env
|
| 7 |
COPY . /app/env
|
| 8 |
|
build/lib/freeciv_env/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from freeciv_env.client import FreecivEnv
|
| 2 |
+
from freeciv_env.models import FreecivAction, FreecivObservation, FreecivState, LegalAction
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"FreecivAction",
|
| 6 |
+
"FreecivEnv",
|
| 7 |
+
"FreecivObservation",
|
| 8 |
+
"FreecivState",
|
| 9 |
+
"LegalAction",
|
| 10 |
+
]
|
build/lib/freeciv_env/adapter.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from freeciv_env.models import CitySummary, FreecivAction, FreecivObservation, LegalAction, UnitSummary
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
ActionLookupKey = tuple[str, int | None, int | None, str | None]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class ActionRef:
|
| 14 |
+
controller: str
|
| 15 |
+
actor_id: int | str
|
| 16 |
+
raw_action_key: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class RawSnapshot:
|
| 21 |
+
turn: int
|
| 22 |
+
state: dict[str, Any]
|
| 23 |
+
actions: dict[str, Any]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True)
|
| 27 |
+
class SnapshotMetrics:
|
| 28 |
+
score: float
|
| 29 |
+
known_tiles: int
|
| 30 |
+
visible_tiles: int
|
| 31 |
+
city_count: int
|
| 32 |
+
unit_count: int
|
| 33 |
+
techs_researched: int
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class PreparedObservation:
|
| 38 |
+
observation: FreecivObservation
|
| 39 |
+
metrics: SnapshotMetrics
|
| 40 |
+
action_refs: dict[ActionLookupKey, ActionRef]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _map_status_rows(raw_state: dict[str, Any]) -> list[list[int | float]]:
|
| 44 |
+
raw_map = raw_state.get("map", {})
|
| 45 |
+
status = raw_map.get("status", [])
|
| 46 |
+
return status if isinstance(status, list) else []
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def count_known_tiles(raw_state: dict[str, Any]) -> int:
|
| 50 |
+
return sum(1 for row in _map_status_rows(raw_state) for value in row if value and value > 0)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def count_visible_tiles(raw_state: dict[str, Any]) -> int:
|
| 54 |
+
return sum(1 for row in _map_status_rows(raw_state) for value in row if value and value >= 2)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def extract_metrics(snapshot: RawSnapshot) -> SnapshotMetrics:
|
| 58 |
+
player = snapshot.state.get("player", {})
|
| 59 |
+
return SnapshotMetrics(
|
| 60 |
+
score=float(player.get("my_score", 0.0)),
|
| 61 |
+
known_tiles=count_known_tiles(snapshot.state),
|
| 62 |
+
visible_tiles=count_visible_tiles(snapshot.state),
|
| 63 |
+
city_count=len(snapshot.state.get("city", {})),
|
| 64 |
+
unit_count=len(snapshot.state.get("unit", {})),
|
| 65 |
+
techs_researched=int(player.get("my_techs_researched", 0) or 0),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def action_lookup_key(action: FreecivAction) -> ActionLookupKey:
|
| 70 |
+
if action.action_type == "move_unit":
|
| 71 |
+
return ("move_unit", action.unit_id, action.direction, None)
|
| 72 |
+
if action.action_type == "build_city":
|
| 73 |
+
return ("build_city", action.unit_id, None, None)
|
| 74 |
+
if action.action_type == "set_city_production":
|
| 75 |
+
return ("set_city_production", action.city_id, None, action.target)
|
| 76 |
+
if action.action_type == "set_research":
|
| 77 |
+
return ("set_research", None, None, action.target)
|
| 78 |
+
return ("end_turn", None, None, None)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _parse_target_name(raw_action_key: str, prefix: str) -> str:
|
| 82 |
+
suffix = raw_action_key.removeprefix(prefix)
|
| 83 |
+
name, _sep, _tail = suffix.rpartition("_")
|
| 84 |
+
return name or suffix
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _controller_actions(snapshot: RawSnapshot, controller: str) -> dict[str, Any]:
|
| 89 |
+
raw_actions = snapshot.actions.get(controller, {})
|
| 90 |
+
if isinstance(raw_actions, dict):
|
| 91 |
+
return raw_actions
|
| 92 |
+
if hasattr(raw_actions, "json_struct"):
|
| 93 |
+
json_actions = raw_actions.json_struct()
|
| 94 |
+
return json_actions if isinstance(json_actions, dict) else {}
|
| 95 |
+
return {}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _extract_legal_actions(snapshot: RawSnapshot) -> tuple[list[LegalAction], dict[ActionLookupKey, ActionRef]]:
|
| 100 |
+
legal_actions: list[LegalAction] = [
|
| 101 |
+
LegalAction(
|
| 102 |
+
action_type="end_turn",
|
| 103 |
+
label="End the current turn",
|
| 104 |
+
raw_action_key="__end_turn__",
|
| 105 |
+
)
|
| 106 |
+
]
|
| 107 |
+
refs: dict[ActionLookupKey, ActionRef] = {}
|
| 108 |
+
|
| 109 |
+
for actor_id, action_map in _controller_actions(snapshot, "unit").items():
|
| 110 |
+
unit_id = int(actor_id)
|
| 111 |
+
if action_map.get("build"):
|
| 112 |
+
legal_actions.append(
|
| 113 |
+
LegalAction(
|
| 114 |
+
action_type="build_city",
|
| 115 |
+
label=f"Build a city with unit {unit_id}",
|
| 116 |
+
unit_id=unit_id,
|
| 117 |
+
raw_action_key="build",
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
refs[("build_city", unit_id, None, None)] = ActionRef(
|
| 121 |
+
controller="unit",
|
| 122 |
+
actor_id=unit_id,
|
| 123 |
+
raw_action_key="build",
|
| 124 |
+
)
|
| 125 |
+
for raw_action_key, enabled in sorted(action_map.items()):
|
| 126 |
+
if not enabled or not raw_action_key.startswith("goto_"):
|
| 127 |
+
continue
|
| 128 |
+
direction = int(raw_action_key.split("_", 1)[1])
|
| 129 |
+
legal_actions.append(
|
| 130 |
+
LegalAction(
|
| 131 |
+
action_type="move_unit",
|
| 132 |
+
label=f"Move unit {unit_id} in direction {direction}",
|
| 133 |
+
unit_id=unit_id,
|
| 134 |
+
direction=direction,
|
| 135 |
+
raw_action_key=raw_action_key,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
refs[("move_unit", unit_id, direction, None)] = ActionRef(
|
| 139 |
+
controller="unit",
|
| 140 |
+
actor_id=unit_id,
|
| 141 |
+
raw_action_key=raw_action_key,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for actor_id, action_map in _controller_actions(snapshot, "city").items():
|
| 145 |
+
city_id = int(actor_id)
|
| 146 |
+
for raw_action_key, enabled in sorted(action_map.items()):
|
| 147 |
+
if not enabled:
|
| 148 |
+
continue
|
| 149 |
+
if raw_action_key.startswith("change_unit_prod_"):
|
| 150 |
+
target = _parse_target_name(raw_action_key, "change_unit_prod_")
|
| 151 |
+
elif raw_action_key.startswith("change_improve_prod_"):
|
| 152 |
+
target = _parse_target_name(raw_action_key, "change_improve_prod_")
|
| 153 |
+
else:
|
| 154 |
+
continue
|
| 155 |
+
legal_actions.append(
|
| 156 |
+
LegalAction(
|
| 157 |
+
action_type="set_city_production",
|
| 158 |
+
label=f"Set city {city_id} production to {target}",
|
| 159 |
+
city_id=city_id,
|
| 160 |
+
target=target,
|
| 161 |
+
raw_action_key=raw_action_key,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
refs[("set_city_production", city_id, None, target)] = ActionRef(
|
| 165 |
+
controller="city",
|
| 166 |
+
actor_id=city_id,
|
| 167 |
+
raw_action_key=raw_action_key,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
tech_actions = _controller_actions(snapshot, "tech").get("cur_player", {})
|
| 171 |
+
for raw_action_key, enabled in sorted(tech_actions.items()):
|
| 172 |
+
if not enabled or not raw_action_key.startswith("research_tech_"):
|
| 173 |
+
continue
|
| 174 |
+
target = _parse_target_name(raw_action_key, "research_tech_")
|
| 175 |
+
legal_actions.append(
|
| 176 |
+
LegalAction(
|
| 177 |
+
action_type="set_research",
|
| 178 |
+
label=f"Research {target}",
|
| 179 |
+
target=target,
|
| 180 |
+
raw_action_key=raw_action_key,
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
refs[("set_research", None, None, target)] = ActionRef(
|
| 184 |
+
controller="tech",
|
| 185 |
+
actor_id="cur_player",
|
| 186 |
+
raw_action_key=raw_action_key,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
legal_actions.sort(
|
| 190 |
+
key=lambda item: (
|
| 191 |
+
item.action_type,
|
| 192 |
+
item.unit_id or -1,
|
| 193 |
+
item.city_id or -1,
|
| 194 |
+
item.direction or -1,
|
| 195 |
+
item.target or "",
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
return legal_actions, refs
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _extract_unit_summaries(snapshot: RawSnapshot) -> list[UnitSummary]:
|
| 202 |
+
unit_actions = _controller_actions(snapshot, "unit")
|
| 203 |
+
units: list[UnitSummary] = []
|
| 204 |
+
for actor_id, unit in sorted(snapshot.state.get("unit", {}).items(), key=lambda item: int(item[0])):
|
| 205 |
+
action_map = unit_actions.get(str(actor_id), unit_actions.get(actor_id, {}))
|
| 206 |
+
move_directions = sorted(
|
| 207 |
+
int(raw_action_key.split("_", 1)[1])
|
| 208 |
+
for raw_action_key, enabled in action_map.items()
|
| 209 |
+
if enabled and raw_action_key.startswith("goto_")
|
| 210 |
+
)
|
| 211 |
+
units.append(
|
| 212 |
+
UnitSummary(
|
| 213 |
+
unit_id=int(actor_id),
|
| 214 |
+
unit_type=str(unit.get("type_rule_name", "Unknown")),
|
| 215 |
+
health=int(unit.get("health", 0) or 0),
|
| 216 |
+
moves_left=int(unit.get("moves_left", unit.get("movesleft", 0)) or 0),
|
| 217 |
+
home_city_id=(
|
| 218 |
+
int(unit.get("home_city"))
|
| 219 |
+
if unit.get("home_city") not in (None, -1, "")
|
| 220 |
+
else None
|
| 221 |
+
),
|
| 222 |
+
veteran_level=int(unit.get("veteran", 0) or 0),
|
| 223 |
+
can_build_city=bool(action_map.get("build", False)),
|
| 224 |
+
move_directions=move_directions,
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
return units
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _extract_city_summaries(snapshot: RawSnapshot) -> list[CitySummary]:
|
| 231 |
+
city_actions = _controller_actions(snapshot, "city")
|
| 232 |
+
cities: list[CitySummary] = []
|
| 233 |
+
for actor_id, city in sorted(snapshot.state.get("city", {}).items(), key=lambda item: int(item[0])):
|
| 234 |
+
action_map = city_actions.get(str(actor_id), city_actions.get(actor_id, {}))
|
| 235 |
+
production_options = [
|
| 236 |
+
_parse_target_name(raw_action_key, "change_unit_prod_")
|
| 237 |
+
for raw_action_key, enabled in sorted(action_map.items())
|
| 238 |
+
if enabled and raw_action_key.startswith("change_unit_prod_")
|
| 239 |
+
] + [
|
| 240 |
+
_parse_target_name(raw_action_key, "change_improve_prod_")
|
| 241 |
+
for raw_action_key, enabled in sorted(action_map.items())
|
| 242 |
+
if enabled and raw_action_key.startswith("change_improve_prod_")
|
| 243 |
+
]
|
| 244 |
+
cities.append(
|
| 245 |
+
CitySummary(
|
| 246 |
+
city_id=int(actor_id),
|
| 247 |
+
size=int(city.get("size", 0) or 0),
|
| 248 |
+
prod_food=int(city.get("prod_food", 0) or 0),
|
| 249 |
+
prod_shield=int(city.get("prod_shield", 0) or 0),
|
| 250 |
+
prod_trade=int(city.get("prod_trade", 0) or 0),
|
| 251 |
+
surplus_food=int(city.get("surplus_food", 0) or 0),
|
| 252 |
+
surplus_shield=int(city.get("surplus_shield", 0) or 0),
|
| 253 |
+
surplus_trade=int(city.get("surplus_trade", 0) or 0),
|
| 254 |
+
production_kind=(
|
| 255 |
+
int(city.get("production_kind"))
|
| 256 |
+
if city.get("production_kind") is not None
|
| 257 |
+
else None
|
| 258 |
+
),
|
| 259 |
+
production_value=(
|
| 260 |
+
int(city.get("production_value"))
|
| 261 |
+
if city.get("production_value") is not None
|
| 262 |
+
else None
|
| 263 |
+
),
|
| 264 |
+
turns_to_complete=(
|
| 265 |
+
float(city.get("turns_to_prod_complete"))
|
| 266 |
+
if city.get("turns_to_prod_complete") is not None
|
| 267 |
+
else None
|
| 268 |
+
),
|
| 269 |
+
production_options=production_options,
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
return cities
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _build_summary(
|
| 276 |
+
snapshot: RawSnapshot,
|
| 277 |
+
metrics: SnapshotMetrics,
|
| 278 |
+
units: list[UnitSummary],
|
| 279 |
+
cities: list[CitySummary],
|
| 280 |
+
legal_actions: list[LegalAction],
|
| 281 |
+
) -> str:
|
| 282 |
+
player = snapshot.state.get("player", {})
|
| 283 |
+
lines = [
|
| 284 |
+
f"Turn {snapshot.turn}",
|
| 285 |
+
f"Score {metrics.score:.1f}",
|
| 286 |
+
f"Map: {metrics.known_tiles} known tiles, {metrics.visible_tiles} visible tiles",
|
| 287 |
+
f"Economy: {player.get('my_gold', 0)} gold, science rate {player.get('my_science', 0)}%",
|
| 288 |
+
f"Cities: {metrics.city_count}",
|
| 289 |
+
]
|
| 290 |
+
for city in cities[:5]:
|
| 291 |
+
lines.append(
|
| 292 |
+
f"- City {city.city_id}: size {city.size}, food {city.prod_food}/{city.surplus_food:+d}, "
|
| 293 |
+
f"shields {city.prod_shield}/{city.surplus_shield:+d}, trade {city.prod_trade}/{city.surplus_trade:+d}"
|
| 294 |
+
)
|
| 295 |
+
lines.append(f"Units: {metrics.unit_count}")
|
| 296 |
+
for unit in units[:8]:
|
| 297 |
+
lines.append(
|
| 298 |
+
f"- Unit {unit.unit_id}: {unit.unit_type}, hp {unit.health}, moves_left {unit.moves_left}, "
|
| 299 |
+
f"build_city={str(unit.can_build_city).lower()}, move_dirs={unit.move_directions}"
|
| 300 |
+
)
|
| 301 |
+
lines.append(f"Techs researched: {metrics.techs_researched}")
|
| 302 |
+
lines.append(f"Legal actions exposed: {len(legal_actions)}")
|
| 303 |
+
return "\n".join(lines)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def prepare_observation(
|
| 307 |
+
snapshot: RawSnapshot,
|
| 308 |
+
*,
|
| 309 |
+
reward: float,
|
| 310 |
+
done: bool,
|
| 311 |
+
status: str,
|
| 312 |
+
metadata: dict[str, Any] | None = None,
|
| 313 |
+
) -> PreparedObservation:
|
| 314 |
+
legal_actions, action_refs = _extract_legal_actions(snapshot)
|
| 315 |
+
metrics = extract_metrics(snapshot)
|
| 316 |
+
units = _extract_unit_summaries(snapshot)
|
| 317 |
+
cities = _extract_city_summaries(snapshot)
|
| 318 |
+
observation = FreecivObservation(
|
| 319 |
+
turn=snapshot.turn,
|
| 320 |
+
score=metrics.score,
|
| 321 |
+
known_tiles=metrics.known_tiles,
|
| 322 |
+
visible_tiles=metrics.visible_tiles,
|
| 323 |
+
city_count=metrics.city_count,
|
| 324 |
+
unit_count=metrics.unit_count,
|
| 325 |
+
techs_researched=metrics.techs_researched,
|
| 326 |
+
status=status,
|
| 327 |
+
summary=_build_summary(snapshot, metrics, units, cities, legal_actions),
|
| 328 |
+
units=units,
|
| 329 |
+
cities=cities,
|
| 330 |
+
legal_actions=legal_actions,
|
| 331 |
+
reward=reward,
|
| 332 |
+
done=done,
|
| 333 |
+
metadata=metadata or {},
|
| 334 |
+
)
|
| 335 |
+
return PreparedObservation(observation=observation, metrics=metrics, action_refs=action_refs)
|
build/lib/freeciv_env/client.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from openenv.core.client_types import StepResult
|
| 4 |
+
from openenv.core.env_client import EnvClient
|
| 5 |
+
|
| 6 |
+
from freeciv_env.models import FreecivAction, FreecivObservation, FreecivState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FreecivEnv(EnvClient[FreecivAction, FreecivObservation, FreecivState]):
|
| 10 |
+
def _step_payload(self, action: FreecivAction) -> dict:
|
| 11 |
+
return action.model_dump(exclude_none=True)
|
| 12 |
+
|
| 13 |
+
def _parse_result(self, payload: dict) -> StepResult[FreecivObservation]:
|
| 14 |
+
observation = FreecivObservation(**payload["observation"])
|
| 15 |
+
return StepResult(
|
| 16 |
+
observation=observation,
|
| 17 |
+
reward=payload.get("reward"),
|
| 18 |
+
done=payload.get("done", False),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def _parse_state(self, payload: dict) -> FreecivState:
|
| 22 |
+
return FreecivState(**payload)
|
build/lib/freeciv_env/grpo.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
|
| 6 |
+
from freeciv_env.models import FreecivAction, FreecivObservation, LegalAction
|
| 7 |
+
|
| 8 |
+
SYSTEM_PROMPT = (
|
| 9 |
+
"You are choosing the next action for a Freeciv agent. "
|
| 10 |
+
"Return only the integer index of the best legal action. "
|
| 11 |
+
"Do not output words, punctuation, JSON, or explanations."
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
TASK_PROMPT = (
|
| 15 |
+
"Pick the legal action index that maximizes immediate reward. "
|
| 16 |
+
"Invalid actions are penalized. Shorter outputs are better."
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def format_action_line(index: int, action: LegalAction) -> str:
|
| 21 |
+
return f"{index}: {action.label}"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_turn_prompt(observation: FreecivObservation, task_prompt: str = TASK_PROMPT) -> str:
|
| 25 |
+
action_lines = [format_action_line(index, action) for index, action in enumerate(observation.legal_actions)]
|
| 26 |
+
return (
|
| 27 |
+
f"{task_prompt}\n\n"
|
| 28 |
+
f"State:\n{observation.summary}\n\n"
|
| 29 |
+
f"Legal actions:\n" + "\n".join(action_lines) + "\n\n"
|
| 30 |
+
"Return exactly one integer index."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_action_choice(completion_text: str, legal_actions: Iterable[LegalAction]) -> FreecivAction | None:
|
| 35 |
+
legal_actions = list(legal_actions)
|
| 36 |
+
match = re.search(r"-?\d+", completion_text)
|
| 37 |
+
if match is None:
|
| 38 |
+
return None
|
| 39 |
+
index = int(match.group(0))
|
| 40 |
+
if index < 0 or index >= len(legal_actions):
|
| 41 |
+
return None
|
| 42 |
+
action = legal_actions[index]
|
| 43 |
+
if action.action_type == "end_turn":
|
| 44 |
+
return FreecivAction(action_type="end_turn")
|
| 45 |
+
if action.action_type == "move_unit":
|
| 46 |
+
return FreecivAction(action_type="move_unit", unit_id=action.unit_id, direction=action.direction)
|
| 47 |
+
if action.action_type == "build_city":
|
| 48 |
+
return FreecivAction(action_type="build_city", unit_id=action.unit_id)
|
| 49 |
+
if action.action_type == "set_city_production":
|
| 50 |
+
return FreecivAction(action_type="set_city_production", city_id=action.city_id, target=action.target)
|
| 51 |
+
if action.action_type == "set_research":
|
| 52 |
+
return FreecivAction(action_type="set_research", target=action.target)
|
| 53 |
+
raise ValueError(f"unsupported action_type: {action.action_type}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def action_priority(action: LegalAction) -> tuple[int, int]:
|
| 57 |
+
if action.action_type == "build_city":
|
| 58 |
+
return (500, 0)
|
| 59 |
+
if action.action_type == "set_research":
|
| 60 |
+
return (400, 0)
|
| 61 |
+
if action.action_type == "set_city_production":
|
| 62 |
+
bonus = 50 if (action.target or "") == "Settlers" else 0
|
| 63 |
+
return (300 + bonus, 0)
|
| 64 |
+
if action.action_type == "move_unit":
|
| 65 |
+
return (200, -(action.direction or 0))
|
| 66 |
+
if action.action_type == "end_turn":
|
| 67 |
+
return (0, 0)
|
| 68 |
+
return (-1000, 0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def oracle_action_index(legal_actions: Iterable[LegalAction]) -> int:
|
| 73 |
+
legal_actions = list(legal_actions)
|
| 74 |
+
if not legal_actions:
|
| 75 |
+
raise ValueError("no legal actions available")
|
| 76 |
+
best_index = 0
|
| 77 |
+
best_priority = action_priority(legal_actions[0])
|
| 78 |
+
for index, action in enumerate(legal_actions[1:], start=1):
|
| 79 |
+
priority = action_priority(action)
|
| 80 |
+
if priority > best_priority:
|
| 81 |
+
best_index = index
|
| 82 |
+
best_priority = priority
|
| 83 |
+
return best_index
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def reward_from_oracle(completions, best_index, **kwargs):
|
| 88 |
+
del kwargs
|
| 89 |
+
rewards = []
|
| 90 |
+
for completion, expected in zip(completions, best_index):
|
| 91 |
+
match = re.search(r"-?\d+", completion if isinstance(completion, str) else str(completion))
|
| 92 |
+
if match is None:
|
| 93 |
+
rewards.append(-0.25)
|
| 94 |
+
continue
|
| 95 |
+
chosen = int(match.group(0))
|
| 96 |
+
rewards.append(1.0 if chosen == int(expected) else 0.0)
|
| 97 |
+
return rewards
|
build/lib/freeciv_env/models.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, model_validator
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UnitSummary(BaseModel):
|
| 11 |
+
unit_id: int = Field(..., description="Freeciv unit id")
|
| 12 |
+
unit_type: str = Field(..., description="Ruleset unit type name")
|
| 13 |
+
health: int = Field(0, description="Current health")
|
| 14 |
+
moves_left: int = Field(0, description="Movement points remaining")
|
| 15 |
+
home_city_id: int | None = Field(None, description="Home city id, if any")
|
| 16 |
+
veteran_level: int = Field(0, description="Veteran level")
|
| 17 |
+
can_build_city: bool = Field(False, description="Whether the unit can found a city now")
|
| 18 |
+
move_directions: list[int] = Field(default_factory=list, description="Legal move direction indexes")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CitySummary(BaseModel):
|
| 22 |
+
city_id: int = Field(..., description="Freeciv city id")
|
| 23 |
+
size: int = Field(..., description="Population size")
|
| 24 |
+
prod_food: int = Field(0, description="Gross food output")
|
| 25 |
+
prod_shield: int = Field(0, description="Gross shield output")
|
| 26 |
+
prod_trade: int = Field(0, description="Gross trade output")
|
| 27 |
+
surplus_food: int = Field(0, description="Net food surplus")
|
| 28 |
+
surplus_shield: int = Field(0, description="Net shield surplus")
|
| 29 |
+
surplus_trade: int = Field(0, description="Net trade surplus")
|
| 30 |
+
production_kind: int | None = Field(None, description="Current production kind enum from Freeciv")
|
| 31 |
+
production_value: int | None = Field(None, description="Current production value id from Freeciv")
|
| 32 |
+
turns_to_complete: float | None = Field(None, description="Turns until current production completes")
|
| 33 |
+
production_options: list[str] = Field(default_factory=list, description="Legal production targets")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LegalAction(BaseModel):
|
| 37 |
+
action_type: Literal[
|
| 38 |
+
"end_turn",
|
| 39 |
+
"move_unit",
|
| 40 |
+
"build_city",
|
| 41 |
+
"set_city_production",
|
| 42 |
+
"set_research",
|
| 43 |
+
]
|
| 44 |
+
label: str = Field(..., description="Human-readable action label")
|
| 45 |
+
unit_id: int | None = Field(None, description="Target unit id")
|
| 46 |
+
city_id: int | None = Field(None, description="Target city id")
|
| 47 |
+
direction: int | None = Field(None, description="Freeciv direction index 0..7")
|
| 48 |
+
target: str | None = Field(None, description="Production or tech target name")
|
| 49 |
+
raw_action_key: str | None = Field(None, description="Underlying freeciv-bot action key")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FreecivAction(Action):
|
| 53 |
+
action_type: Literal[
|
| 54 |
+
"end_turn",
|
| 55 |
+
"move_unit",
|
| 56 |
+
"build_city",
|
| 57 |
+
"set_city_production",
|
| 58 |
+
"set_research",
|
| 59 |
+
]
|
| 60 |
+
unit_id: int | None = None
|
| 61 |
+
city_id: int | None = None
|
| 62 |
+
direction: int | None = None
|
| 63 |
+
target: str | None = None
|
| 64 |
+
|
| 65 |
+
@model_validator(mode="after")
|
| 66 |
+
def validate_shape(self) -> "FreecivAction":
|
| 67 |
+
if self.action_type == "end_turn":
|
| 68 |
+
return self
|
| 69 |
+
if self.action_type == "move_unit":
|
| 70 |
+
if self.unit_id is None or self.direction is None:
|
| 71 |
+
raise ValueError("move_unit requires unit_id and direction")
|
| 72 |
+
return self
|
| 73 |
+
if self.action_type == "build_city":
|
| 74 |
+
if self.unit_id is None:
|
| 75 |
+
raise ValueError("build_city requires unit_id")
|
| 76 |
+
return self
|
| 77 |
+
if self.action_type == "set_city_production":
|
| 78 |
+
if self.city_id is None or not self.target:
|
| 79 |
+
raise ValueError("set_city_production requires city_id and target")
|
| 80 |
+
return self
|
| 81 |
+
if self.action_type == "set_research":
|
| 82 |
+
if not self.target:
|
| 83 |
+
raise ValueError("set_research requires target")
|
| 84 |
+
return self
|
| 85 |
+
raise ValueError(f"unsupported action_type: {self.action_type}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FreecivObservation(Observation):
|
| 89 |
+
turn: int = Field(..., description="Current game turn")
|
| 90 |
+
score: float = Field(..., description="Current player score")
|
| 91 |
+
known_tiles: int = Field(..., description="Tiles known to the player")
|
| 92 |
+
visible_tiles: int = Field(..., description="Tiles currently visible to the player")
|
| 93 |
+
city_count: int = Field(..., description="Number of owned cities")
|
| 94 |
+
unit_count: int = Field(..., description="Number of owned units")
|
| 95 |
+
techs_researched: int = Field(..., description="Number of researched techs")
|
| 96 |
+
status: str = Field("ok", description="High-level environment status")
|
| 97 |
+
summary: str = Field(..., description="Compact text summary for LLMs")
|
| 98 |
+
units: list[UnitSummary] = Field(default_factory=list, description="Compact unit summaries")
|
| 99 |
+
cities: list[CitySummary] = Field(default_factory=list, description="Compact city summaries")
|
| 100 |
+
legal_actions: list[LegalAction] = Field(default_factory=list, description="Legal actions exposed by the environment")
|
| 101 |
+
reward: float = Field(0.0, description="Reward from the last action")
|
| 102 |
+
done: bool = Field(False, description="Whether the episode is done")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class FreecivState(State):
|
| 106 |
+
turn: int = Field(0, description="Current game turn")
|
| 107 |
+
score: float = Field(0.0, description="Current player score")
|
| 108 |
+
known_tiles: int = Field(0, description="Known tiles")
|
| 109 |
+
visible_tiles: int = Field(0, description="Visible tiles")
|
| 110 |
+
city_count: int = Field(0, description="Owned city count")
|
| 111 |
+
unit_count: int = Field(0, description="Owned unit count")
|
| 112 |
+
techs_researched: int = Field(0, description="Researched tech count")
|
build/lib/freeciv_env/runtime.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
from typing import Protocol
|
| 8 |
+
from urllib.parse import urlencode, urlparse
|
| 9 |
+
from urllib.request import Request, urlopen
|
| 10 |
+
|
| 11 |
+
from freeciv_env.adapter import ActionRef, RawSnapshot
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FreecivSession(Protocol):
|
| 15 |
+
def reset(self, seed: int | None = None) -> RawSnapshot: ...
|
| 16 |
+
|
| 17 |
+
def apply_action(self, action_ref: ActionRef) -> RawSnapshot: ...
|
| 18 |
+
|
| 19 |
+
def end_turn(self) -> RawSnapshot: ...
|
| 20 |
+
|
| 21 |
+
def close(self) -> None: ...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _InteractiveBot:
|
| 25 |
+
def __init__(self, session: "LiveFreecivSession"):
|
| 26 |
+
from freecivbot.bot.base_bot import BaseBot
|
| 27 |
+
|
| 28 |
+
class InteractiveBotImpl(BaseBot):
|
| 29 |
+
def __init__(self, owner: "LiveFreecivSession"):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self._owner = owner
|
| 32 |
+
|
| 33 |
+
def conduct_turn(self, pplayer, info_controls, end_turn_hook):
|
| 34 |
+
super().conduct_turn(pplayer, info_controls, end_turn_hook)
|
| 35 |
+
self._publish_snapshot()
|
| 36 |
+
|
| 37 |
+
def calculate_next_move(self):
|
| 38 |
+
if self._turn_active:
|
| 39 |
+
self._publish_snapshot()
|
| 40 |
+
|
| 41 |
+
def _publish_snapshot(self):
|
| 42 |
+
self._acquire_state()
|
| 43 |
+
self._owner._publish_snapshot(
|
| 44 |
+
RawSnapshot(
|
| 45 |
+
turn=self.turn,
|
| 46 |
+
state=self._turn_state,
|
| 47 |
+
actions=self._turn_opts,
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.impl = InteractiveBotImpl(session)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class _ConfiguredCivClient:
|
| 55 |
+
def __init__(self, bot, user_name: str, *, client_port: int, visual_monitor: bool = False):
|
| 56 |
+
from freecivbot.civclient import CivClient
|
| 57 |
+
|
| 58 |
+
class ConfiguredCivClientImpl(CivClient):
|
| 59 |
+
def init_control(self, ws_client):
|
| 60 |
+
self.ws_client = ws_client
|
| 61 |
+
self.init_controller()
|
| 62 |
+
if self.visual_monitor:
|
| 63 |
+
self.monitor.start_monitor()
|
| 64 |
+
login_message = {
|
| 65 |
+
"pid": 4,
|
| 66 |
+
"username": self.user_name,
|
| 67 |
+
"capability": "+Freeciv.Web.Devel-3.2",
|
| 68 |
+
"version_label": "-dev",
|
| 69 |
+
"major_version": 3,
|
| 70 |
+
"minor_version": 1,
|
| 71 |
+
"patch_version": 90,
|
| 72 |
+
"port": self.client_port,
|
| 73 |
+
"password": None,
|
| 74 |
+
"subject": None,
|
| 75 |
+
}
|
| 76 |
+
self.ws_client.send(login_message)
|
| 77 |
+
|
| 78 |
+
def handle_chat_msg(self, packet):
|
| 79 |
+
from freecivbot.utils.fc_events import E_UNDEFINED
|
| 80 |
+
|
| 81 |
+
message = packet["message"]
|
| 82 |
+
conn_id = packet["conn_id"]
|
| 83 |
+
event = packet["event"]
|
| 84 |
+
|
| 85 |
+
if message is None:
|
| 86 |
+
return
|
| 87 |
+
if event is None or event < 0 or event >= E_UNDEFINED:
|
| 88 |
+
print("Undefined message event type")
|
| 89 |
+
print(packet)
|
| 90 |
+
print("\r\n")
|
| 91 |
+
packet["event"] = event = E_UNDEFINED
|
| 92 |
+
|
| 93 |
+
if conn_id in self.clstate.connections:
|
| 94 |
+
message = "<b>" + self.clstate.connections[conn_id]["username"] + ":</b>" + message
|
| 95 |
+
else:
|
| 96 |
+
if "/metamessage" in message:
|
| 97 |
+
return
|
| 98 |
+
if "Metaserver message string" in message:
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
packet["message"] = message
|
| 102 |
+
print(packet)
|
| 103 |
+
print("\r\n")
|
| 104 |
+
|
| 105 |
+
if "You are logged in as" in message:
|
| 106 |
+
self.ws_client.send_message("/set minplayers 1")
|
| 107 |
+
self.prepare_game()
|
| 108 |
+
|
| 109 |
+
def handle_conn_info(self, packet):
|
| 110 |
+
from freecivbot.connectivity.client_state import C_S_PREPARING
|
| 111 |
+
from freecivbot.utils.freecivlog import freelog
|
| 112 |
+
|
| 113 |
+
pconn = self.clstate.find_conn_by_id(packet["id"])
|
| 114 |
+
|
| 115 |
+
if not packet["used"]:
|
| 116 |
+
if pconn is None:
|
| 117 |
+
freelog(f"Server removed unknown connection {packet['id']}")
|
| 118 |
+
return
|
| 119 |
+
self.clstate.client_remove_cli_conn(pconn)
|
| 120 |
+
pconn = None
|
| 121 |
+
else:
|
| 122 |
+
pplayer = self.player_ctrl.valid_player_by_number(packet["player_num"])
|
| 123 |
+
if pplayer is None:
|
| 124 |
+
return
|
| 125 |
+
packet["playing"] = pplayer
|
| 126 |
+
|
| 127 |
+
if self.clstate.has_id(packet["id"]):
|
| 128 |
+
self.clstate.init_state(packet)
|
| 129 |
+
|
| 130 |
+
self.clstate.conn_list_append(packet)
|
| 131 |
+
|
| 132 |
+
if self.clstate.has_id(packet["id"]) and self.clstate.cur_player() != packet["playing"]:
|
| 133 |
+
self.clstate.set_client_state(C_S_PREPARING)
|
| 134 |
+
|
| 135 |
+
self.impl = ConfiguredCivClientImpl(
|
| 136 |
+
bot,
|
| 137 |
+
user_name,
|
| 138 |
+
client_port=client_port,
|
| 139 |
+
visual_monitor=visual_monitor,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class _ConfiguredCivConnection:
|
| 144 |
+
def __init__(self, civ_client, base_url: str, *, owner: "LiveFreecivSession", wait_for_server: int = 120, retry_interval: int = 5):
|
| 145 |
+
from math import ceil
|
| 146 |
+
|
| 147 |
+
import websocket
|
| 148 |
+
|
| 149 |
+
self._websocket = websocket
|
| 150 |
+
self.client = civ_client
|
| 151 |
+
self.base_url = base_url
|
| 152 |
+
self._owner = owner
|
| 153 |
+
self._loop = None
|
| 154 |
+
self._owner._connection = self
|
| 155 |
+
self.civserverport = self._reserve_client_port(base_url, civ_client.client_port)
|
| 156 |
+
self.client.client_port = self.civserverport
|
| 157 |
+
self.proxyport = 1000 + self.civserverport
|
| 158 |
+
self._retry_interval = retry_interval
|
| 159 |
+
self._num_retries = int(ceil(wait_for_server / retry_interval))
|
| 160 |
+
self._cur_retry = 0
|
| 161 |
+
self._ws_url = self._build_ws_url(base_url)
|
| 162 |
+
self.network_init()
|
| 163 |
+
|
| 164 |
+
def _build_ws_url(self, base_url: str) -> str:
|
| 165 |
+
parsed = urlparse(base_url)
|
| 166 |
+
scheme = "wss" if parsed.scheme == "https" else "ws"
|
| 167 |
+
host = parsed.hostname or "localhost"
|
| 168 |
+
port = parsed.port
|
| 169 |
+
if port is None:
|
| 170 |
+
port = 443 if scheme == "wss" else 80
|
| 171 |
+
return f"{scheme}://{host}:{port}/civsocket/{self.proxyport}"
|
| 172 |
+
|
| 173 |
+
def _reserve_client_port(self, base_url: str, requested_port: int) -> int:
|
| 174 |
+
parsed = urlparse(base_url)
|
| 175 |
+
scheme = parsed.scheme or "http"
|
| 176 |
+
host = parsed.hostname or "localhost"
|
| 177 |
+
port = parsed.port
|
| 178 |
+
if port is None:
|
| 179 |
+
port = 443 if scheme == "https" else 80
|
| 180 |
+
query = urlencode({"civserverport": requested_port})
|
| 181 |
+
launcher_url = f"{scheme}://{host}:{port}/civclientlauncher?{query}"
|
| 182 |
+
request = Request(launcher_url, method="POST")
|
| 183 |
+
with urlopen(request, timeout=10) as response:
|
| 184 |
+
result = response.headers.get("result")
|
| 185 |
+
reserved_port = response.headers.get("port")
|
| 186 |
+
if result != "success" or reserved_port is None:
|
| 187 |
+
raise RuntimeError(f"failed to reserve freeciv client port via {launcher_url}")
|
| 188 |
+
return int(reserved_port)
|
| 189 |
+
|
| 190 |
+
def _retry(self):
|
| 191 |
+
self._cur_retry += 1
|
| 192 |
+
time.sleep(self._retry_interval)
|
| 193 |
+
return self._detect_server_up()
|
| 194 |
+
|
| 195 |
+
def _detect_server_up(self):
|
| 196 |
+
ws = self._websocket.WebSocket()
|
| 197 |
+
try:
|
| 198 |
+
ws.connect(self._ws_url, timeout=10)
|
| 199 |
+
return True
|
| 200 |
+
except Exception as err:
|
| 201 |
+
print("Connect not successful:", err, " retrying in %s seconds." % self._retry_interval)
|
| 202 |
+
if self._cur_retry < self._num_retries:
|
| 203 |
+
return self._retry()
|
| 204 |
+
return False
|
| 205 |
+
finally:
|
| 206 |
+
try:
|
| 207 |
+
ws.close()
|
| 208 |
+
except Exception:
|
| 209 |
+
pass
|
| 210 |
+
|
| 211 |
+
def network_init(self):
|
| 212 |
+
self._cur_retry = 0
|
| 213 |
+
print("Connecting to server at %s ..." % self.base_url)
|
| 214 |
+
if self._detect_server_up():
|
| 215 |
+
self.websocket_init()
|
| 216 |
+
else:
|
| 217 |
+
print("Connection could not be established!")
|
| 218 |
+
|
| 219 |
+
def websocket_init(self):
|
| 220 |
+
from tornado import ioloop
|
| 221 |
+
|
| 222 |
+
from freecivbot.connectivity.clinet import CivWSClient
|
| 223 |
+
|
| 224 |
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
| 225 |
+
ioloop.IOLoop.clear_current()
|
| 226 |
+
self._loop = ioloop.IOLoop.current()
|
| 227 |
+
|
| 228 |
+
client = CivWSClient(self.client)
|
| 229 |
+
|
| 230 |
+
def send_json(data):
|
| 231 |
+
if not client._ws_connection:
|
| 232 |
+
raise RuntimeError("Web socket connection is closed.")
|
| 233 |
+
msg = json.dumps(data, separators=(",", ":"))
|
| 234 |
+
client._ws_connection.write_message(msg)
|
| 235 |
+
|
| 236 |
+
client.send = send_json
|
| 237 |
+
client.connect(self._ws_url)
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
self._loop.start()
|
| 241 |
+
except KeyboardInterrupt:
|
| 242 |
+
client.close()
|
| 243 |
+
|
| 244 |
+
def submit(self, fn) -> None:
|
| 245 |
+
if self._loop is None:
|
| 246 |
+
raise RuntimeError("freeciv connection loop is not ready")
|
| 247 |
+
done = threading.Event()
|
| 248 |
+
error: BaseException | None = None
|
| 249 |
+
|
| 250 |
+
def run():
|
| 251 |
+
nonlocal error
|
| 252 |
+
try:
|
| 253 |
+
fn()
|
| 254 |
+
except BaseException as exc:
|
| 255 |
+
error = exc
|
| 256 |
+
finally:
|
| 257 |
+
done.set()
|
| 258 |
+
|
| 259 |
+
self._loop.add_callback(run)
|
| 260 |
+
if not done.wait(timeout=10):
|
| 261 |
+
raise TimeoutError("timed out dispatching action to freeciv loop")
|
| 262 |
+
if error is not None:
|
| 263 |
+
raise error
|
| 264 |
+
|
| 265 |
+
def close(self) -> None:
|
| 266 |
+
if self._loop is None:
|
| 267 |
+
return
|
| 268 |
+
self.submit(self.client.close)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class LiveFreecivSession:
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
*,
|
| 275 |
+
username: str = "openenvbot",
|
| 276 |
+
client_port: int = 6000,
|
| 277 |
+
base_url: str = "http://localhost",
|
| 278 |
+
turn_timeout_s: float = 60.0,
|
| 279 |
+
):
|
| 280 |
+
self.username = username
|
| 281 |
+
self.client_port = client_port
|
| 282 |
+
self.base_url = base_url
|
| 283 |
+
self.turn_timeout_s = turn_timeout_s
|
| 284 |
+
|
| 285 |
+
self._bot_wrapper: _InteractiveBot | None = None
|
| 286 |
+
self._client = None
|
| 287 |
+
self._connection: _ConfiguredCivConnection | None = None
|
| 288 |
+
self._thread: threading.Thread | None = None
|
| 289 |
+
self._ready = threading.Event()
|
| 290 |
+
self._snapshot_lock = threading.Lock()
|
| 291 |
+
self._snapshot: RawSnapshot | None = None
|
| 292 |
+
self._thread_error: BaseException | None = None
|
| 293 |
+
self._reset_counter = 0
|
| 294 |
+
self._session_seed = time.monotonic_ns() % 1_000_000
|
| 295 |
+
|
| 296 |
+
def reset(self, seed: int | None = None) -> RawSnapshot:
|
| 297 |
+
del seed
|
| 298 |
+
self.close()
|
| 299 |
+
self._reset_counter += 1
|
| 300 |
+
username = self._next_username()
|
| 301 |
+
client_port = self.client_port + ((self._session_seed + self._reset_counter - 1) % 3)
|
| 302 |
+
|
| 303 |
+
self._ready.clear()
|
| 304 |
+
self._thread_error = None
|
| 305 |
+
self._snapshot = None
|
| 306 |
+
|
| 307 |
+
self._bot_wrapper = _InteractiveBot(self)
|
| 308 |
+
self._client = _ConfiguredCivClient(
|
| 309 |
+
self._bot_wrapper.impl,
|
| 310 |
+
username,
|
| 311 |
+
client_port=client_port,
|
| 312 |
+
visual_monitor=False,
|
| 313 |
+
).impl
|
| 314 |
+
|
| 315 |
+
def run() -> None:
|
| 316 |
+
try:
|
| 317 |
+
_ConfiguredCivConnection(self._client, self.base_url, owner=self)
|
| 318 |
+
except BaseException as exc: # pragma: no cover - surfaced in waiters
|
| 319 |
+
self._thread_error = exc
|
| 320 |
+
self._ready.set()
|
| 321 |
+
|
| 322 |
+
self._thread = threading.Thread(target=run, name="freeciv-live-session", daemon=True)
|
| 323 |
+
self._thread.start()
|
| 324 |
+
return self._wait_for_snapshot("reset")
|
| 325 |
+
|
| 326 |
+
def apply_action(self, action_ref: ActionRef) -> RawSnapshot:
|
| 327 |
+
snapshot = self._require_snapshot()
|
| 328 |
+
action_list = snapshot.actions[action_ref.controller]
|
| 329 |
+
valid_actions = action_list.get_actions(action_ref.actor_id, valid_only=True)
|
| 330 |
+
action = None if valid_actions is None else valid_actions.get(action_ref.raw_action_key)
|
| 331 |
+
if action is None:
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"action {action_ref.raw_action_key} is no longer valid for {action_ref.controller}:{action_ref.actor_id}"
|
| 334 |
+
)
|
| 335 |
+
self._ready.clear()
|
| 336 |
+
connection = self._require_connection()
|
| 337 |
+
connection.submit(lambda: action_list.trigger_validated_action(action))
|
| 338 |
+
return self._wait_for_snapshot(action_ref.raw_action_key)
|
| 339 |
+
|
| 340 |
+
def end_turn(self) -> RawSnapshot:
|
| 341 |
+
if self._bot_wrapper is None:
|
| 342 |
+
raise RuntimeError("session has not been reset")
|
| 343 |
+
self._ready.clear()
|
| 344 |
+
connection = self._require_connection()
|
| 345 |
+
connection.submit(self._bot_wrapper.impl.end_turn)
|
| 346 |
+
return self._wait_for_snapshot("end_turn")
|
| 347 |
+
|
| 348 |
+
def close(self) -> None:
|
| 349 |
+
if self._connection is not None:
|
| 350 |
+
try:
|
| 351 |
+
self._connection.close()
|
| 352 |
+
except Exception:
|
| 353 |
+
pass
|
| 354 |
+
elif self._client is not None:
|
| 355 |
+
try:
|
| 356 |
+
self._client.close()
|
| 357 |
+
except Exception:
|
| 358 |
+
pass
|
| 359 |
+
if self._thread is not None and self._thread.is_alive():
|
| 360 |
+
self._thread.join(timeout=5)
|
| 361 |
+
self._bot_wrapper = None
|
| 362 |
+
self._client = None
|
| 363 |
+
self._connection = None
|
| 364 |
+
self._thread = None
|
| 365 |
+
self._snapshot = None
|
| 366 |
+
self._thread_error = None
|
| 367 |
+
self._ready.clear()
|
| 368 |
+
|
| 369 |
+
def _publish_snapshot(self, snapshot: RawSnapshot) -> None:
|
| 370 |
+
with self._snapshot_lock:
|
| 371 |
+
self._snapshot = snapshot
|
| 372 |
+
self._ready.set()
|
| 373 |
+
|
| 374 |
+
def _next_username(self) -> str:
|
| 375 |
+
suffix = str(self._session_seed + self._reset_counter)
|
| 376 |
+
prefix_len = max(1, 31 - len(suffix))
|
| 377 |
+
return f"{self.username[:prefix_len]}{suffix}"
|
| 378 |
+
|
| 379 |
+
def _require_connection(self) -> _ConfiguredCivConnection:
|
| 380 |
+
if self._connection is None:
|
| 381 |
+
raise RuntimeError("freeciv connection is not ready")
|
| 382 |
+
return self._connection
|
| 383 |
+
|
| 384 |
+
def _require_snapshot(self) -> RawSnapshot:
|
| 385 |
+
with self._snapshot_lock:
|
| 386 |
+
if self._snapshot is None:
|
| 387 |
+
raise RuntimeError("no live snapshot is available")
|
| 388 |
+
return self._snapshot
|
| 389 |
+
|
| 390 |
+
def _wait_for_snapshot(self, reason: str) -> RawSnapshot:
|
| 391 |
+
deadline = time.monotonic() + self.turn_timeout_s
|
| 392 |
+
while time.monotonic() < deadline:
|
| 393 |
+
if self._thread_error is not None:
|
| 394 |
+
raise RuntimeError(f"freeciv session failed during {reason}") from self._thread_error
|
| 395 |
+
if self._ready.wait(timeout=0.1):
|
| 396 |
+
if self._thread_error is not None:
|
| 397 |
+
raise RuntimeError(f"freeciv session failed during {reason}") from self._thread_error
|
| 398 |
+
snapshot = self._require_snapshot()
|
| 399 |
+
if snapshot is not None:
|
| 400 |
+
return snapshot
|
| 401 |
+
raise TimeoutError(f"timed out waiting for freeciv snapshot during {reason}")
|
build/lib/freeciv_env/server/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from freeciv_env.server.freeciv_environment import FreecivEnvironment
|
| 2 |
+
|
| 3 |
+
__all__ = ["FreecivEnvironment"]
|
build/lib/freeciv_env/server/app.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server import create_app
|
| 6 |
+
|
| 7 |
+
from freeciv_env.models import FreecivAction, FreecivObservation
|
| 8 |
+
from freeciv_env.runtime import LiveFreecivSession
|
| 9 |
+
from freeciv_env.server.freeciv_environment import FreecivEnvironment
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_live_session() -> LiveFreecivSession:
|
| 13 |
+
return LiveFreecivSession(
|
| 14 |
+
username=os.getenv("FREECIV_USERNAME", "openenvbot"),
|
| 15 |
+
client_port=int(os.getenv("FREECIV_CLIENT_PORT", "6000")),
|
| 16 |
+
base_url=os.getenv("FREECIV_SERVER_URL", "http://localhost"),
|
| 17 |
+
turn_timeout_s=float(os.getenv("FREECIV_TURN_TIMEOUT_S", "60")),
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_freeciv_app(*, session_factory=create_live_session, max_turns: int | None = None):
|
| 22 |
+
if max_turns is None:
|
| 23 |
+
max_turns = int(os.getenv("FREECIV_MAX_TURNS", "50"))
|
| 24 |
+
return create_app(
|
| 25 |
+
lambda: FreecivEnvironment(session_factory=session_factory, max_turns=max_turns),
|
| 26 |
+
FreecivAction,
|
| 27 |
+
FreecivObservation,
|
| 28 |
+
env_name="freeciv_env",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
app = create_freeciv_app()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
import uvicorn
|
| 37 |
+
|
| 38 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, ws_ping_interval=300, ws_ping_timeout=300)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
build/lib/freeciv_env/server/freeciv_environment.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Callable
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.interfaces import Environment
|
| 7 |
+
|
| 8 |
+
from freeciv_env.adapter import (
|
| 9 |
+
ActionLookupKey,
|
| 10 |
+
ActionRef,
|
| 11 |
+
PreparedObservation,
|
| 12 |
+
RawSnapshot,
|
| 13 |
+
SnapshotMetrics,
|
| 14 |
+
action_lookup_key,
|
| 15 |
+
prepare_observation,
|
| 16 |
+
)
|
| 17 |
+
from freeciv_env.models import FreecivAction, FreecivObservation, FreecivState
|
| 18 |
+
from freeciv_env.runtime import FreecivSession
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FreecivEnvironment(Environment[FreecivAction, FreecivObservation, FreecivState]):
|
| 22 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 23 |
+
|
| 24 |
+
def __init__(self, session_factory: Callable[[], FreecivSession], max_turns: int = 50):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self._session_factory = session_factory
|
| 27 |
+
self.max_turns = max_turns
|
| 28 |
+
self._session: FreecivSession | None = None
|
| 29 |
+
self._snapshot: RawSnapshot | None = None
|
| 30 |
+
self._metrics: SnapshotMetrics | None = None
|
| 31 |
+
self._action_refs: dict[ActionLookupKey, ActionRef] = {}
|
| 32 |
+
self._state = FreecivState(episode_id=str(uuid4()), step_count=0)
|
| 33 |
+
|
| 34 |
+
def reset(
|
| 35 |
+
self,
|
| 36 |
+
seed: int | None = None,
|
| 37 |
+
episode_id: str | None = None,
|
| 38 |
+
**kwargs,
|
| 39 |
+
) -> FreecivObservation:
|
| 40 |
+
del kwargs
|
| 41 |
+
self.close()
|
| 42 |
+
self._session = self._session_factory()
|
| 43 |
+
snapshot = self._session.reset(seed=seed)
|
| 44 |
+
prepared = prepare_observation(
|
| 45 |
+
snapshot,
|
| 46 |
+
reward=0.0,
|
| 47 |
+
done=self._is_done(snapshot),
|
| 48 |
+
status="ready",
|
| 49 |
+
metadata={},
|
| 50 |
+
)
|
| 51 |
+
self._commit(snapshot, prepared, episode_id=episode_id or str(uuid4()))
|
| 52 |
+
return prepared.observation
|
| 53 |
+
|
| 54 |
+
def step(
|
| 55 |
+
self,
|
| 56 |
+
action: FreecivAction,
|
| 57 |
+
timeout_s: float | None = None,
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> FreecivObservation:
|
| 60 |
+
del timeout_s, kwargs
|
| 61 |
+
if self._session is None or self._snapshot is None or self._metrics is None:
|
| 62 |
+
raise RuntimeError("environment must be reset before step")
|
| 63 |
+
|
| 64 |
+
self._state.step_count += 1
|
| 65 |
+
if action.action_type == "end_turn":
|
| 66 |
+
next_snapshot = self._session.end_turn()
|
| 67 |
+
reward = self._reward_for_transition(action, self._metrics, next_snapshot)
|
| 68 |
+
prepared = prepare_observation(
|
| 69 |
+
next_snapshot,
|
| 70 |
+
reward=reward,
|
| 71 |
+
done=self._is_done(next_snapshot),
|
| 72 |
+
status="ok",
|
| 73 |
+
metadata={},
|
| 74 |
+
)
|
| 75 |
+
self._commit(next_snapshot, prepared, episode_id=self._state.episode_id)
|
| 76 |
+
return prepared.observation
|
| 77 |
+
|
| 78 |
+
ref = self._action_refs.get(action_lookup_key(action))
|
| 79 |
+
if ref is None:
|
| 80 |
+
prepared = prepare_observation(
|
| 81 |
+
self._snapshot,
|
| 82 |
+
reward=-0.25,
|
| 83 |
+
done=self._is_done(self._snapshot),
|
| 84 |
+
status="invalid_action",
|
| 85 |
+
metadata={"error": "action is not currently legal"},
|
| 86 |
+
)
|
| 87 |
+
self._commit(self._snapshot, prepared, episode_id=self._state.episode_id, replace_snapshot=False)
|
| 88 |
+
return prepared.observation
|
| 89 |
+
|
| 90 |
+
next_snapshot = self._session.apply_action(ref)
|
| 91 |
+
reward = self._reward_for_transition(action, self._metrics, next_snapshot)
|
| 92 |
+
prepared = prepare_observation(
|
| 93 |
+
next_snapshot,
|
| 94 |
+
reward=reward,
|
| 95 |
+
done=self._is_done(next_snapshot),
|
| 96 |
+
status="ok",
|
| 97 |
+
metadata={},
|
| 98 |
+
)
|
| 99 |
+
self._commit(next_snapshot, prepared, episode_id=self._state.episode_id)
|
| 100 |
+
return prepared.observation
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def state(self) -> FreecivState:
|
| 104 |
+
return self._state
|
| 105 |
+
|
| 106 |
+
def close(self) -> None:
|
| 107 |
+
if self._session is not None:
|
| 108 |
+
self._session.close()
|
| 109 |
+
self._session = None
|
| 110 |
+
self._snapshot = None
|
| 111 |
+
self._metrics = None
|
| 112 |
+
self._action_refs = {}
|
| 113 |
+
|
| 114 |
+
def _commit(
|
| 115 |
+
self,
|
| 116 |
+
snapshot: RawSnapshot,
|
| 117 |
+
prepared: PreparedObservation,
|
| 118 |
+
*,
|
| 119 |
+
episode_id: str,
|
| 120 |
+
replace_snapshot: bool = True,
|
| 121 |
+
) -> None:
|
| 122 |
+
if replace_snapshot:
|
| 123 |
+
self._snapshot = snapshot
|
| 124 |
+
self._metrics = prepared.metrics
|
| 125 |
+
self._action_refs = prepared.action_refs
|
| 126 |
+
self._state = FreecivState(
|
| 127 |
+
episode_id=episode_id,
|
| 128 |
+
step_count=self._state.step_count,
|
| 129 |
+
turn=prepared.observation.turn,
|
| 130 |
+
score=prepared.observation.score,
|
| 131 |
+
known_tiles=prepared.observation.known_tiles,
|
| 132 |
+
visible_tiles=prepared.observation.visible_tiles,
|
| 133 |
+
city_count=prepared.observation.city_count,
|
| 134 |
+
unit_count=prepared.observation.unit_count,
|
| 135 |
+
techs_researched=prepared.observation.techs_researched,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _reward_for_transition(
|
| 139 |
+
self,
|
| 140 |
+
action: FreecivAction,
|
| 141 |
+
previous: SnapshotMetrics,
|
| 142 |
+
next_snapshot: RawSnapshot,
|
| 143 |
+
) -> float:
|
| 144 |
+
from freeciv_env.adapter import extract_metrics
|
| 145 |
+
|
| 146 |
+
current = extract_metrics(next_snapshot)
|
| 147 |
+
reward = {
|
| 148 |
+
"end_turn": 0.0,
|
| 149 |
+
"move_unit": 0.01,
|
| 150 |
+
"build_city": 0.10,
|
| 151 |
+
"set_city_production": 0.05,
|
| 152 |
+
"set_research": 0.05,
|
| 153 |
+
}[action.action_type]
|
| 154 |
+
reward += max(current.score - previous.score, 0.0) * 0.02
|
| 155 |
+
reward += max(current.known_tiles - previous.known_tiles, 0) * 0.01
|
| 156 |
+
reward += max(current.city_count - previous.city_count, 0) * 0.50
|
| 157 |
+
reward += max(current.techs_researched - previous.techs_researched, 0) * 0.25
|
| 158 |
+
return float(reward)
|
| 159 |
+
|
| 160 |
+
def _is_done(self, snapshot: RawSnapshot) -> bool:
|
| 161 |
+
player = snapshot.state.get("player", {})
|
| 162 |
+
alive = bool(player.get("my_is_alive", True))
|
| 163 |
+
return (not alive) or snapshot.turn >= self.max_turns
|
build/lib/server/__init__.py
ADDED
|
File without changes
|
build/lib/server/app.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from freeciv_env.server.app import app as app
|
| 2 |
+
from freeciv_env.server.app import main as _main
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def main() -> None:
|
| 6 |
+
_main()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
main()
|
freeciv_env.egg-info/PKG-INFO
CHANGED
|
@@ -17,6 +17,19 @@ Requires-Dist: datasets>=4.0.0; extra == "train"
|
|
| 17 |
Requires-Dist: trl>=0.24.0; extra == "train"
|
| 18 |
Requires-Dist: unsloth>=2026.3.4; extra == "train"
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# freeciv-env
|
| 21 |
|
| 22 |
OpenEnv environment for Freeciv, built on top of `freeciv-bot`.
|
|
|
|
| 17 |
Requires-Dist: trl>=0.24.0; extra == "train"
|
| 18 |
Requires-Dist: unsloth>=2026.3.4; extra == "train"
|
| 19 |
|
| 20 |
+
---
|
| 21 |
+
title: Freeciv Environment Server
|
| 22 |
+
emoji: 🎮
|
| 23 |
+
colorFrom: blue
|
| 24 |
+
colorTo: indigo
|
| 25 |
+
sdk: docker
|
| 26 |
+
pinned: false
|
| 27 |
+
app_port: 8000
|
| 28 |
+
base_path: /web
|
| 29 |
+
tags:
|
| 30 |
+
- openenv
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
# freeciv-env
|
| 34 |
|
| 35 |
OpenEnv environment for Freeciv, built on top of `freeciv-bot`.
|
freeciv_env/server/Dockerfile
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
FROM ${BASE_IMAGE} AS builder
|
| 3 |
|
|
|
|
|
|
|
| 4 |
WORKDIR /app/env
|
| 5 |
COPY . /app/env
|
| 6 |
|
|
|
|
| 1 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
FROM ${BASE_IMAGE} AS builder
|
| 3 |
|
| 4 |
+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 5 |
+
|
| 6 |
WORKDIR /app/env
|
| 7 |
COPY . /app/env
|
| 8 |
|