| |
| """Evaluate a trained checkpoint on Safe Multi-Agent MuJoCo or Harvest. |
| |
| Usage: |
| python eval.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \ |
| --env mamujoco_HalfCheetah_6x1 \ |
| --n_eval_episodes 100 |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| |
| MAMUJOCO_ENVS = { |
| "mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"), |
| "mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"), |
| "mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"), |
| } |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Evaluate a trained checkpoint") |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to .pt checkpoint file") |
| parser.add_argument("--env", type=str, required=True, |
| choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"], |
| help="Environment name") |
| parser.add_argument("--n_eval_episodes", type=int, default=100, |
| help="Number of evaluation episodes") |
| parser.add_argument("--seed", type=int, default=0, |
| help="Evaluation seed") |
| parser.add_argument("--device", type=str, default="cpu", |
| choices=["cpu", "cuda"], |
| help="Device for inference") |
| parser.add_argument("--output", type=str, default=None, |
| help="Path to save evaluation results JSON") |
| parser.add_argument("--deterministic", action="store_true", |
| help="Use deterministic (greedy) actions") |
| return parser.parse_args() |
|
|
|
|
| def load_checkpoint(ckpt_path, device="cpu"): |
| """Load a checkpoint and return the state dicts and config.""" |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| print(f"Loaded checkpoint: {ckpt_path}") |
| print(f" Keys: {list(ckpt.keys())}") |
| if "config" in ckpt: |
| cfg = ckpt["config"] |
| print(f" Environment: {cfg.get('env_name', 'unknown')}") |
| print(f" Method: {cfg.get('mode', 'unknown')}") |
| print(f" Seed: {cfg.get('seed', 'unknown')}") |
| if "final_metrics" in ckpt: |
| m = ckpt["final_metrics"] |
| print(f" Stored welfare: {m.get('total_welfare', 'N/A')}") |
| print(f" Stored constraint sat: {m.get('constraint_satisfaction_pct', 'N/A')}%") |
| return ckpt |
|
|
|
|
| def make_env(env_name, seed=0): |
| """Create the evaluation environment.""" |
| if env_name in MAMUJOCO_ENVS: |
| scenario, agent_conf = MAMUJOCO_ENVS[env_name] |
| try: |
| from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti |
| except ImportError: |
| print("ERROR: mappo_lagrangian package not installed.") |
| print("Install from: macpo_base/MAPPO-Lagrangian/") |
| print(" cd macpo_base/MAPPO-Lagrangian && pip install -e .") |
| sys.exit(1) |
|
|
| env_args = { |
| "scenario": scenario, |
| "agent_conf": agent_conf, |
| "agent_obsk": 1, |
| "episode_limit": 1000, |
| } |
| env = MujocoMulti(env_args=env_args) |
| env.seed(seed) |
| return env |
|
|
| elif env_name == "harvest_5": |
| try: |
| from mappo_lagrangian.envs.harvest import HarvestEnv |
| except ImportError: |
| print("ERROR: Harvest environment not found in mappo_lagrangian.") |
| sys.exit(1) |
| env = HarvestEnv(n_agents=5) |
| env.seed(seed) |
| return env |
|
|
| else: |
| raise ValueError(f"Unknown environment: {env_name}") |
|
|
|
|
| def evaluate(env, checkpoint, n_episodes, device="cpu", deterministic=False): |
| """Run evaluation episodes and collect metrics.""" |
| actor_state = checkpoint.get("actor_state_dict") |
| if actor_state is None: |
| print("WARNING: No actor_state_dict in checkpoint. Reporting stored metrics only.") |
| return checkpoint.get("final_metrics", {}) |
|
|
| |
| |
| print(f"\nRunning {n_episodes} evaluation episodes...") |
| print("NOTE: This script provides a template. Adapt the model loading") |
| print(" to match your MAPPO-Lagrangian actor architecture.") |
|
|
| all_rewards = [] |
| all_costs = [] |
|
|
| for ep in range(n_episodes): |
| obs = env.reset() |
| episode_reward = 0.0 |
| episode_cost = 0.0 |
| done = False |
| step = 0 |
|
|
| while not done and step < 1000: |
| |
| |
| n_agents = env.n_agents if hasattr(env, "n_agents") else 1 |
| action = [env.action_space[i].sample() for i in range(n_agents)] |
|
|
| obs, rewards, dones, infos = env.step(action) |
| episode_reward += sum(rewards) if isinstance(rewards, list) else rewards |
| if isinstance(infos, list): |
| episode_cost += sum(info.get("cost", 0) for info in infos) |
| elif isinstance(infos, dict): |
| episode_cost += infos.get("cost", 0) |
| done = all(dones) if isinstance(dones, list) else dones |
| step += 1 |
|
|
| all_rewards.append(episode_reward) |
| all_costs.append(episode_cost) |
|
|
| if (ep + 1) % 20 == 0: |
| print(f" Episode {ep+1}/{n_episodes}: " |
| f"reward={np.mean(all_rewards[-20:]):.1f}, " |
| f"cost={np.mean(all_costs[-20:]):.2f}") |
|
|
| results = { |
| "n_episodes": n_episodes, |
| "mean_reward": float(np.mean(all_rewards)), |
| "std_reward": float(np.std(all_rewards)), |
| "mean_cost": float(np.mean(all_costs)), |
| "std_cost": float(np.std(all_costs)), |
| "constraint_satisfaction_pct": float( |
| 100.0 * np.mean([c <= 1.0 for c in all_costs]) |
| ), |
| } |
| return results |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| |
| ckpt = load_checkpoint(args.checkpoint, device=args.device) |
|
|
| |
| if "final_metrics" in ckpt: |
| print("\n--- Stored Training Metrics ---") |
| for k, v in ckpt["final_metrics"].items(): |
| print(f" {k}: {v}") |
|
|
| |
| try: |
| env = make_env(args.env, seed=args.seed) |
| results = evaluate( |
| env, ckpt, args.n_eval_episodes, |
| device=args.device, deterministic=args.deterministic, |
| ) |
| except (ImportError, SystemExit): |
| print("\nCannot create environment. Reporting stored metrics only.") |
| results = ckpt.get("final_metrics", {"error": "env not available"}) |
|
|
| print("\n--- Evaluation Results ---") |
| for k, v in results.items(): |
| if isinstance(v, float): |
| print(f" {k}: {v:.4f}") |
| else: |
| print(f" {k}: {v}") |
|
|
| |
| if args.output: |
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\nResults saved to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|