Spaces:
Runtime error
Runtime error
| """Train a tabular Q-learning policy for the multi-agent strategy environment.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| from collections import deque | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import numpy as np | |
| from strategy_env.models import ACTION_CHOICES, StrategyAction, StrategyObservation | |
| from strategy_env.server.environment import ( | |
| RULE_MULTIPLIERS, | |
| RULE_ORDER, | |
| MultiAgentStrategyEnvironment, | |
| ) | |
| from strategy_env.tasks import TASK_ORDER | |
| EPISODES = int(os.getenv("TRAIN_EPISODES", "4000")) | |
| MAX_STEPS = int(os.getenv("TRAIN_MAX_STEPS", "18")) | |
| ALPHA = float(os.getenv("TRAIN_ALPHA", "0.2")) | |
| GAMMA = float(os.getenv("TRAIN_GAMMA", "0.95")) | |
| EPS_START = float(os.getenv("TRAIN_EPS_START", "1.0")) | |
| EPS_END = float(os.getenv("TRAIN_EPS_END", "0.05")) | |
| SEED = int(os.getenv("TRAIN_SEED", "42")) | |
| POLICY_PATH = Path("artifacts/q_policy.json") | |
| HISTORY_PATH = Path("artifacts/training_history.json") | |
| RULE_INDEX = {rule: idx for idx, rule in enumerate(RULE_ORDER)} | |
| TASK_INDEX = {task_id: idx for idx, task_id in enumerate(TASK_ORDER)} | |
| ACTION_INDEX = {name: idx for idx, name in enumerate(ACTION_CHOICES)} | |
| def _bucket(value: int, step: int, max_value: int) -> int: | |
| return max(0, min(max_value // step, value // step)) | |
| def encode_state(obs: StrategyObservation) -> str: | |
| opp_last_idx = ACTION_INDEX.get(obs.last_opponent_action, len(ACTION_CHOICES)) | |
| resource_gap = obs.own_resources - obs.visible_opponent_resources | |
| defense_gap = obs.own_defense - obs.visible_opponent_defense | |
| key = ( | |
| TASK_INDEX.get(obs.task_id, 0), | |
| RULE_INDEX.get(obs.active_rule, 0), | |
| _bucket(obs.turn, 2, obs.max_turns + 2), | |
| _bucket(obs.own_resources, 2, 20), | |
| _bucket(obs.own_defense, 2, 20), | |
| _bucket(obs.own_intel, 2, 20), | |
| _bucket(obs.visible_opponent_resources, 2, 20), | |
| _bucket(obs.visible_opponent_defense, 2, 20), | |
| 1 if obs.rule_hint.startswith("high_confidence") else 0, | |
| _bucket(resource_gap + 12, 3, 24), | |
| _bucket(defense_gap + 12, 3, 24), | |
| opp_last_idx, | |
| ) | |
| return "|".join(str(x) for x in key) | |
| def _ensure_state(q_table: Dict[str, List[float]], key: str) -> None: | |
| if key not in q_table: | |
| q_table[key] = [0.0 for _ in ACTION_CHOICES] | |
| def _epsilon(episode: int) -> float: | |
| progress = min(1.0, episode / max(1, EPISODES - 1)) | |
| return EPS_START + (EPS_END - EPS_START) * progress | |
| def _valid_action_indices(obs: StrategyObservation) -> List[int]: | |
| valid = set(range(len(ACTION_CHOICES))) | |
| if obs.own_resources < 1: | |
| valid.discard(ACTION_INDEX["attack"]) | |
| if obs.own_intel < 1: | |
| valid.discard(ACTION_INDEX["adapt"]) | |
| valid.discard(ACTION_INDEX["bluff"]) | |
| if not valid: | |
| return [ACTION_INDEX["noop"]] | |
| return sorted(valid) | |
| def _greedy_action_idx( | |
| q_values: List[float], valid_indices: List[int], active_rule: str | |
| ) -> int: | |
| best_q = max(q_values[idx] for idx in valid_indices) | |
| candidates = [idx for idx in valid_indices if abs(q_values[idx] - best_q) < 1e-12] | |
| if len(candidates) == 1: | |
| return candidates[0] | |
| return max( | |
| candidates, | |
| key=lambda idx: RULE_MULTIPLIERS[active_rule][ACTION_CHOICES[idx]], | |
| ) | |
| def _base_task_weights(progress: float) -> List[float]: | |
| if progress < 0.4: | |
| return [0.35, 0.40, 0.25] | |
| if progress < 0.8: | |
| return [0.20, 0.60, 0.20] | |
| return [0.10, 0.75, 0.15] | |
| def _sample_task(episode: int, task_score_ema: Dict[str, float]) -> str: | |
| progress = min(1.0, episode / max(1, EPISODES - 1)) | |
| base = _base_task_weights(progress) | |
| adjusted_weights: List[float] = [] | |
| for task_id, weight in zip(TASK_ORDER, base): | |
| deficit = max(0.0, 0.65 - task_score_ema[task_id]) | |
| boost = 1.0 + 1.25 * deficit | |
| if task_id == "medium_alliance_shuffle": | |
| boost *= 1.2 | |
| adjusted_weights.append(weight * boost) | |
| return random.choices(TASK_ORDER, weights=adjusted_weights, k=1)[0] | |
| def train() -> None: | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| env = MultiAgentStrategyEnvironment() | |
| q_table: Dict[str, List[float]] = {} | |
| task_score_ema = {task_id: 0.5 for task_id in TASK_ORDER} | |
| recent_rewards: deque[float] = deque(maxlen=200) | |
| recent_scores: deque[float] = deque(maxlen=200) | |
| history = [] | |
| for episode in range(1, EPISODES + 1): | |
| task_id = _sample_task(episode, task_score_ema) | |
| obs = env.reset(task_id=task_id, seed=SEED + episode) | |
| total_reward = 0.0 | |
| eps = _epsilon(episode) | |
| for _ in range(MAX_STEPS): | |
| state_key = encode_state(obs) | |
| _ensure_state(q_table, state_key) | |
| valid_indices = _valid_action_indices(obs) | |
| if random.random() < eps: | |
| action_idx = random.choice(valid_indices) | |
| else: | |
| action_idx = _greedy_action_idx( | |
| q_table[state_key], | |
| valid_indices, | |
| obs.active_rule, | |
| ) | |
| action = StrategyAction(action_type=ACTION_CHOICES[action_idx]) | |
| next_obs = env.step(action) | |
| reward = float(next_obs.reward or 0.0) | |
| total_reward += reward | |
| next_key = encode_state(next_obs) | |
| _ensure_state(q_table, next_key) | |
| next_valid = _valid_action_indices(next_obs) | |
| best_next = max(q_table[next_key][idx] for idx in next_valid) | |
| old_q = q_table[state_key][action_idx] | |
| q_table[state_key][action_idx] = old_q + ALPHA * ( | |
| reward + GAMMA * best_next - old_q | |
| ) | |
| obs = next_obs | |
| if obs.done: | |
| break | |
| eval_report = env.evaluate() | |
| score = float(eval_report["score"]) | |
| task_score_ema[task_id] = 0.9 * task_score_ema[task_id] + 0.1 * score | |
| recent_rewards.append(total_reward) | |
| recent_scores.append(score) | |
| if episode % 100 == 0 or episode == 1: | |
| avg_reward = float(np.mean(recent_rewards)) if recent_rewards else 0.0 | |
| avg_score = float(np.mean(recent_scores)) if recent_scores else 0.0 | |
| print( | |
| f"episode={episode}/{EPISODES} epsilon={eps:.3f} " | |
| f"avg_reward_200={avg_reward:.4f} avg_score_200={avg_score:.4f}" | |
| ) | |
| history.append( | |
| { | |
| "episode": episode, | |
| "epsilon": round(eps, 4), | |
| "avg_reward_200": round(avg_reward, 4), | |
| "avg_score_200": round(avg_score, 4), | |
| "ema_easy": round(task_score_ema["easy_frontier_probe"], 4), | |
| "ema_medium": round(task_score_ema["medium_alliance_shuffle"], 4), | |
| "ema_hard": round(task_score_ema["hard_chaos_conclave"], 4), | |
| } | |
| ) | |
| POLICY_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "algorithm": "tabular_q_learning", | |
| "episodes": EPISODES, | |
| "alpha": ALPHA, | |
| "gamma": GAMMA, | |
| "epsilon_start": EPS_START, | |
| "epsilon_end": EPS_END, | |
| "seed": SEED, | |
| "actions": ACTION_CHOICES, | |
| "rule_order": RULE_ORDER, | |
| "q_table": q_table, | |
| } | |
| with POLICY_PATH.open("w", encoding="utf-8") as f: | |
| json.dump(payload, f) | |
| with HISTORY_PATH.open("w", encoding="utf-8") as f: | |
| json.dump(history, f, indent=2) | |
| print(f"Saved policy to {POLICY_PATH}") | |
| print(f"Saved training history to {HISTORY_PATH}") | |
| if __name__ == "__main__": | |
| train() | |