Spaces:
Running
Running
| """Hanabi AHT evaluation API. Upload checkpoint, get scores vs held-out partners.""" | |
| import argparse | |
| import json | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| import uuid | |
| import threading | |
| import zipfile | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from flask import Flask, jsonify, request, send_from_directory | |
| from flask_cors import CORS | |
| jax.config.update('jax_platform_name', 'cpu') | |
| from envs import make_env | |
| from envs.log_wrapper import LogWrapper | |
| from common.run_episodes import run_episodes | |
| from agents.hanabi.bc_lstm_agent import BCLSTMAgent | |
| from agents.hanabi.agent_policy_wrappers import ( | |
| HanabiBCLSTMPolicyWrapper, | |
| HanabiIGGIPolicyWrapper, | |
| HanabiPiersPolicyWrapper, | |
| HanabiSmartBotPolicyWrapper, | |
| HanabiRandomPolicyWrapper, | |
| HanabiOBLPolicyWrapper, | |
| HanabiRuleBasedPolicyWrapper, | |
| HanabiFlawedPolicyWrapper, | |
| HanabiOuterPolicyWrapper, | |
| HanabiVanDenBerghPolicyWrapper, | |
| ) | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 # 200 MB upload limit | |
| CORS(app) | |
| DEFAULT_NUM_EPISODES = 64 | |
| MAX_EPISODE_STEPS = 128 | |
| LEADERBOARD_FILE = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), 'leaderboard.json' | |
| ) | |
| HAND_SIZE, NUM_COLORS, NUM_RANKS, NUM_ACTIONS = 5, 5, 5, 21 | |
| _RB_KW = dict(hand_size=HAND_SIZE, num_colors=NUM_COLORS, num_ranks=NUM_RANKS, | |
| num_actions=NUM_ACTIONS, using_log_wrapper=True) | |
| def _obl_path(level, seed='a'): | |
| load = '_LOAD1' if level > 1 else '' | |
| return f'agents/hanabi/obl-r2d2-flax/icml_OBL{level}/OFF_BELIEF1_SHUFFLE_COLOR0{load}_BZA0_BELIEF_{seed}.safetensors' | |
| PARTNER_REGISTRY = { | |
| # Baselines | |
| 'random': { | |
| 'name': 'Random', | |
| 'description': 'Uniform random legal action baseline (0/25)', | |
| 'factory': lambda: HanabiRandomPolicyWrapper(num_actions=NUM_ACTIONS, using_log_wrapper=True), | |
| }, | |
| # Rule-based strategy variants (action-category priority agents) | |
| 'rule_based_cautious': { | |
| 'name': 'RuleBased (cautious)', | |
| 'description': 'Priority: play first, then discard, then hint', | |
| 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='cautious', **_RB_KW), | |
| }, | |
| 'rule_based_aggressive': { | |
| 'name': 'RuleBased (aggressive)', | |
| 'description': 'Priority: play only, fallback to hint', | |
| 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='aggressive', **_RB_KW), | |
| }, | |
| 'rule_based_communicative': { | |
| 'name': 'RuleBased (communicative)', | |
| 'description': 'Priority: hint first, then play', | |
| 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='communicative', **_RB_KW), | |
| }, | |
| 'rule_based_frugal': { | |
| 'name': 'RuleBased (frugal)', | |
| 'description': 'Priority: discard first to save info tokens', | |
| 'factory': lambda: HanabiRuleBasedPolicyWrapper(strategy='frugal', **_RB_KW), | |
| }, | |
| # Walton-Rivers agents | |
| 'iggi': { | |
| 'name': 'IGGI', | |
| 'description': 'Game-state-aware rule-based (Walton-Rivers 2017). 11.7/25 self-play.', | |
| 'factory': lambda: HanabiIGGIPolicyWrapper(**_RB_KW), | |
| }, | |
| 'piers': { | |
| 'name': 'Piers', | |
| 'description': 'IGGI + probabilistic play (>60%) + dispensable hints. 11.6/25.', | |
| 'factory': lambda: HanabiPiersPolicyWrapper(**_RB_KW), | |
| }, | |
| 'van_den_bergh': { | |
| 'name': 'Van Den Bergh', | |
| 'description': 'Probabilistic discard ordering to protect valuable cards. 9.3/25.', | |
| 'factory': lambda: HanabiVanDenBerghPolicyWrapper(**_RB_KW), | |
| }, | |
| 'outer': { | |
| 'name': 'Outer', | |
| 'description': 'Hint-heavy agent. Scores 0/25 on full Hanabi (exhausts tokens).', | |
| 'factory': lambda: HanabiOuterPolicyWrapper(**_RB_KW), | |
| }, | |
| 'flawed': { | |
| 'name': 'Flawed (30%)', | |
| 'description': 'IGGI with 30% random-mistake probability. 0.2/25.', | |
| 'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.3, **_RB_KW), | |
| }, | |
| '': { | |
| 'name': 'Flawed (60%)', | |
| 'description': 'IGGI with 60% random-mistake probability. ~0/25.', | |
| 'factory': lambda: HanabiFlawedPolicyWrapper(mistake_prob=0.6, **_RB_KW), | |
| }, | |
| # SmartBot (strongest rule-based) | |
| 'smartbot': { | |
| 'name': 'SmartBot', | |
| 'description': 'Convention-heavy rule-based agent. 18.0/25 self-play, 0.4 vs human proxy.', | |
| 'factory': lambda: HanabiSmartBotPolicyWrapper(**_RB_KW), | |
| }, | |
| # OBL R2D2 pretrained, all 5 belief levels (seed-a) | |
| 'obl_l1': { | |
| 'name': 'OBL-L1', | |
| 'description': 'Off-Belief Learning level 1 (Hu et al. 2021). 20.9/25 self-play.', | |
| 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(1), using_log_wrapper=True), | |
| }, | |
| 'obl_l2': { | |
| 'name': 'OBL-L2', | |
| 'description': 'Off-Belief Learning level 2.', | |
| 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(2), using_log_wrapper=True), | |
| }, | |
| 'obl_l3': { | |
| 'name': 'OBL-L3', | |
| 'description': 'Off-Belief Learning level 3.', | |
| 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(3), using_log_wrapper=True), | |
| }, | |
| 'obl_l4': { | |
| 'name': 'OBL-L4', | |
| 'description': 'Off-Belief Learning level 4. 24.3/25 self-play (near-optimal).', | |
| 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(4), using_log_wrapper=True), | |
| }, | |
| 'obl_l5': { | |
| 'name': 'OBL-L5', | |
| 'description': 'Off-Belief Learning level 5.', | |
| 'factory': lambda: HanabiOBLPolicyWrapper(weight_file=_obl_path(5), using_log_wrapper=True), | |
| }, | |
| # Human proxy | |
| 'bc_lstm': { | |
| 'name': 'BC-LSTM Human Proxy', | |
| 'description': 'Behavioral cloning LSTM trained on 1K human games (AH2AC2). 45.5% accuracy.', | |
| 'factory': lambda: HanabiBCLSTMPolicyWrapper( | |
| weight_file='agents/hanabi/bc_lstm_weights/bc_2p.safetensors', | |
| using_log_wrapper=True, greedy=True), | |
| }, | |
| } | |
| # Trained ego agents (RL checkpoints). Returns (policy, params) not just policy. | |
| # Separate from PARTNER_REGISTRY because these carry learned parameters. | |
| _ego_cache = {} | |
| IPPO_FIXTURE = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), | |
| 'fixtures/hanabi_ippo_1e9_seed0/saved_train_run' | |
| ) | |
| def _cpu_safe_restore(path): | |
| """Restore an orbax checkpoint with explicit CPU sharding. | |
| Needed because fixtures saved on GPU machines encode GPU sharding, | |
| which orbax rejects when restoring on CPU-only machines (the eval | |
| API host). Bypasses common.save_load_utils.load_train_run which | |
| uses implicit sharding. | |
| """ | |
| import orbax.checkpoint as ocp | |
| if not os.path.isabs(path): | |
| from common.save_load_utils import REPO_PATH | |
| path = os.path.join(REPO_PATH, path) | |
| checkpointer = ocp.PyTreeCheckpointer() | |
| cpu_sharding = jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]) | |
| meta = checkpointer.metadata(path) | |
| restore_args = jax.tree.map( | |
| lambda _: ocp.ArrayRestoreArgs(sharding=cpu_sharding), | |
| meta, | |
| ) | |
| restored = checkpointer.restore(path, restore_args=restore_args) | |
| # Convert to jax arrays | |
| restored = jax.tree_util.tree_map( | |
| lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x, | |
| restored, | |
| ) | |
| return restored | |
| def _load_trained_ego(fixture_path, actor_type='mlp', hidden_dim=512, | |
| activation='relu', idx_list=None): | |
| """Load trained ego from orbax checkpoint. Cached after first load.""" | |
| cache_key = (fixture_path, actor_type, hidden_dim, activation, | |
| tuple(idx_list) if idx_list else None) | |
| if cache_key in _ego_cache: | |
| return _ego_cache[cache_key] | |
| from evaluation.heldout_evaluator import extract_params | |
| from agents.mlp_actor_critic_agent import MLPActorCriticPolicy | |
| env = make_env('hanabi') | |
| env = LogWrapper(env) | |
| # Use CPU-safe restore instead of the default loader | |
| restored = _cpu_safe_restore(fixture_path) | |
| ckpts = restored['checkpoints'] | |
| # Build policy directly. JaxMARL Hanabi's observation_space.shape is a | |
| # bare int (658), not a tuple, which breaks initialize_mlp_agent's | |
| # `.shape[0]` access. Hardcode obs_dim=658 for full Hanabi. | |
| obs_dim = 658 | |
| action_dim = 21 | |
| policy = MLPActorCriticPolicy( | |
| action_dim=action_dim, | |
| obs_dim=obs_dim, | |
| activation=activation, | |
| ) | |
| init_params = policy.init_params(jax.random.PRNGKey(0)) | |
| # Extract the requested checkpoint. Default to the first. | |
| idx = idx_list if idx_list is not None else [0] | |
| params_list, _ = extract_params(ckpts, init_params, idx_labels=None) | |
| params = params_list[idx[0]] if idx[0] < len(params_list) else params_list[-1] | |
| _ego_cache[cache_key] = (policy, params) | |
| return policy, params | |
| def _extract_and_find_checkpoint(zip_path, extract_dir): | |
| """Extract a zipped orbax checkpoint and locate the saved_train_run directory. | |
| Returns (saved_train_run_path, optional_config_dict). | |
| Raises ValueError if the structure is invalid. | |
| """ | |
| # Reject zip bombs: cap at 500 MB uncompressed | |
| MAX_UNCOMPRESSED = 500 * 1024 * 1024 | |
| with zipfile.ZipFile(zip_path, 'r') as zf: | |
| total = sum(info.file_size for info in zf.infolist()) | |
| if total > MAX_UNCOMPRESSED: | |
| raise ValueError( | |
| f"Uncompressed size {total} bytes exceeds {MAX_UNCOMPRESSED} limit" | |
| ) | |
| # Reject absolute paths and traversal | |
| for info in zf.infolist(): | |
| if info.filename.startswith('/') or '..' in info.filename.split('/'): | |
| raise ValueError(f"Unsafe path in zip: {info.filename}") | |
| zf.extractall(extract_dir) | |
| # Find saved_train_run directory anywhere in the extracted tree | |
| saved_run = None | |
| config_path = None | |
| for root, dirs, files in os.walk(extract_dir): | |
| if 'saved_train_run' in dirs: | |
| saved_run = os.path.join(root, 'saved_train_run') | |
| if 'config.yaml' in files: | |
| config_path = os.path.join(root, 'config.yaml') | |
| if saved_run and config_path: | |
| break | |
| if saved_run is None: | |
| # Maybe the zip IS the saved_train_run directory (user zipped its contents) | |
| # Look for orbax markers | |
| for root, dirs, files in os.walk(extract_dir): | |
| if '_CHECKPOINT_METADATA' in files or '_METADATA' in files: | |
| saved_run = root | |
| break | |
| if saved_run is None: | |
| raise ValueError( | |
| "Could not find saved_train_run/ directory or orbax checkpoint " | |
| "in the uploaded zip. Expected structure: saved_train_run/ with " | |
| "orbax files, optionally alongside config.yaml." | |
| ) | |
| config_dict = None | |
| if config_path: | |
| try: | |
| from omegaconf import OmegaConf | |
| config_dict = OmegaConf.to_container(OmegaConf.load(config_path), resolve=True) | |
| except Exception: | |
| config_dict = None | |
| return saved_run, config_dict | |
| # Ego registry: any agent in PARTNER_REGISTRY can also be used as ego | |
| # (heuristic agents have no params, just a factory). Trained RL checkpoints | |
| # have their own entries with fixture paths. | |
| EGO_REGISTRY = { | |
| # Heuristic/pretrained agents (reuse PARTNER_REGISTRY factories at dispatch time) | |
| key: {'name': info['name'], 'description': info['description'], 'trained': False} | |
| for key, info in PARTNER_REGISTRY.items() | |
| } | |
| # Trained RL checkpoints (orbax fixtures) | |
| EGO_REGISTRY['ippo_1e9'] = { | |
| 'name': 'IPPO (trained, 1e9)', | |
| 'description': '3-seed IPPO trained 1e9 steps. 15.78/25 self-play, 0-5/25 with unseen partners.', | |
| 'trained': True, | |
| 'fixture': IPPO_FIXTURE, | |
| } | |
| def evaluate_ego_vs_partner(ego_policy, ego_params, partner_policy, | |
| num_episodes=64, rng_seed=34957): | |
| env = make_env('hanabi') | |
| env = LogWrapper(env) | |
| rng = jax.random.PRNGKey(rng_seed) | |
| metrics = run_episodes( | |
| rng, env, | |
| agent_0_policy=ego_policy, | |
| agent_0_param=ego_params, | |
| agent_1_policy=partner_policy, | |
| agent_1_param=None, | |
| max_episode_steps=MAX_EPISODE_STEPS, | |
| num_eps=num_episodes, | |
| agent_0_test_mode=True, | |
| agent_1_test_mode=False, | |
| ) | |
| returns = np.array(metrics['returned_episode_returns']) | |
| episode_scores = returns.mean(axis=-1) | |
| return { | |
| 'mean': float(np.mean(episode_scores)), | |
| 'median': float(np.median(episode_scores)), | |
| 'std': float(np.std(episode_scores)), | |
| 'min': float(np.min(episode_scores)), | |
| 'max': float(np.max(episode_scores)), | |
| 'num_episodes': int(num_episodes), | |
| 'per_episode': episode_scores.tolist(), | |
| } | |
| _leaderboard_lock = threading.Lock() | |
| def load_leaderboard(): | |
| if os.path.exists(LEADERBOARD_FILE): | |
| with open(LEADERBOARD_FILE) as f: | |
| return json.load(f) | |
| return [] | |
| def save_leaderboard(entries): | |
| with open(LEADERBOARD_FILE, 'w') as f: | |
| json.dump(entries, f, indent=2) | |
| def add_leaderboard_entry(entry): | |
| with _leaderboard_lock: | |
| entries = load_leaderboard() | |
| entries.append(entry) | |
| entries.sort(key=lambda x: -x.get('mean_score', 0)) | |
| save_leaderboard(entries) | |
| return entries | |
| def list_partners(): | |
| result = {} | |
| for key, info in PARTNER_REGISTRY.items(): | |
| result[key] = { | |
| 'name': info['name'], | |
| 'description': info['description'], | |
| } | |
| return jsonify(result) | |
| def list_egos(): | |
| result = {} | |
| for key, info in EGO_REGISTRY.items(): | |
| result[key] = { | |
| 'name': info['name'], | |
| 'description': info['description'], | |
| 'trained': info['trained'], | |
| } | |
| return jsonify(result) | |
| def evaluate_upload(): | |
| """Evaluate an uploaded orbax checkpoint zip against a held-out partner. | |
| Multipart form fields: | |
| checkpoint: .zip of orbax checkpoint (must contain saved_train_run/ | |
| directory, optionally with config.yaml at the root) | |
| agent_name: str (label for leaderboard) | |
| partner: str (partner key from PARTNER_REGISTRY) | |
| num_episodes: int (default DEFAULT_NUM_EPISODES) | |
| actor_type: str (default 'mlp', overridden by config.yaml if present) | |
| hidden_dim: int (default 512, overridden by config.yaml if present) | |
| activation: str (default 'relu', overridden by config.yaml if present) | |
| """ | |
| checkpoint_file = request.files.get('checkpoint') | |
| if checkpoint_file is None: | |
| return jsonify({'error': 'No checkpoint file uploaded'}), 400 | |
| if not checkpoint_file.filename.lower().endswith('.zip'): | |
| return jsonify({'error': 'Checkpoint must be a .zip archive'}), 400 | |
| partner_key = request.form.get('partner', 'bc_lstm') | |
| if partner_key not in PARTNER_REGISTRY: | |
| return jsonify({'error': f'Unknown partner: {partner_key}'}), 400 | |
| agent_name = request.form.get('agent_name') or 'Uploaded checkpoint' | |
| num_episodes = min(int(request.form.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000) | |
| actor_type = request.form.get('actor_type', 'mlp') | |
| hidden_dim = int(request.form.get('hidden_dim', 512)) | |
| activation = request.form.get('activation', 'relu') | |
| # Extract to a temp directory | |
| upload_dir = os.path.join(os.path.dirname(LEADERBOARD_FILE), 'uploads') | |
| os.makedirs(upload_dir, exist_ok=True) | |
| temp_zip = os.path.join(upload_dir, f'{uuid.uuid4().hex}.zip') | |
| extract_dir = tempfile.mkdtemp(prefix='eval_upload_', dir=upload_dir) | |
| try: | |
| checkpoint_file.save(temp_zip) | |
| try: | |
| saved_run_path, cfg = _extract_and_find_checkpoint(temp_zip, extract_dir) | |
| except ValueError as e: | |
| return jsonify({'error': f'Invalid checkpoint zip: {str(e)}'}), 400 | |
| # Pull architecture from config.yaml if present | |
| if cfg and 'algorithm' in cfg: | |
| algo = cfg['algorithm'] | |
| actor_type = algo.get('ACTOR_TYPE', actor_type) | |
| hidden_dim = int(algo.get('FC_HIDDEN_DIM', hidden_dim)) | |
| activation = algo.get('ACTIVATION', activation) | |
| # Load the checkpoint | |
| try: | |
| ego_policy, ego_params = _load_trained_ego( | |
| saved_run_path, actor_type=actor_type, | |
| hidden_dim=hidden_dim, activation=activation, | |
| ) | |
| except Exception as e: | |
| return jsonify({ | |
| 'error': f'Failed to load checkpoint: {str(e)}', | |
| 'hint': ('Verify actor_type, hidden_dim, and activation match the ' | |
| 'architecture the checkpoint was trained with. Include ' | |
| 'config.yaml in the zip to auto-detect.'), | |
| }), 400 | |
| partner_policy = PARTNER_REGISTRY[partner_key]['factory']() | |
| start_time = time.time() | |
| try: | |
| results = evaluate_ego_vs_partner( | |
| ego_policy, ego_params, partner_policy, | |
| num_episodes=num_episodes, | |
| ) | |
| except Exception as e: | |
| return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500 | |
| elapsed = time.time() - start_time | |
| submission_id = str(uuid.uuid4())[:8] | |
| entry = { | |
| 'id': submission_id, | |
| 'agent_name': agent_name, | |
| 'partner': partner_key, | |
| 'partner_name': PARTNER_REGISTRY[partner_key]['name'], | |
| 'mean_score': results['mean'], | |
| 'median_score': results['median'], | |
| 'std': results['std'], | |
| 'min_score': results['min'], | |
| 'max_score': results['max'], | |
| 'num_episodes': results['num_episodes'], | |
| 'ego_type': 'uploaded', | |
| 'ego_name': f'Uploaded ({actor_type}, h={hidden_dim})', | |
| 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), | |
| 'eval_time_seconds': round(elapsed, 1), | |
| } | |
| leaderboard = add_leaderboard_entry(entry) | |
| response = { | |
| **entry, | |
| 'per_episode_scores': results['per_episode'], | |
| 'leaderboard_position': next( | |
| (i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id), | |
| len(leaderboard), | |
| ), | |
| 'total_submissions': len(leaderboard), | |
| 'architecture': {'actor_type': actor_type, 'hidden_dim': hidden_dim, | |
| 'activation': activation}, | |
| } | |
| return jsonify(response) | |
| finally: | |
| # Clean up temp files | |
| if os.path.exists(temp_zip): | |
| os.remove(temp_zip) | |
| if os.path.exists(extract_dir): | |
| shutil.rmtree(extract_dir, ignore_errors=True) | |
| def evaluate(): | |
| data = request.json or {} | |
| agent_name = data.get('agent_name', 'Anonymous') | |
| partner_key = data.get('partner', 'bc_lstm') | |
| num_episodes = min(int(data.get('num_episodes', DEFAULT_NUM_EPISODES)), 1000) | |
| ego_key = data.get('ego', 'iggi') | |
| if partner_key not in PARTNER_REGISTRY: | |
| return jsonify({'error': f'Unknown partner: {partner_key}', | |
| 'available': list(PARTNER_REGISTRY.keys())}), 400 | |
| if ego_key not in EGO_REGISTRY: | |
| return jsonify({'error': f'Unknown ego: {ego_key}', | |
| 'available': list(EGO_REGISTRY.keys())}), 400 | |
| partner_policy = PARTNER_REGISTRY[partner_key]['factory']() | |
| ego_info = EGO_REGISTRY[ego_key] | |
| if ego_info['trained']: | |
| fixture = ego_info['fixture'] | |
| if not os.path.exists(fixture): | |
| return jsonify({'error': f'Checkpoint not found: {fixture}'}), 404 | |
| try: | |
| ego_policy, ego_params = _load_trained_ego(fixture) | |
| except Exception as e: | |
| return jsonify({'error': f'Failed to load checkpoint: {str(e)}'}), 500 | |
| else: | |
| if ego_key not in PARTNER_REGISTRY: | |
| return jsonify({'error': f'No factory for ego: {ego_key}'}), 400 | |
| ego_policy = PARTNER_REGISTRY[ego_key]['factory']() | |
| ego_params = None | |
| start_time = time.time() | |
| try: | |
| results = evaluate_ego_vs_partner( | |
| ego_policy, ego_params, partner_policy, | |
| num_episodes=num_episodes | |
| ) | |
| except Exception as e: | |
| return jsonify({'error': f'Evaluation failed: {str(e)}'}), 500 | |
| elapsed = time.time() - start_time | |
| submission_id = str(uuid.uuid4())[:8] | |
| entry = { | |
| 'id': submission_id, | |
| 'agent_name': agent_name, | |
| 'partner': partner_key, | |
| 'partner_name': PARTNER_REGISTRY[partner_key]['name'], | |
| 'mean_score': results['mean'], | |
| 'median_score': results['median'], | |
| 'std': results['std'], | |
| 'min_score': results['min'], | |
| 'max_score': results['max'], | |
| 'num_episodes': results['num_episodes'], | |
| 'ego_type': ego_key, | |
| 'ego_name': EGO_REGISTRY[ego_key]['name'], | |
| 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), | |
| 'eval_time_seconds': round(elapsed, 1), | |
| } | |
| leaderboard = add_leaderboard_entry(entry) | |
| response = { | |
| **entry, | |
| 'per_episode_scores': results['per_episode'], | |
| 'leaderboard_position': next( | |
| (i + 1 for i, e in enumerate(leaderboard) if e['id'] == submission_id), | |
| len(leaderboard) | |
| ), | |
| 'total_submissions': len(leaderboard), | |
| } | |
| return jsonify(response) | |
| def leaderboard(): | |
| partner_filter = request.args.get('partner') | |
| entries = load_leaderboard() | |
| if partner_filter: | |
| entries = [e for e in entries if e.get('partner') == partner_filter] | |
| return jsonify({ | |
| 'entries': entries, | |
| 'total': len(entries), | |
| }) | |
| def clear_leaderboard(): | |
| with _leaderboard_lock: | |
| save_leaderboard([]) | |
| return jsonify({'status': 'cleared'}) | |
| def api_info(): | |
| return jsonify({ | |
| 'service': 'Hanabi AHT Policy Evaluation API', | |
| 'version': '1.0', | |
| 'endpoints': { | |
| 'POST /api/evaluate': 'Submit a policy for evaluation', | |
| 'GET /api/leaderboard': 'View submission rankings', | |
| 'GET /api/partners': 'List available evaluation partners', | |
| }, | |
| 'supported_partners': list(PARTNER_REGISTRY.keys()), | |
| 'default_num_episodes': DEFAULT_NUM_EPISODES, | |
| 'leaderboard_file': LEADERBOARD_FILE, | |
| }) | |
| def index(): | |
| return send_from_directory( | |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval_ui'), | |
| 'index.html' | |
| ) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Hanabi AHT Evaluation API') | |
| parser.add_argument('--port', type=int, default=5001) | |
| parser.add_argument('--host', type=str, default='0.0.0.0') | |
| parser.add_argument('--num-episodes', type=int, default=64) | |
| args = parser.parse_args() | |
| DEFAULT_NUM_EPISODES = args.num_episodes | |
| print(f'Hanabi AHT Evaluation API') | |
| print(f' http://{args.host}:{args.port}') | |
| print(f' Partners: {", ".join(PARTNER_REGISTRY.keys())}') | |
| print(f' Default episodes: {DEFAULT_NUM_EPISODES}') | |
| app.run(host=args.host, port=args.port, debug=False) | |