multi-agent-strat / train_rl.py
Avnishjain's picture
Upload 21 files
6888575 verified
"""Train a tabular Q-learning policy for the multi-agent strategy environment."""
from __future__ import annotations
import json
import os
import random
from collections import deque
from pathlib import Path
from typing import Dict, List
import numpy as np
from strategy_env.models import ACTION_CHOICES, StrategyAction, StrategyObservation
from strategy_env.server.environment import (
RULE_MULTIPLIERS,
RULE_ORDER,
MultiAgentStrategyEnvironment,
)
from strategy_env.tasks import TASK_ORDER
EPISODES = int(os.getenv("TRAIN_EPISODES", "4000"))
MAX_STEPS = int(os.getenv("TRAIN_MAX_STEPS", "18"))
ALPHA = float(os.getenv("TRAIN_ALPHA", "0.2"))
GAMMA = float(os.getenv("TRAIN_GAMMA", "0.95"))
EPS_START = float(os.getenv("TRAIN_EPS_START", "1.0"))
EPS_END = float(os.getenv("TRAIN_EPS_END", "0.05"))
SEED = int(os.getenv("TRAIN_SEED", "42"))
POLICY_PATH = Path("artifacts/q_policy.json")
HISTORY_PATH = Path("artifacts/training_history.json")
RULE_INDEX = {rule: idx for idx, rule in enumerate(RULE_ORDER)}
TASK_INDEX = {task_id: idx for idx, task_id in enumerate(TASK_ORDER)}
ACTION_INDEX = {name: idx for idx, name in enumerate(ACTION_CHOICES)}
def _bucket(value: int, step: int, max_value: int) -> int:
return max(0, min(max_value // step, value // step))
def encode_state(obs: StrategyObservation) -> str:
opp_last_idx = ACTION_INDEX.get(obs.last_opponent_action, len(ACTION_CHOICES))
resource_gap = obs.own_resources - obs.visible_opponent_resources
defense_gap = obs.own_defense - obs.visible_opponent_defense
key = (
TASK_INDEX.get(obs.task_id, 0),
RULE_INDEX.get(obs.active_rule, 0),
_bucket(obs.turn, 2, obs.max_turns + 2),
_bucket(obs.own_resources, 2, 20),
_bucket(obs.own_defense, 2, 20),
_bucket(obs.own_intel, 2, 20),
_bucket(obs.visible_opponent_resources, 2, 20),
_bucket(obs.visible_opponent_defense, 2, 20),
1 if obs.rule_hint.startswith("high_confidence") else 0,
_bucket(resource_gap + 12, 3, 24),
_bucket(defense_gap + 12, 3, 24),
opp_last_idx,
)
return "|".join(str(x) for x in key)
def _ensure_state(q_table: Dict[str, List[float]], key: str) -> None:
if key not in q_table:
q_table[key] = [0.0 for _ in ACTION_CHOICES]
def _epsilon(episode: int) -> float:
progress = min(1.0, episode / max(1, EPISODES - 1))
return EPS_START + (EPS_END - EPS_START) * progress
def _valid_action_indices(obs: StrategyObservation) -> List[int]:
valid = set(range(len(ACTION_CHOICES)))
if obs.own_resources < 1:
valid.discard(ACTION_INDEX["attack"])
if obs.own_intel < 1:
valid.discard(ACTION_INDEX["adapt"])
valid.discard(ACTION_INDEX["bluff"])
if not valid:
return [ACTION_INDEX["noop"]]
return sorted(valid)
def _greedy_action_idx(
q_values: List[float], valid_indices: List[int], active_rule: str
) -> int:
best_q = max(q_values[idx] for idx in valid_indices)
candidates = [idx for idx in valid_indices if abs(q_values[idx] - best_q) < 1e-12]
if len(candidates) == 1:
return candidates[0]
return max(
candidates,
key=lambda idx: RULE_MULTIPLIERS[active_rule][ACTION_CHOICES[idx]],
)
def _base_task_weights(progress: float) -> List[float]:
if progress < 0.4:
return [0.35, 0.40, 0.25]
if progress < 0.8:
return [0.20, 0.60, 0.20]
return [0.10, 0.75, 0.15]
def _sample_task(episode: int, task_score_ema: Dict[str, float]) -> str:
progress = min(1.0, episode / max(1, EPISODES - 1))
base = _base_task_weights(progress)
adjusted_weights: List[float] = []
for task_id, weight in zip(TASK_ORDER, base):
deficit = max(0.0, 0.65 - task_score_ema[task_id])
boost = 1.0 + 1.25 * deficit
if task_id == "medium_alliance_shuffle":
boost *= 1.2
adjusted_weights.append(weight * boost)
return random.choices(TASK_ORDER, weights=adjusted_weights, k=1)[0]
def train() -> None:
random.seed(SEED)
np.random.seed(SEED)
env = MultiAgentStrategyEnvironment()
q_table: Dict[str, List[float]] = {}
task_score_ema = {task_id: 0.5 for task_id in TASK_ORDER}
recent_rewards: deque[float] = deque(maxlen=200)
recent_scores: deque[float] = deque(maxlen=200)
history = []
for episode in range(1, EPISODES + 1):
task_id = _sample_task(episode, task_score_ema)
obs = env.reset(task_id=task_id, seed=SEED + episode)
total_reward = 0.0
eps = _epsilon(episode)
for _ in range(MAX_STEPS):
state_key = encode_state(obs)
_ensure_state(q_table, state_key)
valid_indices = _valid_action_indices(obs)
if random.random() < eps:
action_idx = random.choice(valid_indices)
else:
action_idx = _greedy_action_idx(
q_table[state_key],
valid_indices,
obs.active_rule,
)
action = StrategyAction(action_type=ACTION_CHOICES[action_idx])
next_obs = env.step(action)
reward = float(next_obs.reward or 0.0)
total_reward += reward
next_key = encode_state(next_obs)
_ensure_state(q_table, next_key)
next_valid = _valid_action_indices(next_obs)
best_next = max(q_table[next_key][idx] for idx in next_valid)
old_q = q_table[state_key][action_idx]
q_table[state_key][action_idx] = old_q + ALPHA * (
reward + GAMMA * best_next - old_q
)
obs = next_obs
if obs.done:
break
eval_report = env.evaluate()
score = float(eval_report["score"])
task_score_ema[task_id] = 0.9 * task_score_ema[task_id] + 0.1 * score
recent_rewards.append(total_reward)
recent_scores.append(score)
if episode % 100 == 0 or episode == 1:
avg_reward = float(np.mean(recent_rewards)) if recent_rewards else 0.0
avg_score = float(np.mean(recent_scores)) if recent_scores else 0.0
print(
f"episode={episode}/{EPISODES} epsilon={eps:.3f} "
f"avg_reward_200={avg_reward:.4f} avg_score_200={avg_score:.4f}"
)
history.append(
{
"episode": episode,
"epsilon": round(eps, 4),
"avg_reward_200": round(avg_reward, 4),
"avg_score_200": round(avg_score, 4),
"ema_easy": round(task_score_ema["easy_frontier_probe"], 4),
"ema_medium": round(task_score_ema["medium_alliance_shuffle"], 4),
"ema_hard": round(task_score_ema["hard_chaos_conclave"], 4),
}
)
POLICY_PATH.parent.mkdir(parents=True, exist_ok=True)
payload = {
"algorithm": "tabular_q_learning",
"episodes": EPISODES,
"alpha": ALPHA,
"gamma": GAMMA,
"epsilon_start": EPS_START,
"epsilon_end": EPS_END,
"seed": SEED,
"actions": ACTION_CHOICES,
"rule_order": RULE_ORDER,
"q_table": q_table,
}
with POLICY_PATH.open("w", encoding="utf-8") as f:
json.dump(payload, f)
with HISTORY_PATH.open("w", encoding="utf-8") as f:
json.dump(history, f, indent=2)
print(f"Saved policy to {POLICY_PATH}")
print(f"Saved training history to {HISTORY_PATH}")
if __name__ == "__main__":
train()