import torch import numpy as np import matplotlib.pyplot as plt import os import sys # Add project root to sys.path sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") from wm.model.diffusion.flow_matching import FlowMatchScheduler def create_grid_z(B, T, H, W, D): # Create a black-white checkerboard pattern z = torch.zeros((B, T, H, W, D)) block_size = 8 for i in range(0, H, block_size): for j in range(0, W, block_size): if (i // block_size + j // block_size) % 2 == 0: z[:, :, i:i+block_size, j:j+block_size, :] = 1.0 return z def test_fm_loss_vis(): device = "cpu" scheduler = FlowMatchScheduler() scheduler.set_timesteps(num_inference_steps=1000, training=True) B, T, H, W, D = 4, 9, 64, 64, 3 z = create_grid_z(B, T, H, W, D).to(device) # Logic from flow_matching_loss t_indices = torch.randint(0, scheduler.timesteps.shape[0], (B, ), device=device) t = t_indices.unsqueeze(1).expand(-1, T).clone() t_values = scheduler.timesteps[t].clone() # Set first frame to clean (t=0) or small noise set_to_0 = torch.rand(B, device=device) < 0.5 for b in range(B): if set_to_0[b]: t_values[b, 0] = 0 else: small_noise_indices = torch.randint(len(scheduler.timesteps)-5, len(scheduler.timesteps), (1,), device=device) t_values[b, 0] = scheduler.timesteps[small_noise_indices] eps = torch.randn_like(z) z_t = scheduler.add_noise(z, eps, t_values) # Visualization fig, axes = plt.subplots(B, T, figsize=(T*2, B*2)) plt.suptitle("Flow Matching Loss Input Visualization (z_t)\nRow = Batch Item (Different Noise Level), Col = Timestep (T=0:9)\nNote: Frame 0 is clean/low-noise", fontsize=16) for b in range(B): for t_idx in range(T): img = z_t[b, t_idx].numpy() img = np.clip(img, 0, 1) axes[b, t_idx].imshow(img) axes[b, t_idx].axis('off') if t_idx == 0: axes[b, t_idx].set_title(f"B{b} T_val={t_values[b, t_idx]:.1f}\n(Clean/Low)") else: axes[b, t_idx].set_title(f"T_val={t_values[b, t_idx]:.1f}") plt.tight_layout() output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/fm_loss_input_vis.png" os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.savefig(output_path) print(f"Visualization saved to {output_path}") if __name__ == "__main__": test_fm_loss_vis()