File size: 5,735 Bytes
867ae15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | #!/usr/bin/env python3
"""Render evaluation videos from a trained checkpoint.
Usage:
python render.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
--env mamujoco_HalfCheetah_6x1 \
--output_dir videos/ \
--n_episodes 3
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
MAMUJOCO_ENVS = {
"mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
"mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
"mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Render evaluation videos")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to .pt checkpoint file")
parser.add_argument("--env", type=str, required=True,
choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
help="Environment name")
parser.add_argument("--output_dir", type=str, default="videos",
help="Directory to save rendered videos")
parser.add_argument("--n_episodes", type=int, default=3,
help="Number of episodes to render")
parser.add_argument("--seed", type=int, default=0,
help="Rendering seed")
parser.add_argument("--fps", type=int, default=30,
help="Frames per second for video")
parser.add_argument("--width", type=int, default=640,
help="Video width in pixels")
parser.add_argument("--height", type=int, default=480,
help="Video height in pixels")
return parser.parse_args()
def make_render_env(env_name, seed=0):
"""Create environment with rendering enabled."""
if env_name in MAMUJOCO_ENVS:
scenario, agent_conf = MAMUJOCO_ENVS[env_name]
try:
from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
except ImportError:
print("ERROR: mappo_lagrangian package not installed.")
print("Install from: macpo_base/MAPPO-Lagrangian/")
sys.exit(1)
env_args = {
"scenario": scenario,
"agent_conf": agent_conf,
"agent_obsk": 1,
"episode_limit": 1000,
}
env = MujocoMulti(env_args=env_args)
env.seed(seed)
return env
else:
print(f"Rendering not supported for {env_name}")
sys.exit(1)
def render_episodes(env, checkpoint, n_episodes, output_dir, fps=30,
width=640, height=480):
"""Render episodes and save as videos."""
try:
import imageio
except ImportError:
print("ERROR: imageio required for video rendering.")
print(" pip install imageio imageio-ffmpeg")
sys.exit(1)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for ep in range(n_episodes):
frames = []
obs = env.reset()
done = False
step = 0
episode_reward = 0.0
episode_cost = 0.0
while not done and step < 1000:
# Render frame
try:
frame = env.render(mode="rgb_array")
if frame is not None:
frames.append(frame)
except Exception as e:
if step == 0:
print(f" Warning: render failed ({e}). Trying offscreen...")
try:
frame = env.render(
mode="rgb_array",
width=width,
height=height,
)
if frame is not None:
frames.append(frame)
except Exception:
print(" Offscreen rendering also failed. Skipping video.")
return
# Template: replace with actual policy forward pass
n_agents = env.n_agents if hasattr(env, "n_agents") else 1
action = [env.action_space[i].sample() for i in range(n_agents)]
obs, rewards, dones, infos = env.step(action)
episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
if isinstance(infos, list):
episode_cost += sum(info.get("cost", 0) for info in infos)
done = all(dones) if isinstance(dones, list) else dones
step += 1
if frames:
video_path = output_dir / f"episode_{ep:03d}.mp4"
imageio.mimwrite(str(video_path), frames, fps=fps)
print(f" Episode {ep}: {step} steps, reward={episode_reward:.1f}, "
f"cost={episode_cost:.2f} -> {video_path}")
else:
print(f" Episode {ep}: no frames captured")
def main():
args = parse_args()
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
if "final_metrics" in ckpt:
m = ckpt["final_metrics"]
print(f" Welfare: {m.get('total_welfare', 'N/A')}")
print(f" Final rho: {m.get('final_rho', 'N/A')}")
try:
env = make_render_env(args.env, seed=args.seed)
except (ImportError, SystemExit):
print("Cannot create rendering environment. Exiting.")
sys.exit(1)
print(f"\nRendering {args.n_episodes} episodes to {args.output_dir}/")
render_episodes(
env, ckpt, args.n_episodes, args.output_dir,
fps=args.fps, width=args.width, height=args.height,
)
print("Done.")
if __name__ == "__main__":
main()
|