| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | import yaml |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from tqdm import tqdm |
| | import argparse |
| | import sys |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
| |
|
| | from wm.model.interface import get_dynamics_class |
| | from wm.dataset.dataset import RoboticsDatasetWrapper |
| | from torch.utils.data import DataLoader |
| |
|
| | def load_checkpoint(model, path, device): |
| | print(f"Loading checkpoint: {path}") |
| | checkpoint = torch.load(path, map_location=device, weights_only=False) |
| | state_dict = checkpoint['model_state_dict'] |
| | |
| | |
| | scheduler_buffers = [ |
| | 'scheduler.sigmas', |
| | 'scheduler.timesteps', |
| | 'scheduler.linear_timesteps_weights' |
| | ] |
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | |
| | name = k[7:] if k.startswith('module.') else k |
| | if name not in scheduler_buffers: |
| | new_state_dict[name] = v |
| | |
| | model.load_state_dict(new_state_dict, strict=False) |
| | return checkpoint.get('step', 0) |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True, help="Path to the yaml config file") |
| | parser.add_argument("--ckpt_dir", type=str, default="checkpoints/lang_table_fulltraj_dit_v1", help="Directory containing checkpoints") |
| | parser.add_argument("--results_dir", type=str, default="results", help="Directory to save plots") |
| | parser.add_argument("--num_samples", type=int, default=-1, help="Number of validation samples to evaluate (-1 for all)") |
| | parser.add_argument("--batch_size", type=int, default=64, help="Batch size for evaluation") |
| | parser.add_argument("--inference_steps", type=int, default=50, help="Number of denoising steps") |
| | parser.add_argument("--noise_level", type=float, default=0.0, help="Noise level (t0) for the first frame (0.0 to 1.0)") |
| | parser.add_argument("--label", type=str, default="", help="Label for the evaluation run") |
| | args = parser.parse_args() |
| |
|
| | os.makedirs(args.results_dir, exist_ok=True) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | if not args.label: |
| | args.label = f"{args.inference_steps}steps" |
| | if args.noise_level > 0: |
| | args.label += f"_noise{args.noise_level}" |
| |
|
| | |
| | with open(args.config, 'r') as f: |
| | config = yaml.safe_load(f) |
| |
|
| | |
| | print("Initializing dataset...") |
| | full_dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) |
| | |
| | |
| | num_train = int(len(full_dataset) * (config['dataset']['train_test_split'] / (config['dataset']['train_test_split'] + 1))) |
| | val_indices = list(range(num_train, len(full_dataset))) |
| | |
| | |
| | if args.num_samples > 0: |
| | import random |
| | random.seed(42) |
| | random.shuffle(val_indices) |
| | val_indices = val_indices[:args.num_samples] |
| | |
| | print(f"Evaluating on {len(val_indices)} samples.") |
| | val_subset = torch.utils.data.Subset(full_dataset, val_indices) |
| | val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False, num_workers=8) |
| |
|
| | |
| | print("Initializing model...") |
| | dynamics_class = get_dynamics_class(config['dynamics_class']) |
| | model = dynamics_class(config['model_name'], config['model_config']) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | ckpt_files = [f for f in os.listdir(args.ckpt_dir) if f.endswith(".pt") and "checkpoint_" in f] |
| | |
| | def get_step(f): |
| | try: |
| | return int(f.split('_')[1].split('.')[0]) |
| | except: |
| | return -1 |
| | |
| | ckpt_files.sort(key=get_step) |
| | |
| | if not ckpt_files: |
| | print(f"No checkpoints found in {args.ckpt_dir}") |
| | return |
| |
|
| | results = [] |
| |
|
| | |
| | for ckpt_file in ckpt_files: |
| | ckpt_path = os.path.join(args.ckpt_dir, ckpt_file) |
| | step = load_checkpoint(model, ckpt_path, device) |
| | |
| | all_mse = [] |
| | |
| | with torch.no_grad(): |
| | for batch in tqdm(val_loader, desc=f"Eval Step {step}"): |
| | obs = batch['obs'].to(device) |
| | action = batch['action'].to(device) |
| | |
| | |
| | o_0 = obs[:, 0].permute(0, 2, 3, 1).contiguous() |
| | |
| | |
| | pred_video = model.generate( |
| | o_0, action, |
| | num_inference_steps=args.inference_steps, |
| | noise_level=args.noise_level |
| | ) |
| | |
| | |
| | gt_video = obs.permute(0, 1, 3, 4, 2).contiguous() |
| | |
| | |
| | mse = (pred_video - gt_video)**2 |
| | mse_per_sample = mse.mean(dim=(1, 2, 3, 4)) |
| | all_mse.append(mse_per_sample.cpu().numpy()) |
| | |
| | |
| | all_mse_np = np.concatenate(all_mse, axis=0) |
| | mean_mse = np.mean(all_mse_np) |
| | p25 = np.percentile(all_mse_np, 25) |
| | p75 = np.percentile(all_mse_np, 75) |
| | |
| | print(f"Step {step} | Mean MSE: {mean_mse:.6f} | P25: {p25:.6f} | P75: {p75:.6f}") |
| | results.append({ |
| | 'step': step, |
| | 'mean': mean_mse, |
| | 'p25': p25, |
| | 'p75': p75 |
| | }) |
| |
|
| | |
| | results.sort(key=lambda x: x['step']) |
| | steps = [r['step'] for r in results] |
| | means = [r['mean'] for r in results] |
| | p25s = [r['p25'] for r in results] |
| | p75s = [r['p75'] for r in results] |
| |
|
| | plt.figure(figsize=(10, 6)) |
| | plt.plot(steps, means, marker='o', linestyle='-', color='b', label='Mean MSE') |
| | plt.fill_between(steps, p25s, p75s, color='b', alpha=0.2, label='25th-75th Percentile') |
| | plt.title(f"Evaluation MSE Curve - {config['dataset']['name']} ({args.label})") |
| | plt.xlabel("Training Steps") |
| | plt.ylabel("Mean RGB MSE (Full Trajectory)") |
| | plt.legend() |
| | plt.grid(True) |
| | |
| | plot_path = os.path.join(args.results_dir, f"mse_curve_{config['dataset']['name']}_{args.label}.png") |
| | plt.savefig(plot_path) |
| | print(f"Saved MSE curve to {plot_path}") |
| | |
| | |
| | with open(os.path.join(args.results_dir, f"mse_results_{config['dataset']['name']}_{args.label}.txt"), "w") as f: |
| | f.write("Step,Mean_MSE,P25,P75\n") |
| | for r in results: |
| | f.write(f"{r['step']},{r['mean']:.8f},{r['p25']:.8f},{r['p75']:.8f}\n") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|