File size: 7,314 Bytes
6d5f894 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | #!/usr/bin/env python3
"""Evaluate a trained checkpoint on Safe Multi-Agent MuJoCo or Harvest.
Usage:
python eval.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
--env mamujoco_HalfCheetah_6x1 \
--n_eval_episodes 100
"""
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
# Environment name -> (scenario, agent_conf) for MuJoCo envs
MAMUJOCO_ENVS = {
"mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
"mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
"mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate a trained checkpoint")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to .pt checkpoint file")
parser.add_argument("--env", type=str, required=True,
choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
help="Environment name")
parser.add_argument("--n_eval_episodes", type=int, default=100,
help="Number of evaluation episodes")
parser.add_argument("--seed", type=int, default=0,
help="Evaluation seed")
parser.add_argument("--device", type=str, default="cpu",
choices=["cpu", "cuda"],
help="Device for inference")
parser.add_argument("--output", type=str, default=None,
help="Path to save evaluation results JSON")
parser.add_argument("--deterministic", action="store_true",
help="Use deterministic (greedy) actions")
return parser.parse_args()
def load_checkpoint(ckpt_path, device="cpu"):
"""Load a checkpoint and return the state dicts and config."""
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
print(f"Loaded checkpoint: {ckpt_path}")
print(f" Keys: {list(ckpt.keys())}")
if "config" in ckpt:
cfg = ckpt["config"]
print(f" Environment: {cfg.get('env_name', 'unknown')}")
print(f" Method: {cfg.get('mode', 'unknown')}")
print(f" Seed: {cfg.get('seed', 'unknown')}")
if "final_metrics" in ckpt:
m = ckpt["final_metrics"]
print(f" Stored welfare: {m.get('total_welfare', 'N/A')}")
print(f" Stored constraint sat: {m.get('constraint_satisfaction_pct', 'N/A')}%")
return ckpt
def make_env(env_name, seed=0):
"""Create the evaluation environment."""
if env_name in MAMUJOCO_ENVS:
scenario, agent_conf = MAMUJOCO_ENVS[env_name]
try:
from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
except ImportError:
print("ERROR: mappo_lagrangian package not installed.")
print("Install from: macpo_base/MAPPO-Lagrangian/")
print(" cd macpo_base/MAPPO-Lagrangian && pip install -e .")
sys.exit(1)
env_args = {
"scenario": scenario,
"agent_conf": agent_conf,
"agent_obsk": 1,
"episode_limit": 1000,
}
env = MujocoMulti(env_args=env_args)
env.seed(seed)
return env
elif env_name == "harvest_5":
try:
from mappo_lagrangian.envs.harvest import HarvestEnv
except ImportError:
print("ERROR: Harvest environment not found in mappo_lagrangian.")
sys.exit(1)
env = HarvestEnv(n_agents=5)
env.seed(seed)
return env
else:
raise ValueError(f"Unknown environment: {env_name}")
def evaluate(env, checkpoint, n_episodes, device="cpu", deterministic=False):
"""Run evaluation episodes and collect metrics."""
actor_state = checkpoint.get("actor_state_dict")
if actor_state is None:
print("WARNING: No actor_state_dict in checkpoint. Reporting stored metrics only.")
return checkpoint.get("final_metrics", {})
# This is a template -- actual model loading depends on the exact architecture
# used in training. Users should adapt this to their setup.
print(f"\nRunning {n_episodes} evaluation episodes...")
print("NOTE: This script provides a template. Adapt the model loading")
print(" to match your MAPPO-Lagrangian actor architecture.")
all_rewards = []
all_costs = []
for ep in range(n_episodes):
obs = env.reset()
episode_reward = 0.0
episode_cost = 0.0
done = False
step = 0
while not done and step < 1000:
# Template: replace with actual policy forward pass
# action = policy(obs, deterministic=deterministic)
n_agents = env.n_agents if hasattr(env, "n_agents") else 1
action = [env.action_space[i].sample() for i in range(n_agents)]
obs, rewards, dones, infos = env.step(action)
episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
if isinstance(infos, list):
episode_cost += sum(info.get("cost", 0) for info in infos)
elif isinstance(infos, dict):
episode_cost += infos.get("cost", 0)
done = all(dones) if isinstance(dones, list) else dones
step += 1
all_rewards.append(episode_reward)
all_costs.append(episode_cost)
if (ep + 1) % 20 == 0:
print(f" Episode {ep+1}/{n_episodes}: "
f"reward={np.mean(all_rewards[-20:]):.1f}, "
f"cost={np.mean(all_costs[-20:]):.2f}")
results = {
"n_episodes": n_episodes,
"mean_reward": float(np.mean(all_rewards)),
"std_reward": float(np.std(all_rewards)),
"mean_cost": float(np.mean(all_costs)),
"std_cost": float(np.std(all_costs)),
"constraint_satisfaction_pct": float(
100.0 * np.mean([c <= 1.0 for c in all_costs])
),
}
return results
def main():
args = parse_args()
# Load checkpoint
ckpt = load_checkpoint(args.checkpoint, device=args.device)
# If checkpoint has stored final_metrics, print them
if "final_metrics" in ckpt:
print("\n--- Stored Training Metrics ---")
for k, v in ckpt["final_metrics"].items():
print(f" {k}: {v}")
# Try to create env and evaluate
try:
env = make_env(args.env, seed=args.seed)
results = evaluate(
env, ckpt, args.n_eval_episodes,
device=args.device, deterministic=args.deterministic,
)
except (ImportError, SystemExit):
print("\nCannot create environment. Reporting stored metrics only.")
results = ckpt.get("final_metrics", {"error": "env not available"})
print("\n--- Evaluation Results ---")
for k, v in results.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
# Save results
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_path}")
if __name__ == "__main__":
main()
|