import os import torch import torch.nn.functional as F import yaml import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import argparse import sys # Add project root to path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from wm.model.interface import get_dynamics_class from wm.dataset.dataset import RoboticsDatasetWrapper from torch.utils.data import DataLoader def load_checkpoint(model, path, device): print(f"Loading checkpoint: {path}") checkpoint = torch.load(path, map_location=device, weights_only=False) state_dict = checkpoint['model_state_dict'] # Filter out scheduler buffers that might cause size mismatches scheduler_buffers = [ 'scheduler.sigmas', 'scheduler.timesteps', 'scheduler.linear_timesteps_weights' ] new_state_dict = {} for k, v in state_dict.items(): # Remove "module." prefix from DDP name = k[7:] if k.startswith('module.') else k if name not in scheduler_buffers: new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=False) return checkpoint.get('step', 0) def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to the yaml config file") parser.add_argument("--ckpt_dir", type=str, default="checkpoints/lang_table_fulltraj_dit_v1", help="Directory containing checkpoints") parser.add_argument("--results_dir", type=str, default="results", help="Directory to save plots") parser.add_argument("--num_samples", type=int, default=-1, help="Number of validation samples to evaluate (-1 for all)") parser.add_argument("--batch_size", type=int, default=64, help="Batch size for evaluation") parser.add_argument("--inference_steps", type=int, default=50, help="Number of denoising steps") parser.add_argument("--noise_level", type=float, default=0.0, help="Noise level (t0) for the first frame (0.0 to 1.0)") parser.add_argument("--label", type=str, default="", help="Label for the evaluation run") args = parser.parse_args() os.makedirs(args.results_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If label is not provided, generate one if not args.label: args.label = f"{args.inference_steps}steps" if args.noise_level > 0: args.label += f"_noise{args.noise_level}" # 1. Load Config with open(args.config, 'r') as f: config = yaml.safe_load(f) # 2. Setup Dataset print("Initializing dataset...") full_dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) # Pick validation samples from the end as in training split num_train = int(len(full_dataset) * (config['dataset']['train_test_split'] / (config['dataset']['train_test_split'] + 1))) val_indices = list(range(num_train, len(full_dataset))) # Limit number of samples if requested if args.num_samples > 0: import random random.seed(42) random.shuffle(val_indices) val_indices = val_indices[:args.num_samples] print(f"Evaluating on {len(val_indices)} samples.") val_subset = torch.utils.data.Subset(full_dataset, val_indices) val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False, num_workers=8) # 3. Initialize Model print("Initializing model...") dynamics_class = get_dynamics_class(config['dynamics_class']) model = dynamics_class(config['model_name'], config['model_config']) model.to(device) model.eval() # 4. Find Checkpoints ckpt_files = [f for f in os.listdir(args.ckpt_dir) if f.endswith(".pt") and "checkpoint_" in f] def get_step(f): try: return int(f.split('_')[1].split('.')[0]) except: return -1 ckpt_files.sort(key=get_step) if not ckpt_files: print(f"No checkpoints found in {args.ckpt_dir}") return results = [] # 5. Evaluate each checkpoint for ckpt_file in ckpt_files: ckpt_path = os.path.join(args.ckpt_dir, ckpt_file) step = load_checkpoint(model, ckpt_path, device) all_mse = [] with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Eval Step {step}"): obs = batch['obs'].to(device) # [B, T, C, H, W] action = batch['action'].to(device) # [B, T, A] # First frame: [B, H, W, 3] -> permute from [B, C, H, W] o_0 = obs[:, 0].permute(0, 2, 3, 1).contiguous() # Generate rollout pred_video = model.generate( o_0, action, num_inference_steps=args.inference_steps, noise_level=args.noise_level ) # Ground truth: [B, T, C, H, W] -> [B, T, H, W, 3] gt_video = obs.permute(0, 1, 3, 4, 2).contiguous() # Calculate MSE per sample (average over time and space) mse = (pred_video - gt_video)**2 mse_per_sample = mse.mean(dim=(1, 2, 3, 4)) # [B] all_mse.append(mse_per_sample.cpu().numpy()) # Calculate statistics across all samples all_mse_np = np.concatenate(all_mse, axis=0) mean_mse = np.mean(all_mse_np) p25 = np.percentile(all_mse_np, 25) p75 = np.percentile(all_mse_np, 75) print(f"Step {step} | Mean MSE: {mean_mse:.6f} | P25: {p25:.6f} | P75: {p75:.6f}") results.append({ 'step': step, 'mean': mean_mse, 'p25': p25, 'p75': p75 }) # 6. Plot and Save results.sort(key=lambda x: x['step']) steps = [r['step'] for r in results] means = [r['mean'] for r in results] p25s = [r['p25'] for r in results] p75s = [r['p75'] for r in results] plt.figure(figsize=(10, 6)) plt.plot(steps, means, marker='o', linestyle='-', color='b', label='Mean MSE') plt.fill_between(steps, p25s, p75s, color='b', alpha=0.2, label='25th-75th Percentile') plt.title(f"Evaluation MSE Curve - {config['dataset']['name']} ({args.label})") plt.xlabel("Training Steps") plt.ylabel("Mean RGB MSE (Full Trajectory)") plt.legend() plt.grid(True) plot_path = os.path.join(args.results_dir, f"mse_curve_{config['dataset']['name']}_{args.label}.png") plt.savefig(plot_path) print(f"Saved MSE curve to {plot_path}") # Save raw results to a text file with open(os.path.join(args.results_dir, f"mse_results_{config['dataset']['name']}_{args.label}.txt"), "w") as f: f.write("Step,Mean_MSE,P25,P75\n") for r in results: f.write(f"{r['step']},{r['mean']:.8f},{r['p25']:.8f},{r['p75']:.8f}\n") if __name__ == "__main__": main()