"""Hanabi AHT evaluation API. Upload checkpoint, get scores vs held-out partners.""" import argparse import json import os import shutil import sys import tempfile import time import uuid import threading import zipfile sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import jax import jax.numpy as jnp import numpy as np from flask import Flask, jsonify, request, send_from_directory from flask_cors import CORS jax.config.update('jax_platform_name', 'cpu') from envs import make_env from envs.log_wrapper import LogWrapper from common.run_episodes import run_episodes from agents.hanabi.bc_lstm_agent import BCLSTMAgent from agents.hanabi.agent_policy_wrappers import ( HanabiBCLSTMPolicyWrapper, HanabiIGGIPolicyWrapper, HanabiPiersPolicyWrapper, HanabiSmartBotPolicyWrapper, HanabiRandomPolicyWrapper, HanabiOBLPolicyWrapper, HanabiRuleBasedPolicyWrapper, HanabiFlawedPolicyWrapper, HanabiOuterPolicyWrapper, HanabiVanDenBerghPolicyWrapper, ) app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 # 200 MB upload limit CORS(app) DEFAULT_NUM_EPISODES = 64 MAX_EPISODE_STEPS = 128 LEADERBOARD_FILE = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'leaderboard.json' ) HAND_SIZE, NUM_COLORS, NUM_RANKS, NUM_ACTIONS = 5, 5, 5, 21 _RB_KW = dict(hand_size=HAND_SIZE, num_colors=NUM_COLORS, num_ranks=NUM_RANKS, num_actions=NUM_ACTIONS, using_log_wrapper=True) def _obl_path(level, seed='a'): load = '_LOAD1' if level > 1 else '' return f'agents/hanabi/obl-r2d2-flax/icml_OBL{level}/OFF_BELIEF1_SHUFFLE_COLOR0{load}_BZA0_BELIEF_{seed}.safetensors' PARTNER_REGISTRY = { # Baselines 'random': { 'name': 'Random', 'description': 'Uniform random legal action baseline (0/25)', 'factory': lambda: HanabiRandomPolicyWrapper(num_actions=NUM_ACTIONS, using_log_wrapper=True), }, # Rule-based strategy variants (action-category priority agents) 'rule_based_cautious': { 'name': 'RuleBased (cautious)', 'description': 'Priority: play first, then discard, then hint', 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='cautious', **_RB_KW), }, 'rule_based_aggressive': { 'name': 'RuleBased (aggressive)', 'description': 'Priority: play only, fallback to hint', 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='aggressive', **_RB_KW), }, 'rule_based_communicative': { 'name': 'RuleBased (communicative)', 'description': 'Priority: hint first, then play', 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='communicative', **_RB_KW), }, 'rule_based_frugal': { 'name': 'RuleBased (frugal)', 'description': 'Priority: discard first to save info tokens', 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='frugal', **_RB_KW), }, # Walton-Rivers agents 'iggi': { 'name': 'IGGI', 'description': 'Game-state-aware rule-based (Walton-Rivers 2017). 11.7/25 self-play.', 'factory': lambda: HanabiIGGIPolicyWrapper(**_RB_KW), }, 'piers': { 'name': 'Piers', 'description': 'IGGI + probabilistic play (>60%) + dispensable hints. 11.6/25.', 'factory': lambda: HanabiPiersPolicyWrapper(**_RB_KW), }, 'van_den_bergh': { 'name': 'Van Den Bergh', 'description': 'Probabilistic discard ordering to protect valuable cards. 9.3/25.', 'factory': lambda: HanabiVanDenBerghPolicyWrapper(**_RB_KW), }, 'outer': { 'name': 'Outer', 'description': 'Hint-heavy agent. Scores 0/25 on full Hanabi (exhausts tokens).', 'factory': lambda: HanabiOuterPolicyWrapper(**_RB_KW), }, 'flawed': { 'name': 'Flawed (30%)', 'description': 'IGGI with 30% random-mistake probability. 0.2/25.', 'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.3, **_RB_KW), }, '': { 'name': 'Flawed (60%)', 'description': 'IGGI with 60% random-mistake probability. ~0/25.', 'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.6, **_RB_KW), }, # SmartBot (strongest rule-based) 'smartbot': { 'name': 'SmartBot', 'description': 'Convention-heavy rule-based agent. 18.0/25 self-play, 0.4 vs human proxy.', 'factory': lambda: HanabiSmartBotPolicyWrapper(**_RB_KW), }, # OBL R2D2 pretrained, all 5 belief levels (seed-a) 'obl_l1': { 'name': 'OBL-L1', 'description': 'Off-Belief Learning level 1 (Hu et al. 2021). 20.9/25 self-play.', 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(1), using_log_wrapper=True), }, 'obl_l2': { 'name': 'OBL-L2', 'description': 'Off-Belief Learning level 2.', 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(2), using_log_wrapper=True), }, 'obl_l3': { 'name': 'OBL-L3', 'description': 'Off-Belief Learning level 3.', 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(3), using_log_wrapper=True), }, 'obl_l4': { 'name': 'OBL-L4', 'description': 'Off-Belief Learning level 4. 24.3/25 self-play (near-optimal).', 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(4), using_log_wrapper=True), }, 'obl_l5': { 'name': 'OBL-L5', 'description': 'Off-Belief Learning level 5.', 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(5), using_log_wrapper=True), }, # Human proxy 'bc_lstm': { 'name': 'BC-LSTM Human Proxy', 'description': 'Behavioral cloning LSTM trained on 1K human games (AH2AC2). 45.5% accuracy.', 'factory': lambda: HanabiBCLSTMPolicyWrapper( weight_file='agents/hanabi/bc_lstm_weights/bc_2p.safetensors', using_log_wrapper=True, greedy=True), }, } # Trained ego agents (RL checkpoints). Returns (policy, params) not just policy. # Separate from PARTNER_REGISTRY because these carry learned parameters. _ego_cache = {} IPPO_FIXTURE = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'fixtures/hanabi_ippo_1e9_seed0/saved_train_run' ) def _cpu_safe_restore(path): """Restore an orbax checkpoint with explicit CPU sharding. Needed because fixtures saved on GPU machines encode GPU sharding, which orbax rejects when restoring on CPU-only machines (the eval API host). Bypasses common.save_load_utils.load_train_run which uses implicit sharding. """ import orbax.checkpoint as ocp if not os.path.isabs(path): from common.save_load_utils import REPO_PATH path = os.path.join(REPO_PATH, path) checkpointer = ocp.PyTreeCheckpointer() cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]) meta = checkpointer.metadata(path) restore_args = jax.tree.map( lambda _: ocp.ArrayRestoreArgs(sharding=cpu_sharding), meta, ) restored = checkpointer.restore(path, restore_args=restore_args) # Convert to jax arrays restored = jax.tree_util.tree_map( lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x, restored, ) return restored def _load_trained_ego(fixture_path, actor_type='mlp', hidden_dim=512, activation='relu', idx_list=None): """Load trained ego from orbax checkpoint. Cached after first load.""" cache_key = (fixture_path, actor_type, hidden_dim, activation, tuple(idx_list) if idx_list else None) if cache_key in _ego_cache: return _ego_cache[cache_key] from evaluation.heldout_evaluator import extract_params from agents.mlp_actor_critic_agent import MLPActorCriticPolicy env = make_env('hanabi') env = LogWrapper(env) # Use CPU-safe restore instead of the default loader restored = _cpu_safe_restore(fixture_path) ckpts = restored['checkpoints'] # Build policy directly. JaxMARL Hanabi's observation_space.shape is a # bare int (658), not a tuple, which breaks initialize_mlp_agent's # `.shape[0]` access. Hardcode obs_dim=658 for full Hanabi. obs_dim = 658 action_dim = 21 policy = MLPActorCriticPolicy( action_dim=action_dim, obs_dim=obs_dim, activation=activation, ) init_params = policy.init_params(jax.random.PRNGKey(0)) # Extract the requested checkpoint. Default to the first. idx = idx_list if idx_list is not None else [0] params_list, _ = extract_params(ckpts, init_params, idx_labels=None) params = params_list[idx[0]] if idx[0] < len(params_list) else params_list[-1] _ego_cache[cache_key] = (policy, params) return policy, params def _extract_and_find_checkpoint(zip_path, extract_dir): """Extract a zipped orbax checkpoint and locate the saved_train_run directory. Returns (saved_train_run_path, optional_config_dict). Raises ValueError if the structure is invalid. """ # Reject zip bombs: cap at 500 MB uncompressed MAX_UNCOMPRESSED = 500 * 1024 * 1024 with zipfile.ZipFile(zip_path, 'r') as zf: total = sum(info.file_size for info in zf.infolist()) if total > MAX_UNCOMPRESSED: raise ValueError( f"Uncompressed size {total} bytes exceeds {MAX_UNCOMPRESSED} limit" ) # Reject absolute paths and traversal for info in zf.infolist(): if info.filename.startswith('/') or '..' in info.filename.split('/'): raise ValueError(f"Unsafe path in zip: {info.filename}") zf.extractall(extract_dir) # Find saved_train_run directory anywhere in the extracted tree saved_run = None config_path = None for root, dirs, files in os.walk(extract_dir): if 'saved_train_run' in dirs: saved_run = os.path.join(root, 'saved_train_run') if 'config.yaml' in files: config_path = os.path.join(root, 'config.yaml') if saved_run and config_path: break if saved_run is None: # Maybe the zip IS the saved_train_run directory (user zipped its contents) # Look for orbax markers for root, dirs, files in os.walk(extract_dir): if '_CHECKPOINT_METADATA' in files or '_METADATA' in files: saved_run = root break if saved_run is None: raise ValueError( "Could not find saved_train_run/ directory or orbax checkpoint " "in the uploaded zip. Expected structure: saved_train_run/ with " "orbax files, optionally alongside config.yaml." ) config_dict = None if config_path: try: from omegaconf import OmegaConf config_dict = OmegaConf.to_container(OmegaConf.load(config_path), resolve=True) except Exception: config_dict = None return saved_run, config_dict # Ego registry: any agent in PARTNER_REGISTRY can also be used as ego # (heuristic agents have no params, just a factory). Trained RL checkpoints # have their own entries with fixture paths. EGO_REGISTRY = { # Heuristic/pretrained agents (reuse PARTNER_REGISTRY factories at dispatch time) key: {'name': info['name'], 'description': info['description'], 'trained': False} for key, info in PARTNER_REGISTRY.items() } # Trained RL checkpoints (orbax fixtures) EGO_REGISTRY['ippo_1e9'] = { 'name': 'IPPO (trained, 1e9)', 'description': '3-seed IPPO trained 1e9 steps. 15.78/25 self-play, 0-5/25 with unseen partners.', 'trained': True, 'fixture': IPPO_FIXTURE, } def evaluate_ego_vs_partner(ego_policy, ego_params, partner_policy, num_episodes=64, rng_seed=34957): env = make_env('hanabi') env = LogWrapper(env) rng = jax.random.PRNGKey(rng_seed) metrics = run_episodes( rng, env, agent_0_policy=ego_policy, agent_0_param=ego_params, agent_1_policy=partner_policy, agent_1_param=None, max_episode_steps=MAX_EPISODE_STEPS, num_eps=num_episodes, agent_0_test_mode=True, agent_1_test_mode=False, ) returns = np.array(metrics['returned_episode_returns']) episode_scores = returns.mean(axis=-1) return { 'mean': float(np.mean(episode_scores)), 'median': float(np.median(episode_scores)), 'std': float(np.std(episode_scores)), 'min': float(np.min(episode_scores)), 'max': float(np.max(episode_scores)), 'num_episodes': int(num_episodes), 'per_episode': episode_scores.tolist(), } _leaderboard_lock = threading.Lock() def load_leaderboard(): if os.path.exists(LEADERBOARD_FILE): with open(LEADERBOARD_FILE) as f: return json.load(f) return [] def save_leaderboard(entries): with open(LEADERBOARD_FILE, 'w') as f: json.dump(entries, f, indent=2) def add_leaderboard_entry(entry): with _leaderboard_lock: entries = load_leaderboard() entries.append(entry) entries.sort(key=lambda x: -x.get('mean_score', 0)) save_leaderboard(entries) return entries @app.route('/api/partners', methods=['GET']) def list_partners(): result = {} for key, info in PARTNER_REGISTRY.items(): result[key] = { 'name': info['name'], 'description': info['description'], } return jsonify(result) @app.route('/api/egos', methods=['GET']) def list_egos(): result = {} for key, info in EGO_REGISTRY.items(): result[key] = { 'name': info['name'], 'description': info['description'], 'trained': info['trained'], } return jsonify(result) @app.route('/api/evaluate_upload', methods=['POST']) def evaluate_upload(): """Evaluate an uploaded orbax checkpoint zip against a held-out partner. Multipart form fields: checkpoint: .zip of orbax checkpoint (must contain saved_train_run/ directory, optionally with config.yaml at the root) agent_name: str (label for leaderboard) partner: str (partner key from PARTNER_REGISTRY) num_episodes: int (default DEFAULT_NUM_EPISODES) actor_type: str (default 'mlp', overridden by config.yaml if present) hidden_dim: int (default 512, overridden by config.yaml if present) activation: str (default 'relu', overridden by config.yaml if present) """ checkpoint_file = request.files.get('checkpoint') if checkpoint_file is None: return jsonify({'error': 'No checkpoint file uploaded'}), 400 if not checkpoint_file.filename.lower().endswith('.zip'): return jsonify({'error': 'Checkpoint must be a .zip archive'}), 400 partner_key = request.form.get('partner', 'bc_lstm') if partner_key not in PARTNER_REGISTRY: return jsonify({'error': f'Unknown partner: {partner_key}'}), 400 agent_name = request.form.get('agent_name') or 'Uploaded checkpoint' num_episodes = min(int(request.form.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000) actor_type = request.form.get('actor_type', 'mlp') hidden_dim = int(request.form.get('hidden_dim', 512)) activation = request.form.get('activation', 'relu') # Extract to a temp directory upload_dir = os.path.join(os.path.dirname(LEADERBOARD_FILE), 'uploads') os.makedirs(upload_dir, exist_ok=True) temp_zip = os.path.join(upload_dir, f'{uuid.uuid4().hex}.zip') extract_dir = tempfile.mkdtemp(prefix='eval_upload_', dir=upload_dir) try: checkpoint_file.save(temp_zip) try: saved_run_path, cfg = _extract_and_find_checkpoint(temp_zip, extract_dir) except ValueError as e: return jsonify({'error': f'Invalid checkpoint zip: {str(e)}'}), 400 # Pull architecture from config.yaml if present if cfg and 'algorithm' in cfg: algo = cfg['algorithm'] actor_type = algo.get('ACTOR_TYPE', actor_type) hidden_dim = int(algo.get('FC_HIDDEN_DIM', hidden_dim)) activation = algo.get('ACTIVATION', activation) # Load the checkpoint try: ego_policy, ego_params = _load_trained_ego( saved_run_path, actor_type=actor_type, hidden_dim=hidden_dim, activation=activation, ) except Exception as e: return jsonify({ 'error': f'Failed to load checkpoint: {str(e)}', 'hint': ('Verify actor_type, hidden_dim, and activation match the ' 'architecture the checkpoint was trained with. Include ' 'config.yaml in the zip to auto-detect.'), }), 400 partner_policy = PARTNER_REGISTRY[partner_key]['factory']() start_time = time.time() try: results = evaluate_ego_vs_partner( ego_policy, ego_params, partner_policy, num_episodes=num_episodes, ) except Exception as e: return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500 elapsed = time.time() - start_time submission_id = str(uuid.uuid4())[:8] entry = { 'id': submission_id, 'agent_name': agent_name, 'partner': partner_key, 'partner_name': PARTNER_REGISTRY[partner_key]['name'], 'mean_score': results['mean'], 'median_score': results['median'], 'std': results['std'], 'min_score': results['min'], 'max_score': results['max'], 'num_episodes': results['num_episodes'], 'ego_type': 'uploaded', 'ego_name': f'Uploaded ({actor_type}, h={hidden_dim})', 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'eval_time_seconds': round(elapsed, 1), } leaderboard = add_leaderboard_entry(entry) response = { **entry, 'per_episode_scores': results['per_episode'], 'leaderboard_position': next( (i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id), len(leaderboard), ), 'total_submissions': len(leaderboard), 'architecture': {'actor_type': actor_type, 'hidden_dim': hidden_dim, 'activation': activation}, } return jsonify(response) finally: # Clean up temp files if os.path.exists(temp_zip): os.remove(temp_zip) if os.path.exists(extract_dir): shutil.rmtree(extract_dir, ignore_errors=True) @app.route('/api/evaluate', methods=['POST']) def evaluate(): data = request.json or {} agent_name = data.get('agent_name', 'Anonymous') partner_key = data.get('partner', 'bc_lstm') num_episodes = min(int(data.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000) ego_key = data.get('ego', 'iggi') if partner_key not in PARTNER_REGISTRY: return jsonify({'error': f'Unknown partner: {partner_key}', 'available': list(PARTNER_REGISTRY.keys())}), 400 if ego_key not in EGO_REGISTRY: return jsonify({'error': f'Unknown ego: {ego_key}', 'available': list(EGO_REGISTRY.keys())}), 400 partner_policy = PARTNER_REGISTRY[partner_key]['factory']() ego_info = EGO_REGISTRY[ego_key] if ego_info['trained']: fixture = ego_info['fixture'] if not os.path.exists(fixture): return jsonify({'error': f'Checkpoint not found: {fixture}'}), 404 try: ego_policy, ego_params = _load_trained_ego(fixture) except Exception as e: return jsonify({'error': f'Failed to load checkpoint: {str(e)}'}), 500 else: if ego_key not in PARTNER_REGISTRY: return jsonify({'error': f'No factory for ego: {ego_key}'}), 400 ego_policy = PARTNER_REGISTRY[ego_key]['factory']() ego_params = None start_time = time.time() try: results = evaluate_ego_vs_partner( ego_policy, ego_params, partner_policy, num_episodes=num_episodes ) except Exception as e: return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500 elapsed = time.time() - start_time submission_id = str(uuid.uuid4())[:8] entry = { 'id': submission_id, 'agent_name': agent_name, 'partner': partner_key, 'partner_name': PARTNER_REGISTRY[partner_key]['name'], 'mean_score': results['mean'], 'median_score': results['median'], 'std': results['std'], 'min_score': results['min'], 'max_score': results['max'], 'num_episodes': results['num_episodes'], 'ego_type': ego_key, 'ego_name': EGO_REGISTRY[ego_key]['name'], 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'eval_time_seconds': round(elapsed, 1), } leaderboard = add_leaderboard_entry(entry) response = { **entry, 'per_episode_scores': results['per_episode'], 'leaderboard_position': next( (i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id), len(leaderboard) ), 'total_submissions': len(leaderboard), } return jsonify(response) @app.route('/api/leaderboard', methods=['GET']) def leaderboard(): partner_filter = request.args.get('partner') entries = load_leaderboard() if partner_filter: entries = [e for e in entries if e.get('partner') == partner_filter] return jsonify({ 'entries': entries, 'total': len(entries), }) @app.route('/api/leaderboard/clear', methods=['POST']) def clear_leaderboard(): with _leaderboard_lock: save_leaderboard([]) return jsonify({'status': 'cleared'}) @app.route('/api/info', methods=['GET']) def api_info(): return jsonify({ 'service': 'Hanabi AHT Policy Evaluation API', 'version': '1.0', 'endpoints': { 'POST /api/evaluate': 'Submit a policy for evaluation', 'GET /api/leaderboard': 'View submission rankings', 'GET /api/partners': 'List available evaluation partners', }, 'supported_partners': list(PARTNER_REGISTRY.keys()), 'default_num_episodes': DEFAULT_NUM_EPISODES, 'leaderboard_file': LEADERBOARD_FILE, }) @app.route('/') def index(): return send_from_directory( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval_ui'), 'index.html' ) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Hanabi AHT Evaluation API') parser.add_argument('--port', type=int, default=5001) parser.add_argument('--host', type=str, default='0.0.0.0') parser.add_argument('--num-episodes', type=int, default=64) args = parser.parse_args() DEFAULT_NUM_EPISODES = args.num_episodes print(f'Hanabi AHT Evaluation API') print(f' http://{args.host}:{args.port}') print(f' Partners: {", ".join(PARTNER_REGISTRY.keys())}') print(f' Default episodes: {DEFAULT_NUM_EPISODES}') app.run(host=args.host, port=args.port, debug=False)