Artvv's picture
Upload src/persistentpoker_bench/cli.py with huggingface_hub
7348220 verified
from __future__ import annotations
import argparse
import json
import os
import re
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Callable
from persistentpoker_bench.adapters.litellm_adapter import LiteLLMConfig
from persistentpoker_bench.budget import BudgetCaps
from persistentpoker_bench.hand_runner import HandRunnerConfig
from persistentpoker_bench.interactive import (
PlaySeatKind,
PlaySeatSpec,
PlaySessionConfig,
TerminalHandObserver,
parse_play_session_config,
play_terminal_match,
run_play_session,
)
from persistentpoker_bench.leaderboard import build_leaderboard_rows, export_leaderboard_csv
from persistentpoker_bench.match_runner import MatchRunnerConfig
from persistentpoker_bench.model_registry import (
DEFAULT_MODEL_REGISTRY,
LeaderboardTrack,
RegisteredModel,
find_registered_model,
models_for_track,
)
from persistentpoker_bench.retries import RetryPolicy
from persistentpoker_bench.replay import build_match_replay, export_match_replay_json
from persistentpoker_bench.runtime_agents import LiteLLMRuntimeAgent
from persistentpoker_bench.smoke import run_local_smoke_suite
from persistentpoker_bench.spec import DEFAULT_DETERMINISTIC_SEED
from persistentpoker_bench.testsupport import static_agent_factory
from persistentpoker_bench.tournament import (
TournamentEntrant,
TournamentLineup,
TournamentRunnerConfig,
export_decision_traces_jsonl,
export_match_results_jsonl,
export_match_summaries_jsonl,
run_tournament,
)
def main(argv: list[str] | None = None) -> int:
parser = _build_parser()
args = parser.parse_args(argv)
return args.func(args)
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="persistentpoker-bench")
subparsers = parser.add_subparsers(dest="command", required=True)
models_parser = subparsers.add_parser("models", help="List official benchmark models.")
models_parser.add_argument("--track", choices=[track.value for track in LeaderboardTrack], default=None)
models_parser.set_defaults(func=_cmd_models)
demo_parser = subparsers.add_parser("demo", help="Run a deterministic static demo tournament.")
demo_parser.add_argument("--track", choices=[track.value for track in LeaderboardTrack], default="frontier")
demo_parser.add_argument("--hands", type=int, default=2)
demo_parser.add_argument("--seeds", default="20260428")
demo_parser.add_argument("--outdir", required=True)
demo_parser.set_defaults(func=_cmd_demo)
play_parser = subparsers.add_parser("play", help="Play a live terminal match with one or more humans.")
play_parser.add_argument("--config", default=None)
play_parser.add_argument("--players", default="You,CPU1,CPU2,CPU3")
play_parser.add_argument("--human-seats", default="1")
play_parser.add_argument("--hands", type=int, default=1)
play_parser.add_argument("--seed", type=int, default=DEFAULT_DETERMINISTIC_SEED)
play_parser.add_argument("--starting-stack", type=int, default=2000)
play_parser.add_argument("--small-blind", type=int, default=10)
play_parser.add_argument("--big-blind", type=int, default=20)
play_parser.add_argument("--replay-out", default=None)
play_parser.set_defaults(func=_cmd_play)
web_parser = subparsers.add_parser("web", help="Launch the Gradio replay studio.")
web_parser.add_argument("--host", default="127.0.0.1")
web_parser.add_argument("--port", type=int, default=7860)
web_parser.add_argument("--share", action="store_true")
web_parser.set_defaults(func=_cmd_web)
smoke_parser = subparsers.add_parser("smoke", help="Run a longer local smoke suite.")
smoke_parser.add_argument("--outdir", required=True)
smoke_parser.add_argument("--hands", type=int, default=8)
smoke_parser.add_argument("--play-hands", type=int, default=5)
smoke_parser.add_argument("--provider-hands", type=int, default=2)
smoke_parser.add_argument("--seeds", default="20260428,20260429")
smoke_parser.add_argument("--skip-web", action="store_true")
smoke_parser.set_defaults(func=_cmd_smoke)
run_parser = subparsers.add_parser("run", help="Run a litellm-backed tournament from a JSON config.")
run_parser.add_argument("--config", required=True)
run_parser.add_argument("--outdir", required=True)
run_parser.add_argument("--pool-state", type=str, help="Path to JSON file with card array to seed pool.", default=None)
run_parser.set_defaults(func=_cmd_run)
return parser
def _cmd_models(args: argparse.Namespace) -> int:
models = (
models_for_track(LeaderboardTrack(args.track))
if args.track is not None
else DEFAULT_MODEL_REGISTRY
)
for model in models:
print(f"{model.track.value}\t{model.provider}\t{model.model_id}\t{model.display_name}")
return 0
def _cmd_demo(args: argparse.Namespace) -> int:
track = LeaderboardTrack(args.track)
seeds = _parse_seeds(args.seeds)
registered_models = models_for_track(track)
lineup = TournamentLineup(
lineup_id=f"{track.value}-demo-lineup",
entrants=tuple(
TournamentEntrant(
seat_name=f"P{index + 1}",
registered_model=registered_models[index],
agent_factory=static_agent_factory,
)
for index in range(4)
),
)
tournament_result = run_tournament(
lineups=(lineup,),
config=TournamentRunnerConfig(
track=track,
seeds=tuple(seeds),
match_config_template=MatchRunnerConfig(
hand_runner_config=HandRunnerConfig(seed=0),
hand_count=args.hands,
),
),
progress_callback=_build_cli_progress_reporter(),
)
_export_release_artifacts(tournament_result, Path(args.outdir))
return 0
def _cmd_play(args: argparse.Namespace) -> int:
if args.config:
session_config = parse_play_session_config(
json.loads(Path(args.config).read_text(encoding="utf-8"))
)
results = run_play_session(
session_config,
output=sys.stdout,
observer_factory=lambda visible_hole_seats: TerminalHandObserver(
output=sys.stdout,
visible_hole_seats=visible_hole_seats,
),
)
replay_path = args.replay_out or session_config.replay_out
if replay_path:
export_match_replay_json(
build_match_replay(hand_results=results, session_config=session_config, label="terminal-play"),
replay_path,
)
return 0
player_names = tuple(name.strip() for name in args.players.split(",") if name.strip())
human_seats = tuple(int(token.strip()) - 1 for token in args.human_seats.split(",") if token.strip())
results = play_terminal_match(
player_names=player_names,
human_seats=human_seats,
hand_count=args.hands,
config=HandRunnerConfig(
seed=args.seed,
starting_stack=args.starting_stack,
small_blind=args.small_blind,
big_blind=args.big_blind,
),
output=sys.stdout,
)
if args.replay_out:
session_config = PlaySessionConfig(
seats=tuple(
PlaySeatSpec(
name=seat_name,
kind=PlaySeatKind.HUMAN if seat_index in human_seats else PlaySeatKind.PASSIVE_BOT,
)
for seat_index, seat_name in enumerate(player_names)
),
hand_count=args.hands,
hand_runner_config=HandRunnerConfig(
seed=args.seed,
starting_stack=args.starting_stack,
small_blind=args.small_blind,
big_blind=args.big_blind,
),
)
export_match_replay_json(
build_match_replay(hand_results=results, session_config=session_config, label="terminal-play"),
args.replay_out,
)
return 0
def _cmd_web(args: argparse.Namespace) -> int:
from persistentpoker_bench.web_ui import build_web_app
demo = build_web_app()
demo.launch(server_name=args.host, server_port=args.port, share=bool(args.share))
return 0
def _cmd_smoke(args: argparse.Namespace) -> int:
result = run_local_smoke_suite(
outdir=args.outdir,
seeds=tuple(_parse_seeds(args.seeds)),
demo_hands=args.hands,
play_hands=args.play_hands,
provider_hands=args.provider_hands,
run_web_smoke=not bool(args.skip_web),
)
print(json.dumps(asdict(result), indent=2, sort_keys=True))
return 0
def _cmd_run(args: argparse.Namespace) -> int:
_load_runtime_env(Path(".env"))
config_payload = _expand_env_placeholders(json.loads(Path(args.config).read_text(encoding="utf-8")))
outdir = Path(args.outdir)
initial_pool = ()
if getattr(args, "pool_state", None):
initial_pool = tuple(json.loads(Path(args.pool_state).read_text(encoding="utf-8")))
tournament_result = _run_live_tournament_from_config(
config_payload,
progress_callback=_build_cli_progress_reporter(),
incremental_outdir=outdir,
initial_pool=initial_pool,
)
_export_release_artifacts(tournament_result, outdir, skip_jsonl=True)
return 0
def _run_live_tournament_from_config(
payload: dict[str, Any],
progress_callback: Callable[[dict[str, Any]], None] | None = None,
incremental_outdir: Path | None = None,
initial_pool: tuple[str, ...] = (),
):
track = LeaderboardTrack(payload["track"])
game_mode = str(payload.get("game_mode", "holdem"))
termination_rule = str(payload.get("termination_rule", "hand_limit"))
seeds = tuple(int(seed) for seed in payload["seeds"])
hand_count = int(payload["hand_count"])
starting_hand_number = int(payload.get("starting_hand_number", 1))
hand_seed = int(payload.get("base_seed", 0))
initial_button_index = int(payload.get("initial_button_index", 0))
budget_caps = _parse_budget_caps(payload.get("budget_caps"))
lineups = []
for lineup_payload in payload["lineups"]:
entrants = []
for entrant_payload in lineup_payload["entrants"]:
registered_model = _resolve_registered_model(entrant_payload, track)
retry_policy = RetryPolicy(
max_attempts=int(entrant_payload.get("max_attempts", 3)),
initial_delay_seconds=float(entrant_payload.get("initial_delay_seconds", 0.25)),
backoff_multiplier=float(entrant_payload.get("backoff_multiplier", 2.0)),
)
litellm_config = LiteLLMConfig(
model=entrant_payload.get("litellm_model", entrant_payload["model_id"]),
temperature=_optional_float(entrant_payload.get("temperature", 0.0)),
max_tokens=int(entrant_payload.get("max_tokens", 400)),
timeout=float(entrant_payload.get("timeout", 60.0)),
prefer_json_mode=bool(entrant_payload.get("prefer_json_mode", True)),
extra_kwargs=dict(entrant_payload.get("extra_kwargs", {})),
)
entrants.append(
TournamentEntrant(
seat_name=entrant_payload["seat_name"],
registered_model=registered_model,
agent_factory=_runtime_factory(
provider=registered_model.provider,
litellm_config=litellm_config,
retry_policy=retry_policy,
),
)
)
lineups.append(TournamentLineup(lineup_id=lineup_payload["lineup_id"], entrants=tuple(entrants)))
return run_tournament(
lineups=tuple(lineups),
config=TournamentRunnerConfig(
track=track,
seeds=seeds,
match_config_template=MatchRunnerConfig(
hand_runner_config=HandRunnerConfig(seed=hand_seed, game_mode=game_mode),
hand_count=hand_count,
initial_button_index=initial_button_index,
game_mode=game_mode,
termination_rule=termination_rule,
starting_hand_number=starting_hand_number,
initial_pool=initial_pool,
),
budget_caps=budget_caps,
game_mode=game_mode,
termination_rule=termination_rule,
initial_pool=initial_pool,
),
progress_callback=progress_callback,
incremental_outdir=incremental_outdir,
)
def _runtime_factory(*, provider: str, litellm_config: LiteLLMConfig, retry_policy: RetryPolicy):
def factory() -> LiteLLMRuntimeAgent:
return LiteLLMRuntimeAgent(provider=provider, config=litellm_config, retry_policy=retry_policy)
return factory
def _parse_seeds(raw: str) -> list[int]:
return [int(token.strip()) for token in raw.split(",") if token.strip()]
def _parse_budget_caps(payload: dict[str, Any] | None) -> BudgetCaps | None:
if payload is None:
return None
return BudgetCaps(
total_cost_cap=payload.get("total_cost_cap"),
per_provider_cap=dict(payload.get("per_provider_cap", {})),
per_model_cap=dict(payload.get("per_model_cap", {})),
)
def _resolve_registered_model(
entrant_payload: dict[str, Any],
track: LeaderboardTrack,
) -> RegisteredModel:
provider = str(entrant_payload["provider"])
model_id = str(entrant_payload["model_id"])
try:
return find_registered_model(provider=provider, model_id=model_id)
except ValueError:
display_name = str(entrant_payload.get("display_name", model_id))
api_style = str(entrant_payload.get("api_style", "openai_compatible"))
notes = str(entrant_payload.get("notes", "Custom benchmark entrant from config."))
return RegisteredModel(
provider=provider,
model_id=model_id,
display_name=display_name,
track=track,
api_style=api_style,
notes=notes,
)
def _optional_float(value: Any) -> float | None:
if value is None:
return None
return float(value)
ENV_TOKEN_RE = re.compile(r"\$\{([A-Z0-9_]+)\}")
def _expand_env_placeholders(value: Any) -> Any:
if isinstance(value, dict):
return {key: _expand_env_placeholders(item) for key, item in value.items()}
if isinstance(value, list):
return [_expand_env_placeholders(item) for item in value]
if isinstance(value, str):
return ENV_TOKEN_RE.sub(lambda match: os.getenv(match.group(1), match.group(0)), value)
return value
def _export_release_artifacts(tournament_result, outdir: Path, skip_jsonl: bool = False) -> None:
outdir.mkdir(parents=True, exist_ok=True)
if not skip_jsonl:
export_match_results_jsonl(tournament_result, outdir / "results.jsonl")
export_match_summaries_jsonl(tournament_result, outdir / "match_summaries.jsonl")
export_decision_traces_jsonl(tournament_result, outdir / "decision_traces.jsonl")
export_leaderboard_csv(build_leaderboard_rows(tournament_result), outdir / "leaderboard.csv")
summary = {
"track": tournament_result.track.value,
"match_count": len(tournament_result.match_records),
"artifacts": {
"results_jsonl": str(outdir / "results.jsonl"),
"match_summaries_jsonl": str(outdir / "match_summaries.jsonl"),
"decision_traces_jsonl": str(outdir / "decision_traces.jsonl"),
"leaderboard_csv": str(outdir / "leaderboard.csv"),
},
}
(outdir / "run_summary.json").write_text(json.dumps(summary, indent=2, sort_keys=True), encoding="utf-8")
print(
"[ppb] Artifacts written"
f" | outdir={outdir}"
f" | track={tournament_result.track.value}"
f" | matches={len(tournament_result.match_records)}",
flush=True,
)
def _load_runtime_env(path: Path) -> None:
if path.exists():
for line in path.read_text(encoding="utf-8").splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#") or "=" not in stripped:
continue
key, value = stripped.split("=", maxsplit=1)
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
alias_pairs = (
("openai_api_key", "OPENAI_API_KEY"),
("xai_api_key", "XAI_API_KEY"),
("deepseek_api_key", "DEEPSEEK_API_KEY"),
("gemini_api_key", "GEMINI_API_KEY"),
("google_api_key", "GOOGLE_API_KEY"),
("mistral_api_key", "MISTRAL_API_KEY"),
)
for alias, canonical in alias_pairs:
if not os.getenv(canonical) and os.getenv(alias):
os.environ[canonical] = str(os.getenv(alias))
if os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = str(os.getenv("GEMINI_API_KEY"))
def _build_cli_progress_reporter() -> Callable[[dict[str, Any]], None]:
def report(event: dict[str, Any]) -> None:
event_type = str(event.get("event_type", "unknown"))
if event_type == "tournament_started":
print(
"[ppb] Tournament start"
f" | track={event.get('track')}"
f" | matches={event.get('total_matches')}"
f" | lineups={event.get('lineup_count')}"
f" | seeds={event.get('seed_count')}"
f" | hands/match={event.get('hands_per_match')}",
flush=True,
)
return
if event_type == "match_started":
entrants = event.get("entrants", [])
entrant_summary = ", ".join(
f"{entrant['seat_name']}={entrant['provider']}/{entrant['model_id']}"
for entrant in entrants
)
print(
"[ppb] Match start"
f" | {int(event.get('completed_matches', 0)) + 1}/{event.get('total_matches')}"
f" | lineup={event.get('lineup_id')}"
f" | seed={event.get('seed')}"
f" | entrants=[{entrant_summary}]",
flush=True,
)
return
if event_type == "hand_completed":
winner_indices = event.get("winning_player_indices", [])
winner_summary = (
"seat[" + ",".join(str(index) for index in winner_indices) + "]"
if winner_indices
else "unknown"
)
print(
"[ppb] Hand done"
f" | lineup={event.get('lineup_id')}"
f" | seed={event.get('seed')}"
f" | hand={event.get('hand_number')}"
f" | winners={winner_summary}"
f" | pool={event.get('pool_size_after')}"
f" | next_pool_decision={event.get('winner_pool_decision')}",
flush=True,
)
return
if event_type == "match_completed":
budget_snapshot = event.get("budget_snapshot") or {}
print(
"[ppb] Match done"
f" | {event.get('completed_matches')}/{event.get('total_matches')}"
f" | lineup={event.get('lineup_id')}"
f" | seed={event.get('seed')}"
f" | avg_pool={float(event.get('average_pool_size', 0.0)):.2f}"
f" | est_cost=${float(event.get('estimated_total_cost', 0.0)):.6f}"
f" | budget_total=${float(budget_snapshot.get('total_cost', 0.0)):.6f}",
flush=True,
)
return
if event_type == "tournament_completed":
budget_snapshot = event.get("budget_snapshot") or {}
print(
"[ppb] Tournament done"
f" | track={event.get('track')}"
f" | matches={event.get('completed_matches')}/{event.get('total_matches')}"
f" | budget_total=${float(budget_snapshot.get('total_cost', 0.0)):.6f}",
flush=True,
)
return
print(f"[ppb] Event | {json.dumps(event, sort_keys=True)}", flush=True)
return report
if __name__ == "__main__":
raise SystemExit(main())