| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") |
| |
|
| | from wm.model.diffusion.flow_matching import FlowMatchScheduler |
| |
|
| | def create_grid_z(B, T, H, W, D): |
| | |
| | z = torch.zeros((B, T, H, W, D)) |
| | block_size = 8 |
| | for i in range(0, H, block_size): |
| | for j in range(0, W, block_size): |
| | if (i // block_size + j // block_size) % 2 == 0: |
| | z[:, :, i:i+block_size, j:j+block_size, :] = 1.0 |
| | return z |
| |
|
| | def test_fm_loss_vis(): |
| | device = "cpu" |
| | scheduler = FlowMatchScheduler() |
| | scheduler.set_timesteps(num_inference_steps=1000, training=True) |
| | |
| | B, T, H, W, D = 4, 9, 64, 64, 3 |
| | z = create_grid_z(B, T, H, W, D).to(device) |
| | |
| | |
| | t_indices = torch.randint(0, scheduler.timesteps.shape[0], (B, ), device=device) |
| | t = t_indices.unsqueeze(1).expand(-1, T).clone() |
| | t_values = scheduler.timesteps[t].clone() |
| | |
| | |
| | set_to_0 = torch.rand(B, device=device) < 0.5 |
| | for b in range(B): |
| | if set_to_0[b]: |
| | t_values[b, 0] = 0 |
| | else: |
| | small_noise_indices = torch.randint(len(scheduler.timesteps)-5, len(scheduler.timesteps), (1,), device=device) |
| | t_values[b, 0] = scheduler.timesteps[small_noise_indices] |
| | |
| | eps = torch.randn_like(z) |
| | z_t = scheduler.add_noise(z, eps, t_values) |
| | |
| | |
| | fig, axes = plt.subplots(B, T, figsize=(T*2, B*2)) |
| | plt.suptitle("Flow Matching Loss Input Visualization (z_t)\nRow = Batch Item (Different Noise Level), Col = Timestep (T=0:9)\nNote: Frame 0 is clean/low-noise", fontsize=16) |
| | |
| | for b in range(B): |
| | for t_idx in range(T): |
| | img = z_t[b, t_idx].numpy() |
| | img = np.clip(img, 0, 1) |
| | axes[b, t_idx].imshow(img) |
| | axes[b, t_idx].axis('off') |
| | if t_idx == 0: |
| | axes[b, t_idx].set_title(f"B{b} T_val={t_values[b, t_idx]:.1f}\n(Clean/Low)") |
| | else: |
| | axes[b, t_idx].set_title(f"T_val={t_values[b, t_idx]:.1f}") |
| |
|
| | plt.tight_layout() |
| | output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/fm_loss_input_vis.png" |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | plt.savefig(output_path) |
| | print(f"Visualization saved to {output_path}") |
| |
|
| | if __name__ == "__main__": |
| | test_fm_loss_vis() |
| |
|