File size: 7,314 Bytes
6d5f894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python3
"""Evaluate a trained checkpoint on Safe Multi-Agent MuJoCo or Harvest.

Usage:
    python eval.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
                   --env mamujoco_HalfCheetah_6x1 \
                   --n_eval_episodes 100
"""

import argparse
import json
import sys
from pathlib import Path

import numpy as np
import torch


# Environment name -> (scenario, agent_conf) for MuJoCo envs
MAMUJOCO_ENVS = {
    "mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
    "mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
    "mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate a trained checkpoint")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to .pt checkpoint file")
    parser.add_argument("--env", type=str, required=True,
                        choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
                        help="Environment name")
    parser.add_argument("--n_eval_episodes", type=int, default=100,
                        help="Number of evaluation episodes")
    parser.add_argument("--seed", type=int, default=0,
                        help="Evaluation seed")
    parser.add_argument("--device", type=str, default="cpu",
                        choices=["cpu", "cuda"],
                        help="Device for inference")
    parser.add_argument("--output", type=str, default=None,
                        help="Path to save evaluation results JSON")
    parser.add_argument("--deterministic", action="store_true",
                        help="Use deterministic (greedy) actions")
    return parser.parse_args()


def load_checkpoint(ckpt_path, device="cpu"):
    """Load a checkpoint and return the state dicts and config."""
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    print(f"Loaded checkpoint: {ckpt_path}")
    print(f"  Keys: {list(ckpt.keys())}")
    if "config" in ckpt:
        cfg = ckpt["config"]
        print(f"  Environment: {cfg.get('env_name', 'unknown')}")
        print(f"  Method: {cfg.get('mode', 'unknown')}")
        print(f"  Seed: {cfg.get('seed', 'unknown')}")
    if "final_metrics" in ckpt:
        m = ckpt["final_metrics"]
        print(f"  Stored welfare: {m.get('total_welfare', 'N/A')}")
        print(f"  Stored constraint sat: {m.get('constraint_satisfaction_pct', 'N/A')}%")
    return ckpt


def make_env(env_name, seed=0):
    """Create the evaluation environment."""
    if env_name in MAMUJOCO_ENVS:
        scenario, agent_conf = MAMUJOCO_ENVS[env_name]
        try:
            from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
        except ImportError:
            print("ERROR: mappo_lagrangian package not installed.")
            print("Install from: macpo_base/MAPPO-Lagrangian/")
            print("  cd macpo_base/MAPPO-Lagrangian && pip install -e .")
            sys.exit(1)

        env_args = {
            "scenario": scenario,
            "agent_conf": agent_conf,
            "agent_obsk": 1,
            "episode_limit": 1000,
        }
        env = MujocoMulti(env_args=env_args)
        env.seed(seed)
        return env

    elif env_name == "harvest_5":
        try:
            from mappo_lagrangian.envs.harvest import HarvestEnv
        except ImportError:
            print("ERROR: Harvest environment not found in mappo_lagrangian.")
            sys.exit(1)
        env = HarvestEnv(n_agents=5)
        env.seed(seed)
        return env

    else:
        raise ValueError(f"Unknown environment: {env_name}")


def evaluate(env, checkpoint, n_episodes, device="cpu", deterministic=False):
    """Run evaluation episodes and collect metrics."""
    actor_state = checkpoint.get("actor_state_dict")
    if actor_state is None:
        print("WARNING: No actor_state_dict in checkpoint. Reporting stored metrics only.")
        return checkpoint.get("final_metrics", {})

    # This is a template -- actual model loading depends on the exact architecture
    # used in training. Users should adapt this to their setup.
    print(f"\nRunning {n_episodes} evaluation episodes...")
    print("NOTE: This script provides a template. Adapt the model loading")
    print("      to match your MAPPO-Lagrangian actor architecture.")

    all_rewards = []
    all_costs = []

    for ep in range(n_episodes):
        obs = env.reset()
        episode_reward = 0.0
        episode_cost = 0.0
        done = False
        step = 0

        while not done and step < 1000:
            # Template: replace with actual policy forward pass
            # action = policy(obs, deterministic=deterministic)
            n_agents = env.n_agents if hasattr(env, "n_agents") else 1
            action = [env.action_space[i].sample() for i in range(n_agents)]

            obs, rewards, dones, infos = env.step(action)
            episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
            if isinstance(infos, list):
                episode_cost += sum(info.get("cost", 0) for info in infos)
            elif isinstance(infos, dict):
                episode_cost += infos.get("cost", 0)
            done = all(dones) if isinstance(dones, list) else dones
            step += 1

        all_rewards.append(episode_reward)
        all_costs.append(episode_cost)

        if (ep + 1) % 20 == 0:
            print(f"  Episode {ep+1}/{n_episodes}: "
                  f"reward={np.mean(all_rewards[-20:]):.1f}, "
                  f"cost={np.mean(all_costs[-20:]):.2f}")

    results = {
        "n_episodes": n_episodes,
        "mean_reward": float(np.mean(all_rewards)),
        "std_reward": float(np.std(all_rewards)),
        "mean_cost": float(np.mean(all_costs)),
        "std_cost": float(np.std(all_costs)),
        "constraint_satisfaction_pct": float(
            100.0 * np.mean([c <= 1.0 for c in all_costs])
        ),
    }
    return results


def main():
    args = parse_args()

    # Load checkpoint
    ckpt = load_checkpoint(args.checkpoint, device=args.device)

    # If checkpoint has stored final_metrics, print them
    if "final_metrics" in ckpt:
        print("\n--- Stored Training Metrics ---")
        for k, v in ckpt["final_metrics"].items():
            print(f"  {k}: {v}")

    # Try to create env and evaluate
    try:
        env = make_env(args.env, seed=args.seed)
        results = evaluate(
            env, ckpt, args.n_eval_episodes,
            device=args.device, deterministic=args.deterministic,
        )
    except (ImportError, SystemExit):
        print("\nCannot create environment. Reporting stored metrics only.")
        results = ckpt.get("final_metrics", {"error": "env not available"})

    print("\n--- Evaluation Results ---")
    for k, v in results.items():
        if isinstance(v, float):
            print(f"  {k}: {v:.4f}")
        else:
            print(f"  {k}: {v}")

    # Save results
    if args.output:
        output_path = Path(args.output)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {output_path}")


if __name__ == "__main__":
    main()