Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import os | |
| import re | |
| from google import genai | |
| from google.genai import types | |
| from diplomacy import Game | |
| from prompts import PLAYER_PROMPT | |
| from observation import POWERS, build_current_state, build_observation | |
| MODEL = "gemini-2.5-flash" | |
| _GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") | |
| if not _GEMINI_API_KEY: | |
| raise RuntimeError("GEMINI_API_KEY (or GOOGLE_API_KEY) is required to run game_loop.py") | |
| client = genai.Client(api_key=_GEMINI_API_KEY) | |
| _connection_semaphore = asyncio.Semaphore(7) | |
| SEASON_ORDER = {"S": 0, "F": 1, "W": 2} | |
| PHASE_TYPE_ORDER = {"M": 0, "R": 1, "A": 2} | |
| def _parse_section(text, header): | |
| """Extract the content after a HEADER: line up to the next known header or end.""" | |
| headers = ["ORDERS:", "STRATEGY:", "PRIVATE MESSAGES:", "PUBLIC MESSAGE:"] | |
| pattern = re.escape(header) + r"\s*\n(.*?)(?=" + "|".join(re.escape(h) for h in headers if h != header) + r"|\Z)" | |
| match = re.search(pattern, text, re.DOTALL) | |
| if not match: | |
| return "" | |
| return match.group(1).strip() | |
| def _parse_private_messages(raw): | |
| """Parse 'TO POWER: message' lines into a dict.""" | |
| messages = {} | |
| if not raw or raw.lower() == "none": | |
| return messages | |
| for line in raw.splitlines(): | |
| line = line.strip() | |
| m = re.match(r"TO\s+(\w+):\s*(.*)", line, re.IGNORECASE) | |
| if m: | |
| messages[m.group(1).upper()] = m.group(2).strip() | |
| return messages | |
| def _format_board_state(game): | |
| """Readable string of all units on the board.""" | |
| lines = [] | |
| for power in POWERS: | |
| units = game.powers[power].units | |
| if units: | |
| lines.append(f"{power}: {', '.join(units)}") | |
| return "\n".join(lines) | |
| def _format_all_centers(game): | |
| """Readable string of supply center ownership.""" | |
| lines = [] | |
| for power in POWERS: | |
| centers = game.powers[power].centers | |
| if centers: | |
| lines.append(f"{power}: {', '.join(sorted(centers))}") | |
| return "\n".join(lines) | |
| def _format_possible_orders(game, power): | |
| """Filter possible orders to only those relevant to this power's units.""" | |
| all_possible = game.get_all_possible_orders() | |
| unit_locs = {u.split()[-1] for u in game.powers[power].units} | |
| lines = [] | |
| for loc in sorted(unit_locs): | |
| if loc in all_possible and all_possible[loc]: | |
| lines.append(f"{loc}: {', '.join(all_possible[loc])}") | |
| return "\n".join(lines) if lines else "No orders available." | |
| def _validate_orders(game, power, orders): | |
| """Filter orders to only those the game engine accepts. Invalid ones get HOLD.""" | |
| all_possible = game.get_all_possible_orders() | |
| valid_set = set() | |
| for loc_orders in all_possible.values(): | |
| valid_set.update(loc_orders) | |
| valid = [] | |
| unit_locs_covered = set() | |
| for order in orders: | |
| if order in valid_set: | |
| valid.append(order) | |
| parts = order.split() | |
| if len(parts) >= 2: | |
| unit_locs_covered.add(parts[1]) | |
| for unit in game.powers[power].units: | |
| loc = unit.split()[-1] | |
| if loc not in unit_locs_covered: | |
| valid.append(f"{unit} H") | |
| return valid | |
| def _format_messages_for_power(power, private_messages_this_turn, public_messages_this_turn): | |
| """Build the messages string a player sees this phase.""" | |
| lines = [] | |
| for sender, msg_dict in private_messages_this_turn.items(): | |
| if power in msg_dict: | |
| lines.append(f"FROM {sender}: {msg_dict[power]}") | |
| for sender, msg in public_messages_this_turn: | |
| lines.append(f"PUBLIC ({sender}): {msg}") | |
| return "\n".join(lines) if lines else "No messages this phase." | |
| def _format_history(history, last_n=3): | |
| """Summarize the last N turns of history for the prompt.""" | |
| recent = history[-last_n:] if history else [] | |
| if not recent: | |
| return "No prior history." | |
| lines = [] | |
| for state in recent: | |
| lines.append(f"--- {state['turn']} ---") | |
| for p in POWERS: | |
| orders = state["orders"].get(p, []) | |
| if orders: | |
| lines.append(f" {p}: {', '.join(orders)}") | |
| return "\n".join(lines) | |
| async def call_player_agent(power, game, private_messages_this_turn, public_messages_this_turn, history): | |
| """Call Claude Haiku for a single power and parse the structured response.""" | |
| prompt = PLAYER_PROMPT.format( | |
| power=power, | |
| phase=game.get_current_phase(), | |
| units=", ".join(game.powers[power].units), | |
| centers=", ".join(sorted(game.powers[power].centers)), | |
| all_centers=_format_all_centers(game), | |
| board_state=_format_board_state(game), | |
| history=_format_history(history), | |
| messages=_format_messages_for_power(power, private_messages_this_turn, public_messages_this_turn), | |
| possible_orders=_format_possible_orders(game, power), | |
| ) | |
| async with _connection_semaphore: | |
| response = await client.aio.models.generate_content( | |
| model=MODEL, | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| max_output_tokens=1024, | |
| temperature=0.7, | |
| ), | |
| ) | |
| text = response.text or "" | |
| orders_raw = _parse_section(text, "ORDERS:") | |
| orders = [line.strip() for line in orders_raw.splitlines() if line.strip()] | |
| strategy = _parse_section(text, "STRATEGY:") | |
| if not strategy: | |
| print(f" Warning: {power} returned empty strategy, using fallback") | |
| strategy = "No strategy provided" | |
| private_msgs = _parse_private_messages(_parse_section(text, "PRIVATE MESSAGES:")) | |
| public_msg = _parse_section(text, "PUBLIC MESSAGE:") | |
| if public_msg.lower() == "none": | |
| public_msg = "" | |
| return { | |
| "power": power, | |
| "orders": orders, | |
| "strategy": strategy, | |
| "private_messages": private_msgs, | |
| "public_message": public_msg, | |
| } | |
| def _init_comm_tracker(): | |
| return {"curr": {p: {"messaged": set(), "ignored": set(), "message_count": 0} for p in POWERS}, "prev": {}} | |
| def _update_comm_tracker(comm_tracker, results): | |
| """Rotate curr -> prev, then rebuild curr from this turn's messages.""" | |
| comm_tracker["prev"] = { | |
| p: { | |
| "messaged": set(comm_tracker["curr"][p]["messaged"]), | |
| "ignored": set(comm_tracker["curr"][p]["ignored"]), | |
| "message_count": comm_tracker["curr"][p]["message_count"], | |
| } | |
| for p in POWERS | |
| } | |
| for p in POWERS: | |
| comm_tracker["curr"][p] = {"messaged": set(), "ignored": set(), "message_count": 0} | |
| all_powers_messaged_to = {p: set() for p in POWERS} | |
| for r in results: | |
| sender = r["power"] | |
| recipients = set(r["private_messages"].keys()) | |
| all_powers_messaged_to[sender] = recipients | |
| comm_tracker["curr"][sender]["message_count"] = len(recipients) | |
| comm_tracker["curr"][sender]["messaged"] = recipients | |
| for p in POWERS: | |
| others = set(POWERS) - {p} | |
| messaged_to = all_powers_messaged_to[p] | |
| comm_tracker["curr"][p]["ignored"] = others - messaged_to | |
| def _serializable_comm_tracker(comm_tracker): | |
| """Convert sets to lists for JSON serialization.""" | |
| out = {} | |
| for key in ("curr", "prev"): | |
| out[key] = {} | |
| for p, data in comm_tracker.get(key, {}).items(): | |
| out[key][p] = { | |
| "messaged": sorted(data["messaged"]) if isinstance(data["messaged"], set) else data["messaged"], | |
| "ignored": sorted(data["ignored"]) if isinstance(data["ignored"], set) else data["ignored"], | |
| "message_count": data["message_count"], | |
| } | |
| return out | |
| def _phase_sort_key(turn): | |
| """Convert a Diplomacy phase like S1904R into a sortable key.""" | |
| if not turn or len(turn) < 6: | |
| return None | |
| return ( | |
| int(turn[1:5]), | |
| SEASON_ORDER.get(turn[0], 99), | |
| PHASE_TYPE_ORDER.get(turn[-1], 99), | |
| ) | |
| def _next_game_id(samples): | |
| """Compute the next game identifier from existing serialized samples.""" | |
| if not samples: | |
| return 0 | |
| explicit_ids = [ | |
| sample.get("observation", {}).get("game_id") | |
| for sample in samples | |
| if sample.get("observation", {}).get("game_id") is not None | |
| ] | |
| if explicit_ids: | |
| return max(explicit_ids) + 1 | |
| game_count = 1 | |
| previous_turn = samples[0].get("observation", {}).get("turn") | |
| previous_key = _phase_sort_key(previous_turn) | |
| for sample in samples[1:]: | |
| turn = sample.get("observation", {}).get("turn") | |
| current_key = _phase_sort_key(turn) | |
| if previous_key is not None and current_key is not None and current_key < previous_key: | |
| game_count += 1 | |
| previous_key = current_key | |
| return game_count | |
| async def run_game(max_turns=20): | |
| existing = {"training_data": [], "public_chat_log": [], "history": []} | |
| if os.path.exists("game_data.json"): | |
| with open("game_data.json", "r") as f: | |
| existing = json.load(f) | |
| game_id = _next_game_id(existing.get("training_data", [])) | |
| game = Game() | |
| history = [] | |
| training_data = [] | |
| public_chat_log = [] | |
| comm_tracker = _init_comm_tracker() | |
| private_messages_this_turn = {} | |
| public_messages_this_turn = [] | |
| for turn in range(max_turns): | |
| phase = game.get_current_phase() | |
| print(f"=== Turn {turn + 1}: {phase} ===") | |
| if game.is_game_done: | |
| print("Game over.") | |
| break | |
| results = await asyncio.gather(*( | |
| call_player_agent(p, game, private_messages_this_turn, public_messages_this_turn, history) | |
| for p in POWERS | |
| if game.powers[p].units | |
| )) | |
| private_messages_this_turn = {} | |
| public_messages_this_turn = [] | |
| strategies = {} | |
| submitted_orders = {power: [] for power in POWERS} | |
| for r in results: | |
| power = r["power"] | |
| validated = _validate_orders(game, power, r["orders"]) | |
| submitted_orders[power] = list(validated) | |
| game.set_orders(power, validated) | |
| strategies[power] = r["strategy"] | |
| if r["private_messages"]: | |
| private_messages_this_turn[power] = r["private_messages"] | |
| if r["public_message"]: | |
| public_messages_this_turn.append((power, r["public_message"])) | |
| _update_comm_tracker(comm_tracker, results) | |
| serializable_comm_tracker = _serializable_comm_tracker(comm_tracker) | |
| public_chat_log.append({ | |
| "turn": phase, | |
| "messages": [(s, m) for s, m in public_messages_this_turn], | |
| }) | |
| current_state = build_current_state( | |
| game, | |
| phase=phase, | |
| submitted_orders=submitted_orders, | |
| comm_tracker=serializable_comm_tracker, | |
| ) | |
| history.append(current_state) | |
| game.process() | |
| if turn >= 4: | |
| for power in POWERS: | |
| if not game.powers[power].units: | |
| continue | |
| obs = build_observation( | |
| power, current_state, history, | |
| serializable_comm_tracker, | |
| public_chat_log, | |
| game_id=game_id, | |
| game_step_index=len(training_data), | |
| ) | |
| training_data.append({ | |
| "observation": obs, | |
| "true_strategy": strategies.get(power, ""), | |
| }) | |
| existing["training_data"].extend(training_data) | |
| existing["public_chat_log"].extend(public_chat_log) | |
| existing["history"].extend(history) | |
| with open("game_data.json", "w") as f: | |
| json.dump(existing, f, indent=2, default=str) | |
| print(f"Saved {len(training_data)} new samples ({len(existing['training_data'])} total) to game_data.json") | |
| return existing | |
| if __name__ == "__main__": | |
| asyncio.run(run_game()) | |