File size: 5,735 Bytes
867ae15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""Render evaluation videos from a trained checkpoint.

Usage:
    python render.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
                     --env mamujoco_HalfCheetah_6x1 \
                     --output_dir videos/ \
                     --n_episodes 3
"""

import argparse
import sys
from pathlib import Path

import numpy as np
import torch


MAMUJOCO_ENVS = {
    "mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
    "mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
    "mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}


def parse_args():
    parser = argparse.ArgumentParser(description="Render evaluation videos")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to .pt checkpoint file")
    parser.add_argument("--env", type=str, required=True,
                        choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
                        help="Environment name")
    parser.add_argument("--output_dir", type=str, default="videos",
                        help="Directory to save rendered videos")
    parser.add_argument("--n_episodes", type=int, default=3,
                        help="Number of episodes to render")
    parser.add_argument("--seed", type=int, default=0,
                        help="Rendering seed")
    parser.add_argument("--fps", type=int, default=30,
                        help="Frames per second for video")
    parser.add_argument("--width", type=int, default=640,
                        help="Video width in pixels")
    parser.add_argument("--height", type=int, default=480,
                        help="Video height in pixels")
    return parser.parse_args()


def make_render_env(env_name, seed=0):
    """Create environment with rendering enabled."""
    if env_name in MAMUJOCO_ENVS:
        scenario, agent_conf = MAMUJOCO_ENVS[env_name]
        try:
            from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
        except ImportError:
            print("ERROR: mappo_lagrangian package not installed.")
            print("Install from: macpo_base/MAPPO-Lagrangian/")
            sys.exit(1)

        env_args = {
            "scenario": scenario,
            "agent_conf": agent_conf,
            "agent_obsk": 1,
            "episode_limit": 1000,
        }
        env = MujocoMulti(env_args=env_args)
        env.seed(seed)
        return env
    else:
        print(f"Rendering not supported for {env_name}")
        sys.exit(1)


def render_episodes(env, checkpoint, n_episodes, output_dir, fps=30,
                    width=640, height=480):
    """Render episodes and save as videos."""
    try:
        import imageio
    except ImportError:
        print("ERROR: imageio required for video rendering.")
        print("  pip install imageio imageio-ffmpeg")
        sys.exit(1)

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for ep in range(n_episodes):
        frames = []
        obs = env.reset()
        done = False
        step = 0
        episode_reward = 0.0
        episode_cost = 0.0

        while not done and step < 1000:
            # Render frame
            try:
                frame = env.render(mode="rgb_array")
                if frame is not None:
                    frames.append(frame)
            except Exception as e:
                if step == 0:
                    print(f"  Warning: render failed ({e}). Trying offscreen...")
                    try:
                        frame = env.render(
                            mode="rgb_array",
                            width=width,
                            height=height,
                        )
                        if frame is not None:
                            frames.append(frame)
                    except Exception:
                        print("  Offscreen rendering also failed. Skipping video.")
                        return

            # Template: replace with actual policy forward pass
            n_agents = env.n_agents if hasattr(env, "n_agents") else 1
            action = [env.action_space[i].sample() for i in range(n_agents)]

            obs, rewards, dones, infos = env.step(action)
            episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
            if isinstance(infos, list):
                episode_cost += sum(info.get("cost", 0) for info in infos)
            done = all(dones) if isinstance(dones, list) else dones
            step += 1

        if frames:
            video_path = output_dir / f"episode_{ep:03d}.mp4"
            imageio.mimwrite(str(video_path), frames, fps=fps)
            print(f"  Episode {ep}: {step} steps, reward={episode_reward:.1f}, "
                  f"cost={episode_cost:.2f} -> {video_path}")
        else:
            print(f"  Episode {ep}: no frames captured")


def main():
    args = parse_args()

    print(f"Loading checkpoint: {args.checkpoint}")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)

    if "final_metrics" in ckpt:
        m = ckpt["final_metrics"]
        print(f"  Welfare: {m.get('total_welfare', 'N/A')}")
        print(f"  Final rho: {m.get('final_rho', 'N/A')}")

    try:
        env = make_render_env(args.env, seed=args.seed)
    except (ImportError, SystemExit):
        print("Cannot create rendering environment. Exiting.")
        sys.exit(1)

    print(f"\nRendering {args.n_episodes} episodes to {args.output_dir}/")
    render_episodes(
        env, ckpt, args.n_episodes, args.output_dir,
        fps=args.fps, width=args.width, height=args.height,
    )
    print("Done.")


if __name__ == "__main__":
    main()