lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
"""Hanabi AHT evaluation API. Upload checkpoint, get scores vs held-out partners."""
import argparse
import json
import os
import shutil
import sys
import tempfile
import time
import uuid
import threading
import zipfile
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import jax
import jax.numpy as jnp
import numpy as np
from flask import Flask, jsonify, request, send_from_directory
from flask_cors import CORS
jax.config.update('jax_platform_name', 'cpu')
from envs import make_env
from envs.log_wrapper import LogWrapper
from common.run_episodes import run_episodes
from agents.hanabi.bc_lstm_agent import BCLSTMAgent
from agents.hanabi.agent_policy_wrappers import (
HanabiBCLSTMPolicyWrapper,
HanabiIGGIPolicyWrapper,
HanabiPiersPolicyWrapper,
HanabiSmartBotPolicyWrapper,
HanabiRandomPolicyWrapper,
HanabiOBLPolicyWrapper,
HanabiRuleBasedPolicyWrapper,
HanabiFlawedPolicyWrapper,
HanabiOuterPolicyWrapper,
HanabiVanDenBerghPolicyWrapper,
)
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 # 200 MB upload limit
CORS(app)
DEFAULT_NUM_EPISODES = 64
MAX_EPISODE_STEPS = 128
LEADERBOARD_FILE = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'leaderboard.json'
)
HAND_SIZE, NUM_COLORS, NUM_RANKS, NUM_ACTIONS = 5, 5, 5, 21
_RB_KW = dict(hand_size=HAND_SIZE, num_colors=NUM_COLORS, num_ranks=NUM_RANKS,
num_actions=NUM_ACTIONS, using_log_wrapper=True)
def _obl_path(level, seed='a'):
load = '_LOAD1' if level > 1 else ''
return f'agents/hanabi/obl-r2d2-flax/icml_OBL{level}/OFF_BELIEF1_SHUFFLE_COLOR0{load}_BZA0_BELIEF_{seed}.safetensors'
PARTNER_REGISTRY = {
# Baselines
'random': {
'name': 'Random',
'description': 'Uniform random legal action baseline (0/25)',
'factory': lambda: HanabiRandomPolicyWrapper(num_actions=NUM_ACTIONS, using_log_wrapper=True),
},
# Rule-based strategy variants (action-category priority agents)
'rule_based_cautious': {
'name': 'RuleBased (cautious)',
'description': 'Priority: play first, then discard, then hint',
'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='cautious', **_RB_KW),
},
'rule_based_aggressive': {
'name': 'RuleBased (aggressive)',
'description': 'Priority: play only, fallback to hint',
'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='aggressive', **_RB_KW),
},
'rule_based_communicative': {
'name': 'RuleBased (communicative)',
'description': 'Priority: hint first, then play',
'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='communicative', **_RB_KW),
},
'rule_based_frugal': {
'name': 'RuleBased (frugal)',
'description': 'Priority: discard first to save info tokens',
'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='frugal', **_RB_KW),
},
# Walton-Rivers agents
'iggi': {
'name': 'IGGI',
'description': 'Game-state-aware rule-based (Walton-Rivers 2017). 11.7/25 self-play.',
'factory': lambda: HanabiIGGIPolicyWrapper(**_RB_KW),
},
'piers': {
'name': 'Piers',
'description': 'IGGI + probabilistic play (>60%) + dispensable hints. 11.6/25.',
'factory': lambda: HanabiPiersPolicyWrapper(**_RB_KW),
},
'van_den_bergh': {
'name': 'Van Den Bergh',
'description': 'Probabilistic discard ordering to protect valuable cards. 9.3/25.',
'factory': lambda: HanabiVanDenBerghPolicyWrapper(**_RB_KW),
},
'outer': {
'name': 'Outer',
'description': 'Hint-heavy agent. Scores 0/25 on full Hanabi (exhausts tokens).',
'factory': lambda: HanabiOuterPolicyWrapper(**_RB_KW),
},
'flawed': {
'name': 'Flawed (30%)',
'description': 'IGGI with 30% random-mistake probability. 0.2/25.',
'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.3, **_RB_KW),
},
'': {
'name': 'Flawed (60%)',
'description': 'IGGI with 60% random-mistake probability. ~0/25.',
'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.6, **_RB_KW),
},
# SmartBot (strongest rule-based)
'smartbot': {
'name': 'SmartBot',
'description': 'Convention-heavy rule-based agent. 18.0/25 self-play, 0.4 vs human proxy.',
'factory': lambda: HanabiSmartBotPolicyWrapper(**_RB_KW),
},
# OBL R2D2 pretrained, all 5 belief levels (seed-a)
'obl_l1': {
'name': 'OBL-L1',
'description': 'Off-Belief Learning level 1 (Hu et al. 2021). 20.9/25 self-play.',
'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(1), using_log_wrapper=True),
},
'obl_l2': {
'name': 'OBL-L2',
'description': 'Off-Belief Learning level 2.',
'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(2), using_log_wrapper=True),
},
'obl_l3': {
'name': 'OBL-L3',
'description': 'Off-Belief Learning level 3.',
'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(3), using_log_wrapper=True),
},
'obl_l4': {
'name': 'OBL-L4',
'description': 'Off-Belief Learning level 4. 24.3/25 self-play (near-optimal).',
'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(4), using_log_wrapper=True),
},
'obl_l5': {
'name': 'OBL-L5',
'description': 'Off-Belief Learning level 5.',
'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(5), using_log_wrapper=True),
},
# Human proxy
'bc_lstm': {
'name': 'BC-LSTM Human Proxy',
'description': 'Behavioral cloning LSTM trained on 1K human games (AH2AC2). 45.5% accuracy.',
'factory': lambda: HanabiBCLSTMPolicyWrapper(
weight_file='agents/hanabi/bc_lstm_weights/bc_2p.safetensors',
using_log_wrapper=True, greedy=True),
},
}
# Trained ego agents (RL checkpoints). Returns (policy, params) not just policy.
# Separate from PARTNER_REGISTRY because these carry learned parameters.
_ego_cache = {}
IPPO_FIXTURE = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'fixtures/hanabi_ippo_1e9_seed0/saved_train_run'
)
def _cpu_safe_restore(path):
"""Restore an orbax checkpoint with explicit CPU sharding.
Needed because fixtures saved on GPU machines encode GPU sharding,
which orbax rejects when restoring on CPU-only machines (the eval
API host). Bypasses common.save_load_utils.load_train_run which
uses implicit sharding.
"""
import orbax.checkpoint as ocp
if not os.path.isabs(path):
from common.save_load_utils import REPO_PATH
path = os.path.join(REPO_PATH, path)
checkpointer = ocp.PyTreeCheckpointer()
cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0])
meta = checkpointer.metadata(path)
restore_args = jax.tree.map(
lambda _: ocp.ArrayRestoreArgs(sharding=cpu_sharding),
meta,
)
restored = checkpointer.restore(path, restore_args=restore_args)
# Convert to jax arrays
restored = jax.tree_util.tree_map(
lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x,
restored,
)
return restored
def _load_trained_ego(fixture_path, actor_type='mlp', hidden_dim=512,
activation='relu', idx_list=None):
"""Load trained ego from orbax checkpoint. Cached after first load."""
cache_key = (fixture_path, actor_type, hidden_dim, activation,
tuple(idx_list) if idx_list else None)
if cache_key in _ego_cache:
return _ego_cache[cache_key]
from evaluation.heldout_evaluator import extract_params
from agents.mlp_actor_critic_agent import MLPActorCriticPolicy
env = make_env('hanabi')
env = LogWrapper(env)
# Use CPU-safe restore instead of the default loader
restored = _cpu_safe_restore(fixture_path)
ckpts = restored['checkpoints']
# Build policy directly. JaxMARL Hanabi's observation_space.shape is a
# bare int (658), not a tuple, which breaks initialize_mlp_agent's
# `.shape[0]` access. Hardcode obs_dim=658 for full Hanabi.
obs_dim = 658
action_dim = 21
policy = MLPActorCriticPolicy(
action_dim=action_dim,
obs_dim=obs_dim,
activation=activation,
)
init_params = policy.init_params(jax.random.PRNGKey(0))
# Extract the requested checkpoint. Default to the first.
idx = idx_list if idx_list is not None else [0]
params_list, _ = extract_params(ckpts, init_params, idx_labels=None)
params = params_list[idx[0]] if idx[0] < len(params_list) else params_list[-1]
_ego_cache[cache_key] = (policy, params)
return policy, params
def _extract_and_find_checkpoint(zip_path, extract_dir):
"""Extract a zipped orbax checkpoint and locate the saved_train_run directory.
Returns (saved_train_run_path, optional_config_dict).
Raises ValueError if the structure is invalid.
"""
# Reject zip bombs: cap at 500 MB uncompressed
MAX_UNCOMPRESSED = 500 * 1024 * 1024
with zipfile.ZipFile(zip_path, 'r') as zf:
total = sum(info.file_size for info in zf.infolist())
if total > MAX_UNCOMPRESSED:
raise ValueError(
f"Uncompressed size {total} bytes exceeds {MAX_UNCOMPRESSED} limit"
)
# Reject absolute paths and traversal
for info in zf.infolist():
if info.filename.startswith('/') or '..' in info.filename.split('/'):
raise ValueError(f"Unsafe path in zip: {info.filename}")
zf.extractall(extract_dir)
# Find saved_train_run directory anywhere in the extracted tree
saved_run = None
config_path = None
for root, dirs, files in os.walk(extract_dir):
if 'saved_train_run' in dirs:
saved_run = os.path.join(root, 'saved_train_run')
if 'config.yaml' in files:
config_path = os.path.join(root, 'config.yaml')
if saved_run and config_path:
break
if saved_run is None:
# Maybe the zip IS the saved_train_run directory (user zipped its contents)
# Look for orbax markers
for root, dirs, files in os.walk(extract_dir):
if '_CHECKPOINT_METADATA' in files or '_METADATA' in files:
saved_run = root
break
if saved_run is None:
raise ValueError(
"Could not find saved_train_run/ directory or orbax checkpoint "
"in the uploaded zip. Expected structure: saved_train_run/ with "
"orbax files, optionally alongside config.yaml."
)
config_dict = None
if config_path:
try:
from omegaconf import OmegaConf
config_dict = OmegaConf.to_container(OmegaConf.load(config_path), resolve=True)
except Exception:
config_dict = None
return saved_run, config_dict
# Ego registry: any agent in PARTNER_REGISTRY can also be used as ego
# (heuristic agents have no params, just a factory). Trained RL checkpoints
# have their own entries with fixture paths.
EGO_REGISTRY = {
# Heuristic/pretrained agents (reuse PARTNER_REGISTRY factories at dispatch time)
key: {'name': info['name'], 'description': info['description'], 'trained': False}
for key, info in PARTNER_REGISTRY.items()
}
# Trained RL checkpoints (orbax fixtures)
EGO_REGISTRY['ippo_1e9'] = {
'name': 'IPPO (trained, 1e9)',
'description': '3-seed IPPO trained 1e9 steps. 15.78/25 self-play, 0-5/25 with unseen partners.',
'trained': True,
'fixture': IPPO_FIXTURE,
}
def evaluate_ego_vs_partner(ego_policy, ego_params, partner_policy,
num_episodes=64, rng_seed=34957):
env = make_env('hanabi')
env = LogWrapper(env)
rng = jax.random.PRNGKey(rng_seed)
metrics = run_episodes(
rng, env,
agent_0_policy=ego_policy,
agent_0_param=ego_params,
agent_1_policy=partner_policy,
agent_1_param=None,
max_episode_steps=MAX_EPISODE_STEPS,
num_eps=num_episodes,
agent_0_test_mode=True,
agent_1_test_mode=False,
)
returns = np.array(metrics['returned_episode_returns'])
episode_scores = returns.mean(axis=-1)
return {
'mean': float(np.mean(episode_scores)),
'median': float(np.median(episode_scores)),
'std': float(np.std(episode_scores)),
'min': float(np.min(episode_scores)),
'max': float(np.max(episode_scores)),
'num_episodes': int(num_episodes),
'per_episode': episode_scores.tolist(),
}
_leaderboard_lock = threading.Lock()
def load_leaderboard():
if os.path.exists(LEADERBOARD_FILE):
with open(LEADERBOARD_FILE) as f:
return json.load(f)
return []
def save_leaderboard(entries):
with open(LEADERBOARD_FILE, 'w') as f:
json.dump(entries, f, indent=2)
def add_leaderboard_entry(entry):
with _leaderboard_lock:
entries = load_leaderboard()
entries.append(entry)
entries.sort(key=lambda x: -x.get('mean_score', 0))
save_leaderboard(entries)
return entries
@app.route('/api/partners', methods=['GET'])
def list_partners():
result = {}
for key, info in PARTNER_REGISTRY.items():
result[key] = {
'name': info['name'],
'description': info['description'],
}
return jsonify(result)
@app.route('/api/egos', methods=['GET'])
def list_egos():
result = {}
for key, info in EGO_REGISTRY.items():
result[key] = {
'name': info['name'],
'description': info['description'],
'trained': info['trained'],
}
return jsonify(result)
@app.route('/api/evaluate_upload', methods=['POST'])
def evaluate_upload():
"""Evaluate an uploaded orbax checkpoint zip against a held-out partner.
Multipart form fields:
checkpoint: .zip of orbax checkpoint (must contain saved_train_run/
directory, optionally with config.yaml at the root)
agent_name: str (label for leaderboard)
partner: str (partner key from PARTNER_REGISTRY)
num_episodes: int (default DEFAULT_NUM_EPISODES)
actor_type: str (default 'mlp', overridden by config.yaml if present)
hidden_dim: int (default 512, overridden by config.yaml if present)
activation: str (default 'relu', overridden by config.yaml if present)
"""
checkpoint_file = request.files.get('checkpoint')
if checkpoint_file is None:
return jsonify({'error': 'No checkpoint file uploaded'}), 400
if not checkpoint_file.filename.lower().endswith('.zip'):
return jsonify({'error': 'Checkpoint must be a .zip archive'}), 400
partner_key = request.form.get('partner', 'bc_lstm')
if partner_key not in PARTNER_REGISTRY:
return jsonify({'error': f'Unknown partner: {partner_key}'}), 400
agent_name = request.form.get('agent_name') or 'Uploaded checkpoint'
num_episodes = min(int(request.form.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000)
actor_type = request.form.get('actor_type', 'mlp')
hidden_dim = int(request.form.get('hidden_dim', 512))
activation = request.form.get('activation', 'relu')
# Extract to a temp directory
upload_dir = os.path.join(os.path.dirname(LEADERBOARD_FILE), 'uploads')
os.makedirs(upload_dir, exist_ok=True)
temp_zip = os.path.join(upload_dir, f'{uuid.uuid4().hex}.zip')
extract_dir = tempfile.mkdtemp(prefix='eval_upload_', dir=upload_dir)
try:
checkpoint_file.save(temp_zip)
try:
saved_run_path, cfg = _extract_and_find_checkpoint(temp_zip, extract_dir)
except ValueError as e:
return jsonify({'error': f'Invalid checkpoint zip: {str(e)}'}), 400
# Pull architecture from config.yaml if present
if cfg and 'algorithm' in cfg:
algo = cfg['algorithm']
actor_type = algo.get('ACTOR_TYPE', actor_type)
hidden_dim = int(algo.get('FC_HIDDEN_DIM', hidden_dim))
activation = algo.get('ACTIVATION', activation)
# Load the checkpoint
try:
ego_policy, ego_params = _load_trained_ego(
saved_run_path, actor_type=actor_type,
hidden_dim=hidden_dim, activation=activation,
)
except Exception as e:
return jsonify({
'error': f'Failed to load checkpoint: {str(e)}',
'hint': ('Verify actor_type, hidden_dim, and activation match the '
'architecture the checkpoint was trained with. Include '
'config.yaml in the zip to auto-detect.'),
}), 400
partner_policy = PARTNER_REGISTRY[partner_key]['factory']()
start_time = time.time()
try:
results = evaluate_ego_vs_partner(
ego_policy, ego_params, partner_policy,
num_episodes=num_episodes,
)
except Exception as e:
return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500
elapsed = time.time() - start_time
submission_id = str(uuid.uuid4())[:8]
entry = {
'id': submission_id,
'agent_name': agent_name,
'partner': partner_key,
'partner_name': PARTNER_REGISTRY[partner_key]['name'],
'mean_score': results['mean'],
'median_score': results['median'],
'std': results['std'],
'min_score': results['min'],
'max_score': results['max'],
'num_episodes': results['num_episodes'],
'ego_type': 'uploaded',
'ego_name': f'Uploaded ({actor_type}, h={hidden_dim})',
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'eval_time_seconds': round(elapsed, 1),
}
leaderboard = add_leaderboard_entry(entry)
response = {
**entry,
'per_episode_scores': results['per_episode'],
'leaderboard_position': next(
(i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id),
len(leaderboard),
),
'total_submissions': len(leaderboard),
'architecture': {'actor_type': actor_type, 'hidden_dim': hidden_dim,
'activation': activation},
}
return jsonify(response)
finally:
# Clean up temp files
if os.path.exists(temp_zip):
os.remove(temp_zip)
if os.path.exists(extract_dir):
shutil.rmtree(extract_dir, ignore_errors=True)
@app.route('/api/evaluate', methods=['POST'])
def evaluate():
data = request.json or {}
agent_name = data.get('agent_name', 'Anonymous')
partner_key = data.get('partner', 'bc_lstm')
num_episodes = min(int(data.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000)
ego_key = data.get('ego', 'iggi')
if partner_key not in PARTNER_REGISTRY:
return jsonify({'error': f'Unknown partner: {partner_key}',
'available': list(PARTNER_REGISTRY.keys())}), 400
if ego_key not in EGO_REGISTRY:
return jsonify({'error': f'Unknown ego: {ego_key}',
'available': list(EGO_REGISTRY.keys())}), 400
partner_policy = PARTNER_REGISTRY[partner_key]['factory']()
ego_info = EGO_REGISTRY[ego_key]
if ego_info['trained']:
fixture = ego_info['fixture']
if not os.path.exists(fixture):
return jsonify({'error': f'Checkpoint not found: {fixture}'}), 404
try:
ego_policy, ego_params = _load_trained_ego(fixture)
except Exception as e:
return jsonify({'error': f'Failed to load checkpoint: {str(e)}'}), 500
else:
if ego_key not in PARTNER_REGISTRY:
return jsonify({'error': f'No factory for ego: {ego_key}'}), 400
ego_policy = PARTNER_REGISTRY[ego_key]['factory']()
ego_params = None
start_time = time.time()
try:
results = evaluate_ego_vs_partner(
ego_policy, ego_params, partner_policy,
num_episodes=num_episodes
)
except Exception as e:
return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500
elapsed = time.time() - start_time
submission_id = str(uuid.uuid4())[:8]
entry = {
'id': submission_id,
'agent_name': agent_name,
'partner': partner_key,
'partner_name': PARTNER_REGISTRY[partner_key]['name'],
'mean_score': results['mean'],
'median_score': results['median'],
'std': results['std'],
'min_score': results['min'],
'max_score': results['max'],
'num_episodes': results['num_episodes'],
'ego_type': ego_key,
'ego_name': EGO_REGISTRY[ego_key]['name'],
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'eval_time_seconds': round(elapsed, 1),
}
leaderboard = add_leaderboard_entry(entry)
response = {
**entry,
'per_episode_scores': results['per_episode'],
'leaderboard_position': next(
(i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id),
len(leaderboard)
),
'total_submissions': len(leaderboard),
}
return jsonify(response)
@app.route('/api/leaderboard', methods=['GET'])
def leaderboard():
partner_filter = request.args.get('partner')
entries = load_leaderboard()
if partner_filter:
entries = [e for e in entries if e.get('partner') == partner_filter]
return jsonify({
'entries': entries,
'total': len(entries),
})
@app.route('/api/leaderboard/clear', methods=['POST'])
def clear_leaderboard():
with _leaderboard_lock:
save_leaderboard([])
return jsonify({'status': 'cleared'})
@app.route('/api/info', methods=['GET'])
def api_info():
return jsonify({
'service': 'Hanabi AHT Policy Evaluation API',
'version': '1.0',
'endpoints': {
'POST /api/evaluate': 'Submit a policy for evaluation',
'GET /api/leaderboard': 'View submission rankings',
'GET /api/partners': 'List available evaluation partners',
},
'supported_partners': list(PARTNER_REGISTRY.keys()),
'default_num_episodes': DEFAULT_NUM_EPISODES,
'leaderboard_file': LEADERBOARD_FILE,
})
@app.route('/')
def index():
return send_from_directory(
os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval_ui'),
'index.html'
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hanabi AHT Evaluation API')
parser.add_argument('--port', type=int, default=5001)
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument('--num-episodes', type=int, default=64)
args = parser.parse_args()
DEFAULT_NUM_EPISODES = args.num_episodes
print(f'Hanabi AHT Evaluation API')
print(f' http://{args.host}:{args.port}')
print(f' Partners: {", ".join(PARTNER_REGISTRY.keys())}')
print(f' Default episodes: {DEFAULT_NUM_EPISODES}')
app.run(host=args.host, port=args.port, debug=False)