Sean13's picture
Add scripts/eval.py
6d5f894 verified
Raw
History Blame Contribute Delete
7.31 kB
#!/usr/bin/env python3
"""Evaluate a trained checkpoint on Safe Multi-Agent MuJoCo or Harvest.
Usage:
python eval.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
--env mamujoco_HalfCheetah_6x1 \
--n_eval_episodes 100
"""
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
# Environment name -> (scenario, agent_conf) for MuJoCo envs
MAMUJOCO_ENVS = {
"mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
"mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
"mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate a trained checkpoint")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to .pt checkpoint file")
parser.add_argument("--env", type=str, required=True,
choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
help="Environment name")
parser.add_argument("--n_eval_episodes", type=int, default=100,
help="Number of evaluation episodes")
parser.add_argument("--seed", type=int, default=0,
help="Evaluation seed")
parser.add_argument("--device", type=str, default="cpu",
choices=["cpu", "cuda"],
help="Device for inference")
parser.add_argument("--output", type=str, default=None,
help="Path to save evaluation results JSON")
parser.add_argument("--deterministic", action="store_true",
help="Use deterministic (greedy) actions")
return parser.parse_args()
def load_checkpoint(ckpt_path, device="cpu"):
"""Load a checkpoint and return the state dicts and config."""
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
print(f"Loaded checkpoint: {ckpt_path}")
print(f" Keys: {list(ckpt.keys())}")
if "config" in ckpt:
cfg = ckpt["config"]
print(f" Environment: {cfg.get('env_name', 'unknown')}")
print(f" Method: {cfg.get('mode', 'unknown')}")
print(f" Seed: {cfg.get('seed', 'unknown')}")
if "final_metrics" in ckpt:
m = ckpt["final_metrics"]
print(f" Stored welfare: {m.get('total_welfare', 'N/A')}")
print(f" Stored constraint sat: {m.get('constraint_satisfaction_pct', 'N/A')}%")
return ckpt
def make_env(env_name, seed=0):
"""Create the evaluation environment."""
if env_name in MAMUJOCO_ENVS:
scenario, agent_conf = MAMUJOCO_ENVS[env_name]
try:
from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
except ImportError:
print("ERROR: mappo_lagrangian package not installed.")
print("Install from: macpo_base/MAPPO-Lagrangian/")
print(" cd macpo_base/MAPPO-Lagrangian && pip install -e .")
sys.exit(1)
env_args = {
"scenario": scenario,
"agent_conf": agent_conf,
"agent_obsk": 1,
"episode_limit": 1000,
}
env = MujocoMulti(env_args=env_args)
env.seed(seed)
return env
elif env_name == "harvest_5":
try:
from mappo_lagrangian.envs.harvest import HarvestEnv
except ImportError:
print("ERROR: Harvest environment not found in mappo_lagrangian.")
sys.exit(1)
env = HarvestEnv(n_agents=5)
env.seed(seed)
return env
else:
raise ValueError(f"Unknown environment: {env_name}")
def evaluate(env, checkpoint, n_episodes, device="cpu", deterministic=False):
"""Run evaluation episodes and collect metrics."""
actor_state = checkpoint.get("actor_state_dict")
if actor_state is None:
print("WARNING: No actor_state_dict in checkpoint. Reporting stored metrics only.")
return checkpoint.get("final_metrics", {})
# This is a template -- actual model loading depends on the exact architecture
# used in training. Users should adapt this to their setup.
print(f"\nRunning {n_episodes} evaluation episodes...")
print("NOTE: This script provides a template. Adapt the model loading")
print(" to match your MAPPO-Lagrangian actor architecture.")
all_rewards = []
all_costs = []
for ep in range(n_episodes):
obs = env.reset()
episode_reward = 0.0
episode_cost = 0.0
done = False
step = 0
while not done and step < 1000:
# Template: replace with actual policy forward pass
# action = policy(obs, deterministic=deterministic)
n_agents = env.n_agents if hasattr(env, "n_agents") else 1
action = [env.action_space[i].sample() for i in range(n_agents)]
obs, rewards, dones, infos = env.step(action)
episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
if isinstance(infos, list):
episode_cost += sum(info.get("cost", 0) for info in infos)
elif isinstance(infos, dict):
episode_cost += infos.get("cost", 0)
done = all(dones) if isinstance(dones, list) else dones
step += 1
all_rewards.append(episode_reward)
all_costs.append(episode_cost)
if (ep + 1) % 20 == 0:
print(f" Episode {ep+1}/{n_episodes}: "
f"reward={np.mean(all_rewards[-20:]):.1f}, "
f"cost={np.mean(all_costs[-20:]):.2f}")
results = {
"n_episodes": n_episodes,
"mean_reward": float(np.mean(all_rewards)),
"std_reward": float(np.std(all_rewards)),
"mean_cost": float(np.mean(all_costs)),
"std_cost": float(np.std(all_costs)),
"constraint_satisfaction_pct": float(
100.0 * np.mean([c <= 1.0 for c in all_costs])
),
}
return results
def main():
args = parse_args()
# Load checkpoint
ckpt = load_checkpoint(args.checkpoint, device=args.device)
# If checkpoint has stored final_metrics, print them
if "final_metrics" in ckpt:
print("\n--- Stored Training Metrics ---")
for k, v in ckpt["final_metrics"].items():
print(f" {k}: {v}")
# Try to create env and evaluate
try:
env = make_env(args.env, seed=args.seed)
results = evaluate(
env, ckpt, args.n_eval_episodes,
device=args.device, deterministic=args.deterministic,
)
except (ImportError, SystemExit):
print("\nCannot create environment. Reporting stored metrics only.")
results = ckpt.get("final_metrics", {"error": "env not available"})
print("\n--- Evaluation Results ---")
for k, v in results.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
# Save results
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_path}")
if __name__ == "__main__":
main()