import asyncio import json import os import re from google import genai from google.genai import types from diplomacy import Game from prompts import PLAYER_PROMPT from observation import POWERS, build_current_state, build_observation MODEL = "gemini-2.5-flash" _GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not _GEMINI_API_KEY: raise RuntimeError("GEMINI_API_KEY (or GOOGLE_API_KEY) is required to run game_loop.py") client = genai.Client(api_key=_GEMINI_API_KEY) _connection_semaphore = asyncio.Semaphore(7) SEASON_ORDER = {"S": 0, "F": 1, "W": 2} PHASE_TYPE_ORDER = {"M": 0, "R": 1, "A": 2} def _parse_section(text, header): """Extract the content after a HEADER: line up to the next known header or end.""" headers = ["ORDERS:", "STRATEGY:", "PRIVATE MESSAGES:", "PUBLIC MESSAGE:"] pattern = re.escape(header) + r"\s*\n(.*?)(?=" + "|".join(re.escape(h) for h in headers if h != header) + r"|\Z)" match = re.search(pattern, text, re.DOTALL) if not match: return "" return match.group(1).strip() def _parse_private_messages(raw): """Parse 'TO POWER: message' lines into a dict.""" messages = {} if not raw or raw.lower() == "none": return messages for line in raw.splitlines(): line = line.strip() m = re.match(r"TO\s+(\w+):\s*(.*)", line, re.IGNORECASE) if m: messages[m.group(1).upper()] = m.group(2).strip() return messages def _format_board_state(game): """Readable string of all units on the board.""" lines = [] for power in POWERS: units = game.powers[power].units if units: lines.append(f"{power}: {', '.join(units)}") return "\n".join(lines) def _format_all_centers(game): """Readable string of supply center ownership.""" lines = [] for power in POWERS: centers = game.powers[power].centers if centers: lines.append(f"{power}: {', '.join(sorted(centers))}") return "\n".join(lines) def _format_possible_orders(game, power): """Filter possible orders to only those relevant to this power's units.""" all_possible = game.get_all_possible_orders() unit_locs = {u.split()[-1] for u in game.powers[power].units} lines = [] for loc in sorted(unit_locs): if loc in all_possible and all_possible[loc]: lines.append(f"{loc}: {', '.join(all_possible[loc])}") return "\n".join(lines) if lines else "No orders available." def _validate_orders(game, power, orders): """Filter orders to only those the game engine accepts. Invalid ones get HOLD.""" all_possible = game.get_all_possible_orders() valid_set = set() for loc_orders in all_possible.values(): valid_set.update(loc_orders) valid = [] unit_locs_covered = set() for order in orders: if order in valid_set: valid.append(order) parts = order.split() if len(parts) >= 2: unit_locs_covered.add(parts[1]) for unit in game.powers[power].units: loc = unit.split()[-1] if loc not in unit_locs_covered: valid.append(f"{unit} H") return valid def _format_messages_for_power(power, private_messages_this_turn, public_messages_this_turn): """Build the messages string a player sees this phase.""" lines = [] for sender, msg_dict in private_messages_this_turn.items(): if power in msg_dict: lines.append(f"FROM {sender}: {msg_dict[power]}") for sender, msg in public_messages_this_turn: lines.append(f"PUBLIC ({sender}): {msg}") return "\n".join(lines) if lines else "No messages this phase." def _format_history(history, last_n=3): """Summarize the last N turns of history for the prompt.""" recent = history[-last_n:] if history else [] if not recent: return "No prior history." lines = [] for state in recent: lines.append(f"--- {state['turn']} ---") for p in POWERS: orders = state["orders"].get(p, []) if orders: lines.append(f" {p}: {', '.join(orders)}") return "\n".join(lines) async def call_player_agent(power, game, private_messages_this_turn, public_messages_this_turn, history): """Call Claude Haiku for a single power and parse the structured response.""" prompt = PLAYER_PROMPT.format( power=power, phase=game.get_current_phase(), units=", ".join(game.powers[power].units), centers=", ".join(sorted(game.powers[power].centers)), all_centers=_format_all_centers(game), board_state=_format_board_state(game), history=_format_history(history), messages=_format_messages_for_power(power, private_messages_this_turn, public_messages_this_turn), possible_orders=_format_possible_orders(game, power), ) async with _connection_semaphore: response = await client.aio.models.generate_content( model=MODEL, contents=prompt, config=types.GenerateContentConfig( max_output_tokens=1024, temperature=0.7, ), ) text = response.text or "" orders_raw = _parse_section(text, "ORDERS:") orders = [line.strip() for line in orders_raw.splitlines() if line.strip()] strategy = _parse_section(text, "STRATEGY:") if not strategy: print(f" Warning: {power} returned empty strategy, using fallback") strategy = "No strategy provided" private_msgs = _parse_private_messages(_parse_section(text, "PRIVATE MESSAGES:")) public_msg = _parse_section(text, "PUBLIC MESSAGE:") if public_msg.lower() == "none": public_msg = "" return { "power": power, "orders": orders, "strategy": strategy, "private_messages": private_msgs, "public_message": public_msg, } def _init_comm_tracker(): return {"curr": {p: {"messaged": set(), "ignored": set(), "message_count": 0} for p in POWERS}, "prev": {}} def _update_comm_tracker(comm_tracker, results): """Rotate curr -> prev, then rebuild curr from this turn's messages.""" comm_tracker["prev"] = { p: { "messaged": set(comm_tracker["curr"][p]["messaged"]), "ignored": set(comm_tracker["curr"][p]["ignored"]), "message_count": comm_tracker["curr"][p]["message_count"], } for p in POWERS } for p in POWERS: comm_tracker["curr"][p] = {"messaged": set(), "ignored": set(), "message_count": 0} all_powers_messaged_to = {p: set() for p in POWERS} for r in results: sender = r["power"] recipients = set(r["private_messages"].keys()) all_powers_messaged_to[sender] = recipients comm_tracker["curr"][sender]["message_count"] = len(recipients) comm_tracker["curr"][sender]["messaged"] = recipients for p in POWERS: others = set(POWERS) - {p} messaged_to = all_powers_messaged_to[p] comm_tracker["curr"][p]["ignored"] = others - messaged_to def _serializable_comm_tracker(comm_tracker): """Convert sets to lists for JSON serialization.""" out = {} for key in ("curr", "prev"): out[key] = {} for p, data in comm_tracker.get(key, {}).items(): out[key][p] = { "messaged": sorted(data["messaged"]) if isinstance(data["messaged"], set) else data["messaged"], "ignored": sorted(data["ignored"]) if isinstance(data["ignored"], set) else data["ignored"], "message_count": data["message_count"], } return out def _phase_sort_key(turn): """Convert a Diplomacy phase like S1904R into a sortable key.""" if not turn or len(turn) < 6: return None return ( int(turn[1:5]), SEASON_ORDER.get(turn[0], 99), PHASE_TYPE_ORDER.get(turn[-1], 99), ) def _next_game_id(samples): """Compute the next game identifier from existing serialized samples.""" if not samples: return 0 explicit_ids = [ sample.get("observation", {}).get("game_id") for sample in samples if sample.get("observation", {}).get("game_id") is not None ] if explicit_ids: return max(explicit_ids) + 1 game_count = 1 previous_turn = samples[0].get("observation", {}).get("turn") previous_key = _phase_sort_key(previous_turn) for sample in samples[1:]: turn = sample.get("observation", {}).get("turn") current_key = _phase_sort_key(turn) if previous_key is not None and current_key is not None and current_key < previous_key: game_count += 1 previous_key = current_key return game_count async def run_game(max_turns=20): existing = {"training_data": [], "public_chat_log": [], "history": []} if os.path.exists("game_data.json"): with open("game_data.json", "r") as f: existing = json.load(f) game_id = _next_game_id(existing.get("training_data", [])) game = Game() history = [] training_data = [] public_chat_log = [] comm_tracker = _init_comm_tracker() private_messages_this_turn = {} public_messages_this_turn = [] for turn in range(max_turns): phase = game.get_current_phase() print(f"=== Turn {turn + 1}: {phase} ===") if game.is_game_done: print("Game over.") break results = await asyncio.gather(*( call_player_agent(p, game, private_messages_this_turn, public_messages_this_turn, history) for p in POWERS if game.powers[p].units )) private_messages_this_turn = {} public_messages_this_turn = [] strategies = {} submitted_orders = {power: [] for power in POWERS} for r in results: power = r["power"] validated = _validate_orders(game, power, r["orders"]) submitted_orders[power] = list(validated) game.set_orders(power, validated) strategies[power] = r["strategy"] if r["private_messages"]: private_messages_this_turn[power] = r["private_messages"] if r["public_message"]: public_messages_this_turn.append((power, r["public_message"])) _update_comm_tracker(comm_tracker, results) serializable_comm_tracker = _serializable_comm_tracker(comm_tracker) public_chat_log.append({ "turn": phase, "messages": [(s, m) for s, m in public_messages_this_turn], }) current_state = build_current_state( game, phase=phase, submitted_orders=submitted_orders, comm_tracker=serializable_comm_tracker, ) history.append(current_state) game.process() if turn >= 4: for power in POWERS: if not game.powers[power].units: continue obs = build_observation( power, current_state, history, serializable_comm_tracker, public_chat_log, game_id=game_id, game_step_index=len(training_data), ) training_data.append({ "observation": obs, "true_strategy": strategies.get(power, ""), }) existing["training_data"].extend(training_data) existing["public_chat_log"].extend(public_chat_log) existing["history"].extend(history) with open("game_data.json", "w") as f: json.dump(existing, f, indent=2, default=str) print(f"Saved {len(training_data)} new samples ({len(existing['training_data'])} total) to game_data.json") return existing if __name__ == "__main__": asyncio.run(run_game())