thomasm6m6's picture
Initial Freeciv OpenEnv Space
8dc7642 verified
from __future__ import annotations
import re
from typing import Iterable
from freeciv_env.models import FreecivAction, FreecivObservation, LegalAction
SYSTEM_PROMPT = (
"You are choosing the next action for a Freeciv agent. "
"Return only the integer index of the best legal action. "
"Do not output words, punctuation, JSON, or explanations."
)
TASK_PROMPT = (
"Pick the legal action index that maximizes immediate reward. "
"Invalid actions are penalized. Shorter outputs are better."
)
def format_action_line(index: int, action: LegalAction) -> str:
return f"{index}: {action.label}"
def build_turn_prompt(observation: FreecivObservation, task_prompt: str = TASK_PROMPT) -> str:
action_lines = [format_action_line(index, action) for index, action in enumerate(observation.legal_actions)]
return (
f"{task_prompt}\n\n"
f"State:\n{observation.summary}\n\n"
f"Legal actions:\n" + "\n".join(action_lines) + "\n\n"
"Return exactly one integer index."
)
def parse_action_choice(completion_text: str, legal_actions: Iterable[LegalAction]) -> FreecivAction | None:
legal_actions = list(legal_actions)
match = re.search(r"-?\d+", completion_text)
if match is None:
return None
index = int(match.group(0))
if index < 0 or index >= len(legal_actions):
return None
action = legal_actions[index]
if action.action_type == "end_turn":
return FreecivAction(action_type="end_turn")
if action.action_type == "move_unit":
return FreecivAction(action_type="move_unit", unit_id=action.unit_id, direction=action.direction)
if action.action_type == "build_city":
return FreecivAction(action_type="build_city", unit_id=action.unit_id)
if action.action_type == "set_city_production":
return FreecivAction(action_type="set_city_production", city_id=action.city_id, target=action.target)
if action.action_type == "set_research":
return FreecivAction(action_type="set_research", target=action.target)
raise ValueError(f"unsupported action_type: {action.action_type}")
def action_priority(action: LegalAction) -> tuple[int, int]:
if action.action_type == "build_city":
return (500, 0)
if action.action_type == "set_research":
return (400, 0)
if action.action_type == "set_city_production":
bonus = 50 if (action.target or "") == "Settlers" else 0
return (300 + bonus, 0)
if action.action_type == "move_unit":
return (200, -(action.direction or 0))
if action.action_type == "end_turn":
return (0, 0)
return (-1000, 0)
def oracle_action_index(legal_actions: Iterable[LegalAction]) -> int:
legal_actions = list(legal_actions)
if not legal_actions:
raise ValueError("no legal actions available")
best_index = 0
best_priority = action_priority(legal_actions[0])
for index, action in enumerate(legal_actions[1:], start=1):
priority = action_priority(action)
if priority > best_priority:
best_index = index
best_priority = priority
return best_index
def reward_from_oracle(completions, best_index, **kwargs):
del kwargs
rewards = []
for completion, expected in zip(completions, best_index):
match = re.search(r"-?\d+", completion if isinstance(completion, str) else str(completion))
if match is None:
rewards.append(-0.25)
continue
chosen = int(match.group(0))
rewards.append(1.0 if chosen == int(expected) else 0.0)
return rewards