jtowarek's picture
Upload folder using huggingface_hub
3ff9218 verified
"""LLM agent for game-theory environments."""
from __future__ import annotations
import random
from typing import Any, Callable, Dict, List, Optional
from env.models import GameAction, GameObservation
from constant_definitions.train.agent_constants import (
MAX_ACTION_TOKENS,
MAX_PROMPT_HISTORY_ROUNDS,
PARSE_FAILURE_SENTINEL,
PROMPT_SECTION_ACTIONS,
PROMPT_SECTION_GAME,
PROMPT_SECTION_HISTORY,
PROMPT_SECTION_INSTRUCTION,
PROMPT_SECTION_SCORES,
SYSTEM_PROMPT,
TRAIN_TEMPERATURE_DENOMINATOR,
TRAIN_TEMPERATURE_NUMERATOR,
)
_ZERO = int()
_ONE = int(bool(True))
_NEWLINE = "\n"
_SECTION_SEP = "\n\n"
_BRACKET_OPEN = "["
_BRACKET_CLOSE = "]"
_COLON_SPACE = ": "
_DASH_SPACE = "- "
_ROUND_PREFIX = "Round "
_YOU_PLAYED = " | You played: "
_OPP_PLAYED = " | Opponent played: "
_YOUR_PAYOFF = " | Your payoff: "
_OPP_PAYOFF = " | Opp payoff: "
class PromptBuilder:
"""Formats GameObservation into a structured text prompt.
The prompt intentionally excludes the opponent strategy name
to prevent the model from shortcutting via strategy recognition.
"""
@staticmethod
def build(obs: GameObservation) -> str:
"""Build a structured prompt from a game observation."""
sections: List[str] = []
# Game section
sections.append(
_BRACKET_OPEN + PROMPT_SECTION_GAME + _BRACKET_CLOSE
+ _NEWLINE + obs.game_name
+ _NEWLINE + obs.game_description
)
# History section (limited to last N rounds)
if obs.history:
history_lines: List[str] = []
history_slice = obs.history[-MAX_PROMPT_HISTORY_ROUNDS:]
for rnd in history_slice:
line = (
_ROUND_PREFIX + str(rnd.round_number)
+ _YOU_PLAYED + rnd.player_action
+ _OPP_PLAYED + rnd.opponent_action
+ _YOUR_PAYOFF + str(rnd.player_payoff)
+ _OPP_PAYOFF + str(rnd.opponent_payoff)
)
history_lines.append(line)
sections.append(
_BRACKET_OPEN + PROMPT_SECTION_HISTORY + _BRACKET_CLOSE
+ _NEWLINE + _NEWLINE.join(history_lines)
)
# Scores section
sections.append(
_BRACKET_OPEN + PROMPT_SECTION_SCORES + _BRACKET_CLOSE
+ _NEWLINE + "Your score" + _COLON_SPACE + str(obs.player_score)
+ _NEWLINE + "Opponent score" + _COLON_SPACE + str(obs.opponent_score)
+ _NEWLINE + "Round" + _COLON_SPACE + str(obs.current_round)
+ " of " + str(obs.total_rounds)
)
# Available actions
action_lines = [_DASH_SPACE + a for a in obs.available_actions]
sections.append(
_BRACKET_OPEN + PROMPT_SECTION_ACTIONS + _BRACKET_CLOSE
+ _NEWLINE + _NEWLINE.join(action_lines)
)
# Instruction
sections.append(
_BRACKET_OPEN + PROMPT_SECTION_INSTRUCTION + _BRACKET_CLOSE
+ _NEWLINE + SYSTEM_PROMPT
)
return _SECTION_SEP.join(sections)
def parse_action(response: str, available_actions: List[str]) -> str:
"""Parse an action from LLM response text.
Tries: exact match -> case-insensitive -> substring -> random selection.
"""
stripped = response.strip()
# Exact match
if stripped in available_actions:
return stripped
# Case-insensitive match
lower = stripped.lower()
for action in available_actions:
if action.lower() == lower:
return action
# Substring match (response contains action name)
for action in available_actions:
if action.lower() in lower:
return action
# Random selection as last resort
return random.choice(available_actions)
class LLMAgent:
"""LLM-based agent compatible with TournamentRunner agent_fn interface.
Parameters
----------
generate_fn : callable
A function that takes a prompt string and returns a completion string.
This abstracts over different model backends (HF, vLLM, API).
prompt_builder : PromptBuilder, optional
Custom prompt builder. Defaults to the standard PromptBuilder.
"""
def __init__(
self,
generate_fn: Callable[[str], str],
prompt_builder: Optional[PromptBuilder] = None,
) -> None:
self._generate_fn = generate_fn
self._prompt_builder = prompt_builder or PromptBuilder()
self._last_prompt: str = ""
self._last_completion: str = ""
def __call__(self, obs: GameObservation) -> GameAction:
"""Select an action given a game observation."""
prompt = self._prompt_builder.build(obs)
self._last_prompt = prompt
completion = self._generate_fn(prompt)
self._last_completion = completion
action_str = parse_action(completion, obs.available_actions)
return GameAction(action=action_str)
@property
def last_prompt(self) -> str:
"""The most recently constructed prompt."""
return self._last_prompt
@property
def last_completion(self) -> str:
"""The most recent raw model completion."""
return self._last_completion
class APIAgent(LLMAgent):
"""Agent that uses an external API (OpenAI/Anthropic) for generation.
Parameters
----------
api_call_fn : callable
Function(system_prompt, user_prompt) -> str that calls the API.
"""
def __init__(
self,
api_call_fn: Callable[[str, str], str],
prompt_builder: Optional[PromptBuilder] = None,
) -> None:
def _generate(prompt: str) -> str:
return api_call_fn(SYSTEM_PROMPT, prompt)
super().__init__(generate_fn=_generate, prompt_builder=prompt_builder)