| | """LLM agent for game-theory environments.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import random |
| | from typing import Any, Callable, Dict, List, Optional |
| |
|
| | from env.models import GameAction, GameObservation |
| | from constant_definitions.train.agent_constants import ( |
| | MAX_ACTION_TOKENS, |
| | MAX_PROMPT_HISTORY_ROUNDS, |
| | PARSE_FAILURE_SENTINEL, |
| | PROMPT_SECTION_ACTIONS, |
| | PROMPT_SECTION_GAME, |
| | PROMPT_SECTION_HISTORY, |
| | PROMPT_SECTION_INSTRUCTION, |
| | PROMPT_SECTION_SCORES, |
| | SYSTEM_PROMPT, |
| | TRAIN_TEMPERATURE_DENOMINATOR, |
| | TRAIN_TEMPERATURE_NUMERATOR, |
| | ) |
| |
|
| | _ZERO = int() |
| | _ONE = int(bool(True)) |
| | _NEWLINE = "\n" |
| | _SECTION_SEP = "\n\n" |
| | _BRACKET_OPEN = "[" |
| | _BRACKET_CLOSE = "]" |
| | _COLON_SPACE = ": " |
| | _DASH_SPACE = "- " |
| | _ROUND_PREFIX = "Round " |
| | _YOU_PLAYED = " | You played: " |
| | _OPP_PLAYED = " | Opponent played: " |
| | _YOUR_PAYOFF = " | Your payoff: " |
| | _OPP_PAYOFF = " | Opp payoff: " |
| |
|
| |
|
| | class PromptBuilder: |
| | """Formats GameObservation into a structured text prompt. |
| | |
| | The prompt intentionally excludes the opponent strategy name |
| | to prevent the model from shortcutting via strategy recognition. |
| | """ |
| |
|
| | @staticmethod |
| | def build(obs: GameObservation) -> str: |
| | """Build a structured prompt from a game observation.""" |
| | sections: List[str] = [] |
| |
|
| | |
| | sections.append( |
| | _BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE |
| | + _NEWLINE + obs.game_name |
| | + _NEWLINE + obs.game_description |
| | ) |
| |
|
| | |
| | if obs.history: |
| | history_lines: List[str] = [] |
| | history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:] |
| | for rnd in history_slice: |
| | line = ( |
| | _ROUND_PREFIX + str(rnd.round_number) |
| | + _YOU_PLAYED + rnd.player_action |
| | + _OPP_PLAYED + rnd.opponent_action |
| | + _YOUR_PAYOFF + str(rnd.player_payoff) |
| | + _OPP_PAYOFF + str(rnd.opponent_payoff) |
| | ) |
| | history_lines.append(line) |
| | sections.append( |
| | _BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE |
| | + _NEWLINE + _NEWLINE.join(history_lines) |
| | ) |
| |
|
| | |
| | sections.append( |
| | _BRACKET_OPEN + PROMPT_SECTION_SCORES + _BRACKET_CLOSE |
| | + _NEWLINE + "Your score" + _COLON_SPACE + str(obs.player_score) |
| | + _NEWLINE + "Opponent score" + _COLON_SPACE + str(obs.opponent_score) |
| | + _NEWLINE + "Round" + _COLON_SPACE + str(obs.current_round) |
| | + " of " + str(obs.total_rounds) |
| | ) |
| |
|
| | |
| | action_lines = [_DASH_SPACE + a for a in obs.available_actions] |
| | sections.append( |
| | _BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE |
| | + _NEWLINE + _NEWLINE.join(action_lines) |
| | ) |
| |
|
| | |
| | sections.append( |
| | _BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE |
| | + _NEWLINE + SYSTEM_PROMPT |
| | ) |
| |
|
| | return _SECTION_SEP.join(sections) |
| |
|
| |
|
| | def parse_action(response: str, available_actions: List[str]) -> str: |
| | """Parse an action from LLM response text. |
| | |
| | Tries: exact match -> case-insensitive -> substring -> random selection. |
| | """ |
| | stripped = response.strip() |
| |
|
| | |
| | if stripped in available_actions: |
| | return stripped |
| |
|
| | |
| | lower = stripped.lower() |
| | for action in available_actions: |
| | if action.lower() == lower: |
| | return action |
| |
|
| | |
| | for action in available_actions: |
| | if action.lower() in lower: |
| | return action |
| |
|
| | |
| | return random.choice(available_actions) |
| |
|
| |
|
| | class LLMAgent: |
| | """LLM-based agent compatible with TournamentRunner agent_fn interface. |
| | |
| | Parameters |
| | ---------- |
| | generate_fn : callable |
| | A function that takes a prompt string and returns a completion string. |
| | This abstracts over different model backends (HF, vLLM, API). |
| | prompt_builder : PromptBuilder, optional |
| | Custom prompt builder. Defaults to the standard PromptBuilder. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | generate_fn: Callable[[str], str], |
| | prompt_builder: Optional[PromptBuilder] = None, |
| | ) -> None: |
| | self._generate_fn = generate_fn |
| | self._prompt_builder = prompt_builder or PromptBuilder() |
| | self._last_prompt: str = "" |
| | self._last_completion: str = "" |
| |
|
| | def __call__(self, obs: GameObservation) -> GameAction: |
| | """Select an action given a game observation.""" |
| | prompt = self._prompt_builder.build(obs) |
| | self._last_prompt = prompt |
| | completion = self._generate_fn(prompt) |
| | self._last_completion = completion |
| | action_str = parse_action(completion, obs.available_actions) |
| | return GameAction(action=action_str) |
| |
|
| | @property |
| | def last_prompt(self) -> str: |
| | """The most recently constructed prompt.""" |
| | return self._last_prompt |
| |
|
| | @property |
| | def last_completion(self) -> str: |
| | """The most recent raw model completion.""" |
| | return self._last_completion |
| |
|
| |
|
| | class APIAgent(LLMAgent): |
| | """Agent that uses an external API (OpenAI/Anthropic) for generation. |
| | |
| | Parameters |
| | ---------- |
| | api_call_fn : callable |
| | Function(system_prompt, user_prompt) -> str that calls the API. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | api_call_fn: Callable[[str, str], str], |
| | prompt_builder: Optional[PromptBuilder] = None, |
| | ) -> None: |
| | def _generate(prompt: str) -> str: |
| | return api_call_fn(SYSTEM_PROMPT, prompt) |
| |
|
| | super().__init__(generate_fn=_generate, prompt_builder=prompt_builder) |
| |
|