| """LLM agent that plays Red Alert using any OpenAI-compatible model. |
| |
| Supports OpenRouter, Ollama, LM Studio, or any local/remote endpoint |
| that implements the OpenAI Chat Completions API with tool calling. |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import time |
|
|
| from collections import defaultdict |
|
|
| import httpx |
| from openra_env.config import LLMConfig |
| from openra_env.game_data import get_building_stats, get_faction_info, get_tech_tree, get_unit_stats |
| from openra_env.mcp_ws_client import OpenRAMCPClient |
|
|
| logger = logging.getLogger("llm_agent") |
|
|
|
|
| def _looks_like_tool_capability_error(error_text: str) -> bool: |
| """Best-effort detection of provider errors indicating no tool support.""" |
| text = error_text.lower() |
| |
| |
| if "no endpoints found" in text and "tool" in text: |
| return True |
| markers = ( |
| "support tool use", |
| "does not support tool", |
| "tool calling", |
| "tools are not supported", |
| ) |
| return any(m in text for m in markers) |
|
|
|
|
| def _bench_export_policy(encountered_agent_error: bool) -> tuple[bool, bool, str]: |
| """Decide whether bench export and upload should run for this match. |
| |
| Returns: |
| (should_export, should_upload, reason) |
| Local export always happens (useful for debugging). |
| Upload is skipped when runtime errors occurred. |
| """ |
| if encountered_agent_error: |
| return True, False, "runtime [ERROR] occurred during the match" |
| return True, True, "" |
|
|
|
|
| def _format_llm_api_error(status_code: int, error_text: str, llm_config: LLMConfig) -> str: |
| """Map raw provider errors to clear, actionable runtime messages.""" |
| error_lower = error_text.lower() |
|
|
| if status_code in (401, 403): |
| return ( |
| f"Authentication failed ({status_code}). " |
| "Check your API key: openra-rl config" |
| ) |
|
|
| if status_code == 400 and "model" in error_lower: |
| return ( |
| f"Invalid model ID '{llm_config.model}'. " |
| "Update with: openra-rl config" |
| ) |
|
|
| if status_code == 429: |
| return "Rate limited by LLM provider. Wait a minute and retry." |
|
|
| if status_code == 404 and _looks_like_tool_capability_error(error_text): |
| is_openrouter = "openrouter.ai" in llm_config.base_url.lower() |
| if is_openrouter: |
| return ( |
| f"Model '{llm_config.model}' has no OpenRouter route that supports tool calling. " |
| "OpenRA-RL requires tool-calling models. " |
| "Use a tool-capable model/route (often not ':free'), or use Ollama " |
| "with qwen3:32b or qwen3:4b." |
| ) |
| return ( |
| f"Model '{llm_config.model}' does not support tool calling on this endpoint. " |
| "OpenRA-RL requires tool-calling models." |
| ) |
|
|
| return f"LLM API error {status_code}: {error_text}" |
|
|
|
|
| async def _preflight_tool_calling_support(llm_config: LLMConfig) -> tuple[bool, str]: |
| """Check OpenRouter model route support for tool calling before game start. |
| |
| Returns: |
| (True, "") when preflight passes or does not apply. |
| (False, reason) when preflight confirms tools are unsupported. |
| """ |
| if "openrouter.ai" not in llm_config.base_url.lower(): |
| return True, "" |
|
|
| preflight_cfg = llm_config.model_copy( |
| update={ |
| "max_tokens": 1, |
| "request_timeout_s": min(llm_config.request_timeout_s, 30.0), |
| } |
| ) |
| preflight_messages = [ |
| {"role": "user", "content": "Tool-calling preflight check. Reply briefly."}, |
| ] |
| preflight_tools = [ |
| { |
| "type": "function", |
| "function": { |
| "name": "preflight_ping", |
| "description": "Preflight-only tool for capability check.", |
| "parameters": {"type": "object", "properties": {}}, |
| }, |
| } |
| ] |
| try: |
| await chat_completion(preflight_messages, preflight_tools, preflight_cfg, verbose=False, prompts=None) |
| return True, "" |
| except RuntimeError as e: |
| msg = str(e) |
| if _looks_like_tool_capability_error(msg): |
| return False, msg |
| raise |
|
|
|
|
| def _load_default_prompt() -> str: |
| """Load the default system prompt shipped with the package.""" |
| from openra_env.prompts import load_default_prompt |
| return load_default_prompt() |
|
|
|
|
| |
| SYSTEM_PROMPT = _load_default_prompt() |
|
|
|
|
| def load_system_prompt(config) -> str: |
| """Resolve system prompt from config: inline > file > default. |
| |
| Priority: |
| 1. config.prompts.system_prompt (inline string) |
| 2. config.prompts.system_prompt_file (path to .txt file) |
| 3. config.agent.system_prompt (deprecated, backward compat) |
| 4. config.agent.system_prompt_file (deprecated, backward compat) |
| 5. Built-in default (openra_env/prompts/default.txt) |
| """ |
| from pathlib import Path |
|
|
| |
| prompts_cfg = getattr(config, "prompts", None) |
| if prompts_cfg: |
| if getattr(prompts_cfg, "system_prompt", ""): |
| return prompts_cfg.system_prompt |
| prompt_file = getattr(prompts_cfg, "system_prompt_file", "") |
| if prompt_file: |
| p = Path(prompt_file).expanduser() |
| if p.is_file(): |
| return p.read_text(encoding="utf-8").strip() |
| raise FileNotFoundError(f"system_prompt_file not found: {p}") |
|
|
| |
| agent_cfg = config.agent if hasattr(config, "agent") else config |
| if getattr(agent_cfg, "system_prompt", ""): |
| return agent_cfg.system_prompt |
| prompt_file = getattr(agent_cfg, "system_prompt_file", "") |
| if prompt_file: |
| p = Path(prompt_file).expanduser() |
| if p.is_file(): |
| return p.read_text(encoding="utf-8").strip() |
| raise FileNotFoundError(f"system_prompt_file not found: {p}") |
|
|
| |
| return SYSTEM_PROMPT |
|
|
|
|
| def compose_pregame_briefing(state: dict) -> str: |
| """Compose a strategic briefing from initial game state + static game data. |
| |
| Sent once at game start so the LLM knows map, base position, faction, tech tree, |
| and available units/buildings without needing extra tool calls. |
| """ |
| map_info = state.get("map", {}) |
| map_w = map_info.get("width", 0) |
| map_h = map_info.get("height", 0) |
| map_name = map_info.get("map_name", "?") |
|
|
| |
| buildings = state.get("buildings_summary", []) |
| units = state.get("units_summary", []) |
| all_positions = [(b["cell_x"], b["cell_y"]) for b in buildings] + \ |
| [(u["cell_x"], u["cell_y"]) for u in units] |
| if all_positions: |
| base_x = sum(p[0] for p in all_positions) // len(all_positions) |
| base_y = sum(p[1] for p in all_positions) // len(all_positions) |
| else: |
| base_x, base_y = map_w // 2, map_h // 2 |
|
|
| |
| enemy_x = max(2, min(map_w - 2, map_w - base_x)) |
| enemy_y = max(2, min(map_h - 2, map_h - base_y)) |
|
|
| |
| faction = state.get("faction", "") |
| allied_factions = {"england", "france", "germany"} |
| soviet_factions = {"russia", "ukraine"} |
| if faction in allied_factions: |
| side = "Allied" |
| barracks = "tent" |
| elif faction in soviet_factions: |
| side = "Soviet" |
| barracks = "barr" |
| else: |
| |
| avail = state.get("available_production", []) |
| bldg_types = state.get("building_types", []) |
| if "tent" in avail or "tent" in bldg_types: |
| side, barracks = "Allied", "tent" |
| else: |
| side, barracks = "Soviet", "barr" |
|
|
| |
| tech = get_tech_tree(side.lower()) |
| tech_order = tech.get(side.lower(), tech.get("build_order", [])) |
|
|
| |
| faction_info = get_faction_info(faction) if faction else get_faction_info(side.lower()) |
| avail_units = faction_info.get("available_units", []) if faction_info else [] |
| avail_buildings = faction_info.get("available_buildings", []) if faction_info else [] |
|
|
| |
| unit_lines = [] |
| for utype in avail_units[:12]: |
| stats = get_unit_stats(utype) |
| if stats: |
| unit_lines.append(f" {utype}: {stats['name']} — ${stats['cost']}, {stats.get('category', '?')}") |
|
|
| |
| bldg_lines = [] |
| for btype in avail_buildings[:10]: |
| stats = get_building_stats(btype) |
| if stats: |
| power = stats.get("power", 0) |
| power_str = f", {power:+d} power" if power else "" |
| bldg_lines.append(f" {btype}: {stats['name']} — ${stats['cost']}{power_str}") |
|
|
| |
| dx = enemy_x - base_x |
| dy = enemy_y - base_y |
| dir_parts = [] |
| if dy < -map_h // 6: |
| dir_parts.append("North") |
| elif dy > map_h // 6: |
| dir_parts.append("South") |
| if dx > map_w // 6: |
| dir_parts.append("East") |
| elif dx < -map_w // 6: |
| dir_parts.append("West") |
| defense_direction = "".join(dir_parts) if dir_parts else "Center" |
|
|
| parts = [ |
| "## Strategic Briefing", |
| f"Map: {map_name} ({map_w}x{map_h})", |
| f"Your faction: {faction or side} ({side})", |
| f"Your base: ({base_x}, {base_y})", |
| f"Enemy likely near: ({enemy_x}, {enemy_y})", |
| f"Enemy approach direction: {defense_direction}", |
| "", |
| f"Tech tree: {' → '.join(tech_order[:8])}{'...' if len(tech_order) > 8 else ''}", |
| f"Barracks type: {barracks}", |
| "", |
| "Available units:", |
| *unit_lines, |
| "", |
| "Available buildings:", |
| *bldg_lines, |
| ] |
| return "\n".join(parts) |
|
|
|
|
| def format_state_briefing(state: dict) -> str: |
| """Format game state (from get_game_state tool) into a compact turn briefing with positions.""" |
| if not isinstance(state, dict) or "tick" not in state: |
| return "" |
|
|
| eco = state.get("economy", {}) |
| tick = state["tick"] |
| cash = eco.get("cash", 0) |
| ore = eco.get("ore", 0) |
| funds = cash + ore |
|
|
| parts = [ |
| f"--- TURN BRIEFING (tick {tick}, ~{tick // 25}s game time) ---", |
| f"Funds: ${funds} (cash=${cash} + ore=${ore}) | Power: {state.get('power_balance', 0):+d} | Harvesters: {eco.get('harvester_count', 0)} | Explored: {state.get('explored_percent', 0)}%", |
| ] |
|
|
| |
| minimap = state.get("minimap", "") |
| if minimap: |
| parts.append(minimap) |
|
|
| |
| buildings = state.get("buildings_summary", []) |
| if buildings: |
| base_x = sum(b["cell_x"] for b in buildings) // len(buildings) |
| base_y = sum(b["cell_y"] for b in buildings) // len(buildings) |
| parts.append(f"Base center: ({base_x},{base_y})") |
|
|
| |
| units = state.get("units_summary", []) |
| if units: |
| by_type = defaultdict(list) |
| idle_ids = [] |
| for u in units: |
| by_type[u["type"]].append(u) |
| if u.get("idle") and u.get("can_attack"): |
| idle_ids.append(u["id"]) |
| unit_parts = [] |
| for utype, us in by_type.items(): |
| entries = [] |
| for u in us: |
| pos = f"{u['id']}@({u['cell_x']},{u['cell_y']})" |
| if u.get("target_x") is not None: |
| pos += f"→({u['target_x']},{u['target_y']})" |
| elif not u.get("idle"): |
| |
| act = u.get("activity", "") |
| if act and act not in ("Idle", "Unknown", "Wait"): |
| tag = act[:3].lower() |
| pos += f"→{tag}" |
| entries.append(pos) |
| unit_parts.append(f"{len(us)}x{utype}[{','.join(entries)}]") |
| line = f"Units: {' '.join(unit_parts)}" |
| if idle_ids: |
| line += f" | Idle: [{','.join(str(i) for i in idle_ids)}]" |
| parts.append(line) |
| else: |
| parts.append(f"Units: {state.get('own_units', '?')}") |
|
|
| |
| _BLDG_CATEGORY = {"tent": "infantry", "barr": "infantry", "weap": "vehicle", |
| "hpad": "aircraft", "afld": "aircraft", "syrd": "ship", "spen": "ship", |
| "gun": "defense", "ftur": "defense", "tsla": "defense", |
| "sam": "defense", "agun": "defense", "pbox": "defense", "hbox": "defense"} |
| if buildings: |
| bldg_parts = [] |
| for b in buildings: |
| cat = _BLDG_CATEGORY.get(b["type"], "") |
| cat_str = f"[{cat}]" if cat else "" |
| bldg_parts.append(f"{b['type']}({b['id']})@({b['cell_x']},{b['cell_y']}){cat_str}") |
| parts.append(f"Buildings: {' '.join(bldg_parts)}") |
| else: |
| parts.append(f"Buildings: {state.get('own_buildings', '?')} ({', '.join(state.get('building_types', []))})") |
|
|
| |
| enemies = state.get("enemy_summary", []) |
| enemy_bldgs = state.get("enemy_buildings_summary", []) |
| if enemies or enemy_bldgs: |
| enemy_parts = [] |
| if enemies: |
| eby_type = defaultdict(list) |
| for e in enemies: |
| eby_type[e["type"]].append(e) |
| for etype, es in eby_type.items(): |
| entries = ",".join(f"{e['id']}@({e['cell_x']},{e['cell_y']})" for e in es) |
| enemy_parts.append(f"{len(es)}x{etype}[{entries}]") |
| if enemy_bldgs: |
| ebby_type = defaultdict(list) |
| for b in enemy_bldgs: |
| ebby_type[b["type"]].append(b) |
| for btype, bs in ebby_type.items(): |
| entries = ",".join(f"{b['id']}@({b['cell_x']},{b['cell_y']})" for b in bs) |
| enemy_parts.append(f"{len(bs)}x{btype}[{entries}]") |
| |
| all_enemy_pos = ( |
| [(e["cell_x"], e["cell_y"]) for e in enemies] |
| + [(b["cell_x"], b["cell_y"]) for b in enemy_bldgs] |
| ) |
| avg_x = sum(p[0] for p in all_enemy_pos) // len(all_enemy_pos) |
| avg_y = sum(p[1] for p in all_enemy_pos) // len(all_enemy_pos) |
| parts.append(f"Enemies: {' '.join(enemy_parts)} center ({avg_x},{avg_y})") |
| else: |
| n_enemy = state.get("visible_enemy_units", 0) |
| parts.append(f"Enemies: {'none visible' if n_enemy == 0 else f'{n_enemy} visible'}") |
|
|
| prod = state.get("production_items", []) |
| if prod: |
| active = [p for p in prod if "@100%" not in p] |
| ready = [p.split("@")[0] for p in prod if "@100%" in p] |
| parts_prod = [] |
| if active: |
| parts_prod.append(", ".join(active)) |
| if ready: |
| parts_prod.append(f"READY TO PLACE: {', '.join(ready)}") |
| parts.append(f"Production: {' | '.join(parts_prod)}") |
| else: |
| parts.append("Production: IDLE") |
|
|
| available = state.get("available_production", []) |
| if available: |
| parts.append(f"Can build: {', '.join(available)}") |
|
|
| alerts = state.get("alerts", []) |
| if alerts: |
| parts.append("ALERTS:") |
| for a in alerts: |
| parts.append(f" ** {a}") |
|
|
| parts.append("---") |
|
|
| if state.get("done"): |
| parts.append(f"GAME OVER: {state.get('result', '?')}") |
|
|
| return "\n".join(parts) |
|
|
|
|
| def mcp_tools_to_openai(tools: list) -> list[dict]: |
| """Convert MCP Tool schemas to OpenAI function calling format.""" |
| result = [] |
| for tool in tools: |
| schema = tool.input_schema if hasattr(tool, 'input_schema') else {} |
| |
| params = dict(schema) if schema else {} |
| params.pop("title", None) |
| if "properties" not in params: |
| params["properties"] = {} |
| params["type"] = "object" |
|
|
| result.append({ |
| "type": "function", |
| "function": { |
| "name": tool.name, |
| "description": tool.description or "", |
| "parameters": params, |
| }, |
| }) |
| return result |
|
|
|
|
| def _sanitize_messages(messages: list[dict], prompts=None) -> list[dict]: |
| """Merge consecutive same-role messages for strict-alternation models (e.g. Mistral). |
| |
| Some models require strict user/assistant alternation and reject sequences |
| like ``user → user`` or ``tool → user``. This helper: |
| 1. Merges consecutive ``user`` messages by joining their content with newlines. |
| 2. Inserts a bridge ``assistant`` message when a ``tool`` result is followed |
| by a ``user`` message (Mistral requires tool → assistant → user). |
| """ |
| if not messages: |
| return messages |
|
|
| bridge = prompts.sanitize_bridge if prompts else "Acknowledged. Continuing." |
| merged: list[dict] = [dict(messages[0])] |
| for msg in messages[1:]: |
| prev = merged[-1] |
| |
| if msg["role"] == "user" and prev["role"] == "user": |
| merged[-1] = {**prev, "content": prev["content"] + "\n\n" + msg["content"]} |
| continue |
| |
| if msg["role"] == "user" and prev["role"] == "tool": |
| merged.append({"role": "assistant", "content": bridge}) |
| merged.append(msg) |
| return merged |
|
|
|
|
| async def chat_completion( |
| messages: list[dict], |
| tools: list[dict], |
| llm_config: LLMConfig, |
| verbose: bool = False, |
| prompts=None, |
| ) -> dict: |
| """Call an OpenAI-compatible chat completions API. |
| |
| Works with OpenRouter, Ollama, LM Studio, or any endpoint |
| implementing the OpenAI Chat Completions spec with tool calling. |
| """ |
| clean_messages = _sanitize_messages(messages, prompts=prompts) |
| payload = { |
| "model": llm_config.model, |
| "messages": clean_messages, |
| "max_tokens": llm_config.max_tokens, |
| } |
| if tools: |
| payload["tools"] = tools |
| payload["tool_choice"] = "auto" |
| if llm_config.temperature is not None: |
| payload["temperature"] = llm_config.temperature |
| if llm_config.top_p is not None: |
| payload["top_p"] = llm_config.top_p |
| if llm_config.reasoning_effort is not None: |
| payload["reasoning"] = {"effort": llm_config.reasoning_effort} |
|
|
| headers = dict(llm_config.extra_headers) |
| if llm_config.api_key: |
| headers["Authorization"] = f"Bearer {llm_config.api_key}" |
|
|
| async with httpx.AsyncClient() as client: |
| if verbose: |
| n_msgs = len(clean_messages) |
| roles = [m.get("role", "?") for m in clean_messages] |
| print(f" [LLM] Sending {n_msgs} messages to {llm_config.model}...") |
| print(f" [LLM] Roles: {' → '.join(roles)}") |
|
|
| response = await client.post( |
| llm_config.base_url, |
| headers=headers, |
| json=payload, |
| timeout=llm_config.request_timeout_s, |
| ) |
|
|
| if response.status_code != 200: |
| error_text = response.text[:2000] |
| raise RuntimeError( |
| _format_llm_api_error(response.status_code, error_text, llm_config) |
| ) |
|
|
| try: |
| data = response.json() |
| except (json.JSONDecodeError, ValueError) as e: |
| raise RuntimeError(f"LLM API error 502: invalid JSON response ({e})") |
|
|
| if "error" in data: |
| raise RuntimeError(f"LLM API error 500: {data['error']}") |
|
|
| if verbose: |
| usage = data.get("usage", {}) |
| print( |
| f" [LLM] Response: {usage.get('prompt_tokens', '?')} prompt + " |
| f"{usage.get('completion_tokens', '?')} completion tokens" |
| ) |
|
|
| return data |
|
|
|
|
| def compress_history(messages: list[dict], keep_last: int = 40, |
| trigger: int = 0, prompts=None, compression=None) -> list[dict]: |
| """Compress message history to stay within context limits. |
| |
| Keeps the system prompt and the last ``keep_last`` messages, replacing |
| earlier messages with a state-aware summary that preserves critical |
| game context (buildings, economy, strategy, military, errors). |
| |
| Args: |
| keep_last: Number of recent messages to keep after compression. |
| trigger: Compress when total messages exceed this threshold. |
| 0 (default) means ``keep_last * 2``. |
| prompts: PromptsConfig for customizable text. |
| compression: CompressionConfig controlling what to include in summary. |
| """ |
| threshold = trigger if trigger > 0 else keep_last * 2 |
| if len(messages) <= threshold: |
| return messages |
|
|
| system = messages[0] |
| |
| cut = len(messages) - keep_last |
| while cut < len(messages) and messages[cut].get("role") == "tool": |
| cut += 1 |
| if cut >= len(messages) - 2: |
| return messages |
|
|
| old_messages = messages[1:cut] |
| recent = messages[cut:] |
|
|
| |
| inc_strategy = compression.include_strategy if compression else True |
| inc_military = compression.include_military if compression else True |
| inc_production = compression.include_production if compression else True |
|
|
| |
| last_state = {} |
| building_types = set() |
| unit_types_produced = set() |
| strategy_text = "" |
| errors = [] |
|
|
| for msg in old_messages: |
| |
| if inc_strategy and msg.get("role") == "user" and not strategy_text: |
| content_str = msg.get("content", "") |
| if isinstance(content_str, str): |
| for line in content_str.split("\n"): |
| if line.strip().startswith("Strategy:"): |
| strategy_text = line.strip() |
| break |
|
|
| if msg.get("role") != "tool": |
| continue |
| try: |
| content = json.loads(msg["content"]) if isinstance(msg["content"], str) else msg["content"] |
| if not isinstance(content, dict): |
| continue |
|
|
| |
| if "tick" in content and "economy" in content: |
| last_state = content |
|
|
| |
| for bt in content.get("building_types", []): |
| building_types.add(bt) |
|
|
| |
| if inc_production and "note" in content: |
| note = content["note"] |
| if isinstance(note, str) and "queued" in note: |
| |
| import re |
| m = re.search(r"'(\w+)'.*queued", note) |
| if m: |
| name = m.group(1) |
| |
| if "per unit" in note or "each" in note: |
| unit_types_produced.add(name) |
| else: |
| building_types.add(name) |
|
|
| |
| if content.get("placement_failed"): |
| errors.append("placement failed") |
| elif "error" in content and isinstance(content["error"], str): |
| err = content["error"] |
| if len(err) < 80: |
| errors.append(err) |
| except (json.JSONDecodeError, TypeError): |
| pass |
|
|
| |
| parts = [f"[History: {len(old_messages)} earlier messages removed]"] |
|
|
| if last_state: |
| eco = last_state.get("economy", {}) |
| parts.append( |
| f"Last state at tick {last_state.get('tick', '?')}: " |
| f"${eco.get('cash', '?')} cash, " |
| f"{last_state.get('own_units', '?')} units, " |
| f"{last_state.get('own_buildings', '?')} buildings" |
| ) |
|
|
| if inc_strategy and strategy_text: |
| parts.append(strategy_text) |
|
|
| if building_types: |
| parts.append(f"Buildings built: {', '.join(sorted(building_types))}") |
|
|
| if inc_production and unit_types_produced: |
| parts.append(f"Units produced: {', '.join(sorted(unit_types_produced))}") |
|
|
| if inc_military and last_state: |
| mil = last_state.get("military", {}) |
| if mil: |
| parts.append( |
| f"Military: {mil.get('units_killed', 0)} kills, " |
| f"{mil.get('units_lost', 0)} losses" |
| ) |
|
|
| if errors: |
| unique = list(dict.fromkeys(errors))[-3:] |
| parts.append(f"Recent issues: {'; '.join(unique)}") |
|
|
| suffix = prompts.compression_suffix if prompts else "Game continues from current state." |
| parts.append(suffix) |
|
|
| return [ |
| system, |
| {"role": "user", "content": "\n".join(parts)}, |
| *recent, |
| ] |
|
|
|
|
| async def run_agent(config, verbose: bool = False): |
| """Connect to OpenRA-RL and play a game using an LLM agent.""" |
| url = config.agent.server_url |
| llm_config = config.llm |
| max_turns = config.agent.max_turns |
| max_time = config.agent.max_time_s |
|
|
| |
| is_local = any(h in llm_config.base_url for h in ("localhost", "127.0.0.1")) |
| if is_local and llm_config.request_timeout_s <= 120.0: |
| llm_config = llm_config.model_copy(update={"request_timeout_s": 300.0}) |
|
|
| print(f"Connecting to {url}...") |
| print(f"Model: {llm_config.model} @ {llm_config.base_url}") |
| if is_local: |
| print(f"Timeout: {int(llm_config.request_timeout_s)}s (local model)") |
|
|
| if "openrouter.ai" in llm_config.base_url.lower(): |
| print("Checking model route for tool-calling support...") |
| try: |
| preflight_ok, preflight_err = await _preflight_tool_calling_support(llm_config) |
| except Exception as e: |
| print(f" [ERROR] Preflight check failed: {e}") |
| print(" Aborting before game launch (no match started).") |
| return |
| if not preflight_ok: |
| print(f" [ERROR] Preflight check failed: {preflight_err}") |
| print(" Aborting before game launch (no match started).") |
| return |
|
|
| async with OpenRAMCPClient(base_url=url, message_timeout_s=300.0) as env: |
| print("Resetting environment (launching OpenRA)...") |
| await env.reset() |
|
|
| |
| mcp_tools = await env.list_tools() |
| openai_tools = mcp_tools_to_openai(mcp_tools) |
| tool_names = {t["function"]["name"] for t in openai_tools} |
| print(f"Discovered {len(mcp_tools)} MCP tools") |
|
|
| if verbose: |
| for t in mcp_tools: |
| print(f" - {t.name}: {t.description[:60]}...") |
|
|
| |
| system_prompt = load_system_prompt(config) |
| messages = [{"role": "system", "content": system_prompt}] |
|
|
| |
| planning_strategy = "" |
| planning_status = await env.call_tool("get_planning_status") |
|
|
| if planning_status.get("planning_enabled", True) is not False: |
| print("Starting pre-game planning phase...") |
| planning_data = await env.call_tool("start_planning_phase") |
|
|
| if planning_data.get("planning_active"): |
| max_planning_turns = planning_data.get("max_turns", 10) |
| opponent_summary = planning_data.get("opponent_summary", "") |
|
|
| prompts = config.prompts |
| planning_prompt = prompts.planning_prompt.format( |
| max_turns=max_planning_turns, |
| map_name=planning_data.get("map", {}).get("map_name", "?"), |
| map_width=planning_data.get("map", {}).get("width", "?"), |
| map_height=planning_data.get("map", {}).get("height", "?"), |
| base_x=planning_data.get("base_position", {}).get("x", "?"), |
| base_y=planning_data.get("base_position", {}).get("y", "?"), |
| enemy_x=planning_data.get("enemy_estimated_position", {}).get("x", "?"), |
| enemy_y=planning_data.get("enemy_estimated_position", {}).get("y", "?"), |
| faction=planning_data.get("your_faction", "?"), |
| side=planning_data.get("your_side", "?"), |
| opponent_summary=opponent_summary, |
| planning_nudge=prompts.planning_nudge, |
| ) |
| messages.append({"role": "user", "content": planning_prompt}) |
|
|
| |
| planning_done = False |
| for planning_turn in range(max_planning_turns + 2): |
| try: |
| response = await chat_completion(messages, openai_tools, llm_config, verbose, prompts=config.prompts) |
| except (RuntimeError, httpx.ReadTimeout, httpx.ConnectTimeout) as e: |
| print(f" [Planning] API error: {e}") |
| print(" Skipping planning phase.") |
| break |
| if response is None: |
| break |
|
|
| choice = response["choices"][0] |
| assistant_msg = choice["message"] |
| messages.append(assistant_msg) |
|
|
| if verbose and assistant_msg.get("content"): |
| print(f" [Planning] {assistant_msg['content'][:200]}") |
|
|
| tool_calls = assistant_msg.get("tool_calls", []) |
| if not tool_calls: |
| messages.append({ |
| "role": "user", |
| "content": prompts.planning_nudge, |
| }) |
| continue |
|
|
| for tc in tool_calls: |
| fn_name = tc["function"]["name"] |
| try: |
| fn_args = json.loads(tc["function"].get("arguments", "{}")) |
| except (json.JSONDecodeError, TypeError): |
| fn_args = {} |
|
|
| if verbose: |
| args_str = json.dumps(fn_args) |
| if len(args_str) > 80: |
| args_str = args_str[:80] + "..." |
| print(f" [Planning Tool] {fn_name}({args_str})") |
|
|
| try: |
| result = await env.call_tool(fn_name, **fn_args) |
| except Exception as e: |
| result = {"error": str(e)} |
|
|
| messages.append({ |
| "role": "tool", |
| "tool_call_id": tc["id"], |
| "content": json.dumps(result) if not isinstance(result, str) else result, |
| }) |
|
|
| |
| if isinstance(result, dict): |
| if result.get("planning_complete"): |
| planning_strategy = result.get("strategy", "") |
| planning_done = True |
| if verbose: |
| print(f" [Planning] Strategy: {planning_strategy[:150]}...") |
| elif result.get("planning_expired"): |
| planning_strategy = result.get("strategy", "") |
| planning_done = True |
| print(f" [Planning] Expired: {result.get('reason', '?')}") |
|
|
| if planning_done: |
| break |
|
|
| if not planning_done: |
| |
| try: |
| result = await env.call_tool( |
| "end_planning_phase", |
| strategy="(planning timed out, no explicit strategy)" |
| ) |
| planning_strategy = result.get("strategy", "") |
| except Exception: |
| pass |
| print(" Planning phase timed out, proceeding to gameplay.") |
|
|
| print(f"Planning phase complete. Strategy recorded: {bool(planning_strategy)}") |
| else: |
| if verbose: |
| print(f" Planning: {planning_data.get('message', 'skipped')}") |
|
|
| |
| |
| |
| |
| messages = [messages[0]] |
|
|
| state = await env.call_tool("get_game_state") |
| briefing = compose_pregame_briefing(state) |
|
|
| strategy_section = "" |
| if planning_strategy: |
| strategy_section = f"\n\n## Your Pre-Game Strategy\n{planning_strategy}\n" |
|
|
| |
| mcv_id = None |
| for u in state.get("units_summary", []): |
| if u.get("type") == "mcv": |
| mcv_id = u["id"] |
| break |
| faction = state.get("faction", "") |
| barracks_type = "tent" if faction in {"england", "france", "germany"} else "barr" |
|
|
| mcv_note = f" Your MCV is unit {mcv_id}." if mcv_id else "" |
|
|
| game_start_prompts = config.prompts |
| messages.append({ |
| "role": "user", |
| "content": game_start_prompts.game_start.format( |
| strategy_section=strategy_section, |
| briefing=briefing, |
| barracks_type=barracks_type, |
| mcv_note=mcv_note, |
| ), |
| }) |
|
|
| total_tool_calls = 0 |
| total_api_calls = 0 |
| start_time = time.time() |
| game_done = False |
| encountered_agent_error = False |
| consecutive_errors = 0 |
| MAX_CONSECUTIVE_ERRORS = 3 |
|
|
| turn = 0 |
| while True: |
| |
| elapsed = time.time() - start_time |
| if max_time and elapsed >= max_time: |
| print(f"\n TIME LIMIT reached ({max_time}s). Stopping.") |
| break |
| if max_turns and turn >= max_turns: |
| break |
| turn += 1 |
|
|
| |
| if llm_config.compression_strategy != "none": |
| messages = compress_history( |
| messages, keep_last=llm_config.keep_last_messages, |
| trigger=llm_config.compression_trigger, |
| prompts=config.prompts, |
| compression=config.prompts.compression) |
|
|
| |
| if total_api_calls > 0: |
| try: |
| briefing_state = await env.call_tool("get_game_state") |
| briefing = format_state_briefing(briefing_state) |
| if briefing: |
| messages.append({"role": "user", "content": briefing}) |
| if verbose: |
| |
| for a in briefing_state.get("alerts", []): |
| print(f" [ALERT] {a}") |
| |
| if isinstance(briefing_state, dict) and briefing_state.get("done"): |
| game_done = True |
| print(f"\n GAME OVER: {briefing_state.get('result', '?').upper()} at tick {briefing_state.get('tick', '?')}") |
| break |
| except Exception: |
| pass |
|
|
| |
| response = None |
| max_retries = llm_config.max_retries |
| is_local = any(h in llm_config.base_url for h in ("localhost", "127.0.0.1")) |
| for attempt in range(max_retries): |
| try: |
| response = await chat_completion(messages, openai_tools, llm_config, verbose, prompts=config.prompts) |
| break |
| except (httpx.ReadTimeout, httpx.ConnectTimeout): |
| timeout_s = int(llm_config.request_timeout_s) |
| print(f"\n [ERROR] Request timed out after {timeout_s}s.") |
| encountered_agent_error = True |
| if is_local: |
| print(" [HINT] Local models can be slow. Increase timeout in config.yaml:") |
| print(f" llm.request_timeout_s: {timeout_s * 2}") |
| break |
| except RuntimeError as e: |
| err_str = str(e) |
| retriable = any(code in err_str for code in ("429", "500", "502", "503", "504")) |
| if retriable and attempt < max_retries - 1: |
| wait = llm_config.retry_backoff_s * (attempt + 1) |
| print(f"\n [RETRY] Provider error, waiting {wait}s ({attempt + 1}/{max_retries})...") |
| print(f" {e}") |
| await asyncio.sleep(wait) |
| else: |
| print(f"\n [ERROR] API call failed: {e}") |
| encountered_agent_error = True |
| break |
| if response is None: |
| print(" [ERROR] Stopping agent.") |
| encountered_agent_error = True |
| break |
|
|
| total_api_calls += 1 |
| choice = response["choices"][0] |
| assistant_msg = choice["message"] |
|
|
| |
| messages.append(assistant_msg) |
|
|
| |
| if assistant_msg.get("content") and verbose: |
| print(f"\n [LLM thinks] {assistant_msg['content'][:200]}") |
|
|
| |
| tool_calls = assistant_msg.get("tool_calls", []) |
| if not tool_calls: |
| |
| if verbose: |
| content = assistant_msg.get("content", "(no content)") |
| print(f" [LLM] No tool calls. Response: {content[:100]}") |
| messages.append({ |
| "role": "user", |
| "content": config.prompts.no_tool_nudge, |
| }) |
| continue |
|
|
| |
| for tc in tool_calls: |
| fn_name = tc["function"]["name"] |
| try: |
| fn_args = json.loads(tc["function"].get("arguments", "{}")) |
| except (json.JSONDecodeError, TypeError): |
| fn_args = {} |
|
|
| total_tool_calls += 1 |
|
|
| if verbose: |
| args_str = json.dumps(fn_args) |
| if len(args_str) > 80: |
| args_str = args_str[:80] + "..." |
| print(f" [Tool] {fn_name}({args_str})") |
|
|
| try: |
| result = await env.call_tool(fn_name, **fn_args) |
| consecutive_errors = 0 |
| except Exception as e: |
| result = {"error": str(e)} |
| |
| if fn_name not in tool_names: |
| import difflib |
| close = difflib.get_close_matches(fn_name, tool_names, n=3, cutoff=0.4) |
| |
| build_keywords = {"build", "place", "train", "produce", "construct"} |
| if any(kw in fn_name.lower() for kw in build_keywords): |
| for bt in ("build_unit", "build_structure", "build_and_place"): |
| if bt in tool_names and bt not in close: |
| close.append(bt) |
| if close: |
| result["suggested_tools"] = close |
|
|
| |
| if isinstance(result, dict) and "connection lost" in str(result.get("error", "")).lower(): |
| consecutive_errors += 1 |
| if consecutive_errors >= MAX_CONSECUTIVE_ERRORS: |
| print(f"\n GAME CRASHED: {consecutive_errors} consecutive connection errors. Stopping.") |
| encountered_agent_error = True |
| game_done = True |
|
|
| |
| result_str = json.dumps(result) if not isinstance(result, str) else result |
|
|
| messages.append({ |
| "role": "tool", |
| "tool_call_id": tc["id"], |
| "content": result_str, |
| }) |
|
|
| |
| if isinstance(result, dict) and result.get("done"): |
| game_done = True |
| print(f"\n GAME OVER: {result.get('result', '?').upper()} at tick {result.get('tick', '?')}") |
|
|
| if verbose and isinstance(result, dict): |
| result_preview = json.dumps(result) |
| if len(result_preview) > 500: |
| result_preview = result_preview[:500] + "..." |
| print(f" [Result] {result_preview}") |
|
|
| |
| if total_api_calls % 5 == 0 or game_done: |
| elapsed = time.time() - start_time |
| limit_str = f"/{max_turns}" if max_turns else "" |
| time_str = f"{elapsed:.0f}/{max_time}s" if max_time else f"{elapsed:.0f}s" |
| print( |
| f" Turn {turn}{limit_str} | " |
| f"API calls: {total_api_calls} | " |
| f"Tool calls: {total_tool_calls} | " |
| f"Time: {time_str}" |
| ) |
|
|
| if game_done: |
| break |
|
|
| |
| if choice.get("finish_reason") == "stop" and not tool_calls: |
| messages.append({ |
| "role": "user", |
| "content": config.prompts.continue_nudge, |
| }) |
|
|
| |
| if not game_done: |
| try: |
| await env.call_tool("surrender") |
| print("\n Surrendered (replay will have proper ending)") |
| except Exception: |
| pass |
|
|
| |
| elapsed = time.time() - start_time |
| print() |
| print("=" * 70) |
| print(f"Agent finished after {total_api_calls} API calls, {total_tool_calls} tool calls") |
| print(f"Time: {elapsed:.1f}s ({elapsed / max(total_api_calls, 1):.1f}s per API call)") |
|
|
| |
| try: |
| final = await env.call_tool("get_game_state") |
| mil = final.get("military", {}) |
| eco = final.get("economy", {}) |
| print(f"Result: {final.get('result', 'ongoing').upper()}") |
| print() |
| print("--- SCORECARD ---") |
| print(f" Planning: {'ON — ' + planning_strategy[:100] if planning_strategy else 'OFF'}") |
| print(f" Ticks played: {final.get('tick', '?')}") |
| print(f" Units killed: {mil.get('units_killed', 0)} (value: ${mil.get('kills_cost', 0)})") |
| print(f" Units lost: {mil.get('units_lost', 0)} (value: ${mil.get('deaths_cost', 0)})") |
| print(f" Buildings killed: {mil.get('buildings_killed', 0)}") |
| print(f" Buildings lost: {mil.get('buildings_lost', 0)}") |
| print(f" Army value: ${mil.get('army_value', 0)}") |
| print(f" Assets value: ${mil.get('assets_value', 0)}") |
| print(f" Experience: {mil.get('experience', 0)}") |
| print(f" Orders issued: {mil.get('order_count', 0)}") |
| print(f" Cash remaining: ${eco.get('cash', 0)}") |
| print(f" K/D cost ratio: {mil.get('kills_cost', 0) / max(mil.get('deaths_cost', 1), 1):.2f}") |
| print(f" Own units: {final.get('own_units', '?')}") |
| print(f" Own buildings: {final.get('own_buildings', '?')}") |
| print(f" Explored: {final.get('explored_percent', 0)}%") |
| rv = final.get("reward_vector", {}) |
| if rv: |
| print(" Reward vector:") |
| for dim, val in rv.items(): |
| print(f" {dim:15s} {val:+.3f}") |
| print() |
| except Exception as e: |
| print(f" (could not get final state: {e})") |
|
|
| |
| replay = {} |
| try: |
| replay = await env.call_tool("get_replay_path") |
| if replay.get("path"): |
| print(f"Replay: {replay['path']}") |
| except Exception: |
| pass |
|
|
| |
| should_export, should_upload, skip_reason = _bench_export_policy(encountered_agent_error) |
| try: |
| from datetime import datetime, timezone |
| from pathlib import Path |
|
|
| resolved_name = config.agent.agent_name or llm_config.model |
| sub = { |
| "agent_name": resolved_name, |
| "agent_type": config.agent.agent_type or "LLM", |
| "agent_url": config.agent.agent_url, |
| "opponent": config.opponent.bot_type.capitalize(), |
| "games": 1, |
| "result": final.get("result", ""), |
| "win": final.get("result") == "win", |
| "ticks": final.get("tick", 0), |
| "kills_cost": mil.get("kills_cost", 0), |
| "deaths_cost": mil.get("deaths_cost", 0), |
| "kd_ratio": round(mil.get("kills_cost", 0) / max(mil.get("deaths_cost", 1), 1), 2), |
| "assets_value": mil.get("assets_value", 0), |
| "explored_percent": final.get("explored_percent", 0), |
| "reward_vector": final.get("reward_vector", {}), |
| "replay_path": replay.get("path", ""), |
| "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), |
| } |
| export_dir = Path.home() / ".openra-rl" / "bench-exports" |
| export_dir.mkdir(parents=True, exist_ok=True) |
| ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") |
| slug = resolved_name.replace("/", "_")[:40] |
| export_path = export_dir / f"bench-{slug}-{ts}.json" |
| export_path.write_text(json.dumps(sub, indent=2)) |
| print(f"Bench export: {export_path}") |
|
|
| |
| bench_url = config.agent.bench_url |
| if config.agent.bench_upload and bench_url: |
| if not should_upload: |
| print(f"Skipping bench upload: {skip_reason}") |
| else: |
| try: |
| from openra_env.bench_submit import gradio_submit |
| msg = gradio_submit(bench_url, sub, replay_path=replay.get("path", "")) |
| print(f"Uploaded to bench: {msg}") |
| except Exception as e: |
| print(f" (bench upload failed: {e})") |
| except Exception as e: |
| print(f" (bench export failed: {e})") |
|
|
| print("=" * 70) |
|
|