world_model / wm /eval /eval_checkpoints.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import os
import torch
import torch.nn.functional as F
import yaml
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import sys
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from wm.model.interface import get_dynamics_class
from wm.dataset.dataset import RoboticsDatasetWrapper
from torch.utils.data import DataLoader
def load_checkpoint(model, path, device):
print(f"Loading checkpoint: {path}")
checkpoint = torch.load(path, map_location=device, weights_only=False)
state_dict = checkpoint['model_state_dict']
# Filter out scheduler buffers that might cause size mismatches
scheduler_buffers = [
'scheduler.sigmas',
'scheduler.timesteps',
'scheduler.linear_timesteps_weights'
]
new_state_dict = {}
for k, v in state_dict.items():
# Remove "module." prefix from DDP
name = k[7:] if k.startswith('module.') else k
if name not in scheduler_buffers:
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
return checkpoint.get('step', 0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to the yaml config file")
parser.add_argument("--ckpt_dir", type=str, default="checkpoints/lang_table_fulltraj_dit_v1", help="Directory containing checkpoints")
parser.add_argument("--results_dir", type=str, default="results", help="Directory to save plots")
parser.add_argument("--num_samples", type=int, default=-1, help="Number of validation samples to evaluate (-1 for all)")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size for evaluation")
parser.add_argument("--inference_steps", type=int, default=50, help="Number of denoising steps")
parser.add_argument("--noise_level", type=float, default=0.0, help="Noise level (t0) for the first frame (0.0 to 1.0)")
parser.add_argument("--label", type=str, default="", help="Label for the evaluation run")
args = parser.parse_args()
os.makedirs(args.results_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# If label is not provided, generate one
if not args.label:
args.label = f"{args.inference_steps}steps"
if args.noise_level > 0:
args.label += f"_noise{args.noise_level}"
# 1. Load Config
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# 2. Setup Dataset
print("Initializing dataset...")
full_dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name'])
# Pick validation samples from the end as in training split
num_train = int(len(full_dataset) * (config['dataset']['train_test_split'] / (config['dataset']['train_test_split'] + 1)))
val_indices = list(range(num_train, len(full_dataset)))
# Limit number of samples if requested
if args.num_samples > 0:
import random
random.seed(42)
random.shuffle(val_indices)
val_indices = val_indices[:args.num_samples]
print(f"Evaluating on {len(val_indices)} samples.")
val_subset = torch.utils.data.Subset(full_dataset, val_indices)
val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False, num_workers=8)
# 3. Initialize Model
print("Initializing model...")
dynamics_class = get_dynamics_class(config['dynamics_class'])
model = dynamics_class(config['model_name'], config['model_config'])
model.to(device)
model.eval()
# 4. Find Checkpoints
ckpt_files = [f for f in os.listdir(args.ckpt_dir) if f.endswith(".pt") and "checkpoint_" in f]
def get_step(f):
try:
return int(f.split('_')[1].split('.')[0])
except:
return -1
ckpt_files.sort(key=get_step)
if not ckpt_files:
print(f"No checkpoints found in {args.ckpt_dir}")
return
results = []
# 5. Evaluate each checkpoint
for ckpt_file in ckpt_files:
ckpt_path = os.path.join(args.ckpt_dir, ckpt_file)
step = load_checkpoint(model, ckpt_path, device)
all_mse = []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Eval Step {step}"):
obs = batch['obs'].to(device) # [B, T, C, H, W]
action = batch['action'].to(device) # [B, T, A]
# First frame: [B, H, W, 3] -> permute from [B, C, H, W]
o_0 = obs[:, 0].permute(0, 2, 3, 1).contiguous()
# Generate rollout
pred_video = model.generate(
o_0, action,
num_inference_steps=args.inference_steps,
noise_level=args.noise_level
)
# Ground truth: [B, T, C, H, W] -> [B, T, H, W, 3]
gt_video = obs.permute(0, 1, 3, 4, 2).contiguous()
# Calculate MSE per sample (average over time and space)
mse = (pred_video - gt_video)**2
mse_per_sample = mse.mean(dim=(1, 2, 3, 4)) # [B]
all_mse.append(mse_per_sample.cpu().numpy())
# Calculate statistics across all samples
all_mse_np = np.concatenate(all_mse, axis=0)
mean_mse = np.mean(all_mse_np)
p25 = np.percentile(all_mse_np, 25)
p75 = np.percentile(all_mse_np, 75)
print(f"Step {step} | Mean MSE: {mean_mse:.6f} | P25: {p25:.6f} | P75: {p75:.6f}")
results.append({
'step': step,
'mean': mean_mse,
'p25': p25,
'p75': p75
})
# 6. Plot and Save
results.sort(key=lambda x: x['step'])
steps = [r['step'] for r in results]
means = [r['mean'] for r in results]
p25s = [r['p25'] for r in results]
p75s = [r['p75'] for r in results]
plt.figure(figsize=(10, 6))
plt.plot(steps, means, marker='o', linestyle='-', color='b', label='Mean MSE')
plt.fill_between(steps, p25s, p75s, color='b', alpha=0.2, label='25th-75th Percentile')
plt.title(f"Evaluation MSE Curve - {config['dataset']['name']} ({args.label})")
plt.xlabel("Training Steps")
plt.ylabel("Mean RGB MSE (Full Trajectory)")
plt.legend()
plt.grid(True)
plot_path = os.path.join(args.results_dir, f"mse_curve_{config['dataset']['name']}_{args.label}.png")
plt.savefig(plot_path)
print(f"Saved MSE curve to {plot_path}")
# Save raw results to a text file
with open(os.path.join(args.results_dir, f"mse_results_{config['dataset']['name']}_{args.label}.txt"), "w") as f:
f.write("Step,Mean_MSE,P25,P75\n")
for r in results:
f.write(f"{r['step']},{r['mean']:.8f},{r['p25']:.8f},{r['p75']:.8f}\n")
if __name__ == "__main__":
main()