Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import re | |
| from typing import Iterable | |
| from freeciv_env.models import FreecivAction, FreecivObservation, LegalAction | |
| SYSTEM_PROMPT = ( | |
| "You are choosing the next action for a Freeciv agent. " | |
| "Return only the integer index of the best legal action. " | |
| "Do not output words, punctuation, JSON, or explanations." | |
| ) | |
| TASK_PROMPT = ( | |
| "Pick the legal action index that maximizes immediate reward. " | |
| "Invalid actions are penalized. Shorter outputs are better." | |
| ) | |
| def format_action_line(index: int, action: LegalAction) -> str: | |
| return f"{index}: {action.label}" | |
| def build_turn_prompt(observation: FreecivObservation, task_prompt: str = TASK_PROMPT) -> str: | |
| action_lines = [format_action_line(index, action) for index, action in enumerate(observation.legal_actions)] | |
| return ( | |
| f"{task_prompt}\n\n" | |
| f"State:\n{observation.summary}\n\n" | |
| f"Legal actions:\n" + "\n".join(action_lines) + "\n\n" | |
| "Return exactly one integer index." | |
| ) | |
| def parse_action_choice(completion_text: str, legal_actions: Iterable[LegalAction]) -> FreecivAction | None: | |
| legal_actions = list(legal_actions) | |
| match = re.search(r"-?\d+", completion_text) | |
| if match is None: | |
| return None | |
| index = int(match.group(0)) | |
| if index < 0 or index >= len(legal_actions): | |
| return None | |
| action = legal_actions[index] | |
| if action.action_type == "end_turn": | |
| return FreecivAction(action_type="end_turn") | |
| if action.action_type == "move_unit": | |
| return FreecivAction(action_type="move_unit", unit_id=action.unit_id, direction=action.direction) | |
| if action.action_type == "build_city": | |
| return FreecivAction(action_type="build_city", unit_id=action.unit_id) | |
| if action.action_type == "set_city_production": | |
| return FreecivAction(action_type="set_city_production", city_id=action.city_id, target=action.target) | |
| if action.action_type == "set_research": | |
| return FreecivAction(action_type="set_research", target=action.target) | |
| raise ValueError(f"unsupported action_type: {action.action_type}") | |
| def action_priority(action: LegalAction) -> tuple[int, int]: | |
| if action.action_type == "build_city": | |
| return (500, 0) | |
| if action.action_type == "set_research": | |
| return (400, 0) | |
| if action.action_type == "set_city_production": | |
| bonus = 50 if (action.target or "") == "Settlers" else 0 | |
| return (300 + bonus, 0) | |
| if action.action_type == "move_unit": | |
| return (200, -(action.direction or 0)) | |
| if action.action_type == "end_turn": | |
| return (0, 0) | |
| return (-1000, 0) | |
| def oracle_action_index(legal_actions: Iterable[LegalAction]) -> int: | |
| legal_actions = list(legal_actions) | |
| if not legal_actions: | |
| raise ValueError("no legal actions available") | |
| best_index = 0 | |
| best_priority = action_priority(legal_actions[0]) | |
| for index, action in enumerate(legal_actions[1:], start=1): | |
| priority = action_priority(action) | |
| if priority > best_priority: | |
| best_index = index | |
| best_priority = priority | |
| return best_index | |
| def reward_from_oracle(completions, best_index, **kwargs): | |
| del kwargs | |
| rewards = [] | |
| for completion, expected in zip(completions, best_index): | |
| match = re.search(r"-?\d+", completion if isinstance(completion, str) else str(completion)) | |
| if match is None: | |
| rewards.append(-0.25) | |
| continue | |
| chosen = int(match.group(0)) | |
| rewards.append(1.0 if chosen == int(expected) else 0.0) | |
| return rewards | |