Sean13's picture
Add scripts/render.py
867ae15 verified
Raw
History Blame Contribute Delete
5.74 kB
#!/usr/bin/env python3
"""Render evaluation videos from a trained checkpoint.
Usage:
python render.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \
--env mamujoco_HalfCheetah_6x1 \
--output_dir videos/ \
--n_episodes 3
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
MAMUJOCO_ENVS = {
"mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"),
"mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"),
"mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Render evaluation videos")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to .pt checkpoint file")
parser.add_argument("--env", type=str, required=True,
choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"],
help="Environment name")
parser.add_argument("--output_dir", type=str, default="videos",
help="Directory to save rendered videos")
parser.add_argument("--n_episodes", type=int, default=3,
help="Number of episodes to render")
parser.add_argument("--seed", type=int, default=0,
help="Rendering seed")
parser.add_argument("--fps", type=int, default=30,
help="Frames per second for video")
parser.add_argument("--width", type=int, default=640,
help="Video width in pixels")
parser.add_argument("--height", type=int, default=480,
help="Video height in pixels")
return parser.parse_args()
def make_render_env(env_name, seed=0):
"""Create environment with rendering enabled."""
if env_name in MAMUJOCO_ENVS:
scenario, agent_conf = MAMUJOCO_ENVS[env_name]
try:
from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti
except ImportError:
print("ERROR: mappo_lagrangian package not installed.")
print("Install from: macpo_base/MAPPO-Lagrangian/")
sys.exit(1)
env_args = {
"scenario": scenario,
"agent_conf": agent_conf,
"agent_obsk": 1,
"episode_limit": 1000,
}
env = MujocoMulti(env_args=env_args)
env.seed(seed)
return env
else:
print(f"Rendering not supported for {env_name}")
sys.exit(1)
def render_episodes(env, checkpoint, n_episodes, output_dir, fps=30,
width=640, height=480):
"""Render episodes and save as videos."""
try:
import imageio
except ImportError:
print("ERROR: imageio required for video rendering.")
print(" pip install imageio imageio-ffmpeg")
sys.exit(1)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for ep in range(n_episodes):
frames = []
obs = env.reset()
done = False
step = 0
episode_reward = 0.0
episode_cost = 0.0
while not done and step < 1000:
# Render frame
try:
frame = env.render(mode="rgb_array")
if frame is not None:
frames.append(frame)
except Exception as e:
if step == 0:
print(f" Warning: render failed ({e}). Trying offscreen...")
try:
frame = env.render(
mode="rgb_array",
width=width,
height=height,
)
if frame is not None:
frames.append(frame)
except Exception:
print(" Offscreen rendering also failed. Skipping video.")
return
# Template: replace with actual policy forward pass
n_agents = env.n_agents if hasattr(env, "n_agents") else 1
action = [env.action_space[i].sample() for i in range(n_agents)]
obs, rewards, dones, infos = env.step(action)
episode_reward += sum(rewards) if isinstance(rewards, list) else rewards
if isinstance(infos, list):
episode_cost += sum(info.get("cost", 0) for info in infos)
done = all(dones) if isinstance(dones, list) else dones
step += 1
if frames:
video_path = output_dir / f"episode_{ep:03d}.mp4"
imageio.mimwrite(str(video_path), frames, fps=fps)
print(f" Episode {ep}: {step} steps, reward={episode_reward:.1f}, "
f"cost={episode_cost:.2f} -> {video_path}")
else:
print(f" Episode {ep}: no frames captured")
def main():
args = parse_args()
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
if "final_metrics" in ckpt:
m = ckpt["final_metrics"]
print(f" Welfare: {m.get('total_welfare', 'N/A')}")
print(f" Final rho: {m.get('final_rho', 'N/A')}")
try:
env = make_render_env(args.env, seed=args.seed)
except (ImportError, SystemExit):
print("Cannot create rendering environment. Exiting.")
sys.exit(1)
print(f"\nRendering {args.n_episodes} episodes to {args.output_dir}/")
render_episodes(
env, ckpt, args.n_episodes, args.output_dir,
fps=args.fps, width=args.width, height=args.height,
)
print("Done.")
if __name__ == "__main__":
main()