| |
| """Render evaluation videos from a trained checkpoint. |
| |
| Usage: |
| python render.py --checkpoint checkpoints/halfcheetah_6x1/ours_s0.pt \ |
| --env mamujoco_HalfCheetah_6x1 \ |
| --output_dir videos/ \ |
| --n_episodes 3 |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| MAMUJOCO_ENVS = { |
| "mamujoco_HalfCheetah_6x1": ("HalfCheetah-v2", "6x1"), |
| "mamujoco_Humanoid_9_8": ("Humanoid-v2", "9|8"), |
| "mamujoco_ManySegmentSwimmer_6x1": ("ManySegmentSwimmer", "6x1"), |
| } |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Render evaluation videos") |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to .pt checkpoint file") |
| parser.add_argument("--env", type=str, required=True, |
| choices=list(MAMUJOCO_ENVS.keys()) + ["harvest_5"], |
| help="Environment name") |
| parser.add_argument("--output_dir", type=str, default="videos", |
| help="Directory to save rendered videos") |
| parser.add_argument("--n_episodes", type=int, default=3, |
| help="Number of episodes to render") |
| parser.add_argument("--seed", type=int, default=0, |
| help="Rendering seed") |
| parser.add_argument("--fps", type=int, default=30, |
| help="Frames per second for video") |
| parser.add_argument("--width", type=int, default=640, |
| help="Video width in pixels") |
| parser.add_argument("--height", type=int, default=480, |
| help="Video height in pixels") |
| return parser.parse_args() |
|
|
|
|
| def make_render_env(env_name, seed=0): |
| """Create environment with rendering enabled.""" |
| if env_name in MAMUJOCO_ENVS: |
| scenario, agent_conf = MAMUJOCO_ENVS[env_name] |
| try: |
| from mappo_lagrangian.envs.safety_ma_mujoco.safety_multiagent_mujoco import MujocoMulti |
| except ImportError: |
| print("ERROR: mappo_lagrangian package not installed.") |
| print("Install from: macpo_base/MAPPO-Lagrangian/") |
| sys.exit(1) |
|
|
| env_args = { |
| "scenario": scenario, |
| "agent_conf": agent_conf, |
| "agent_obsk": 1, |
| "episode_limit": 1000, |
| } |
| env = MujocoMulti(env_args=env_args) |
| env.seed(seed) |
| return env |
| else: |
| print(f"Rendering not supported for {env_name}") |
| sys.exit(1) |
|
|
|
|
| def render_episodes(env, checkpoint, n_episodes, output_dir, fps=30, |
| width=640, height=480): |
| """Render episodes and save as videos.""" |
| try: |
| import imageio |
| except ImportError: |
| print("ERROR: imageio required for video rendering.") |
| print(" pip install imageio imageio-ffmpeg") |
| sys.exit(1) |
|
|
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for ep in range(n_episodes): |
| frames = [] |
| obs = env.reset() |
| done = False |
| step = 0 |
| episode_reward = 0.0 |
| episode_cost = 0.0 |
|
|
| while not done and step < 1000: |
| |
| try: |
| frame = env.render(mode="rgb_array") |
| if frame is not None: |
| frames.append(frame) |
| except Exception as e: |
| if step == 0: |
| print(f" Warning: render failed ({e}). Trying offscreen...") |
| try: |
| frame = env.render( |
| mode="rgb_array", |
| width=width, |
| height=height, |
| ) |
| if frame is not None: |
| frames.append(frame) |
| except Exception: |
| print(" Offscreen rendering also failed. Skipping video.") |
| return |
|
|
| |
| n_agents = env.n_agents if hasattr(env, "n_agents") else 1 |
| action = [env.action_space[i].sample() for i in range(n_agents)] |
|
|
| obs, rewards, dones, infos = env.step(action) |
| episode_reward += sum(rewards) if isinstance(rewards, list) else rewards |
| if isinstance(infos, list): |
| episode_cost += sum(info.get("cost", 0) for info in infos) |
| done = all(dones) if isinstance(dones, list) else dones |
| step += 1 |
|
|
| if frames: |
| video_path = output_dir / f"episode_{ep:03d}.mp4" |
| imageio.mimwrite(str(video_path), frames, fps=fps) |
| print(f" Episode {ep}: {step} steps, reward={episode_reward:.1f}, " |
| f"cost={episode_cost:.2f} -> {video_path}") |
| else: |
| print(f" Episode {ep}: no frames captured") |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print(f"Loading checkpoint: {args.checkpoint}") |
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
|
|
| if "final_metrics" in ckpt: |
| m = ckpt["final_metrics"] |
| print(f" Welfare: {m.get('total_welfare', 'N/A')}") |
| print(f" Final rho: {m.get('final_rho', 'N/A')}") |
|
|
| try: |
| env = make_render_env(args.env, seed=args.seed) |
| except (ImportError, SystemExit): |
| print("Cannot create rendering environment. Exiting.") |
| sys.exit(1) |
|
|
| print(f"\nRendering {args.n_episodes} episodes to {args.output_dir}/") |
| render_episodes( |
| env, ckpt, args.n_episodes, args.output_dir, |
| fps=args.fps, width=args.width, height=args.height, |
| ) |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|