File size: 2,582 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")
from wm.model.diffusion.flow_matching import FlowMatchScheduler
def create_grid_z(B, T, H, W, D):
# Create a black-white checkerboard pattern
z = torch.zeros((B, T, H, W, D))
block_size = 8
for i in range(0, H, block_size):
for j in range(0, W, block_size):
if (i // block_size + j // block_size) % 2 == 0:
z[:, :, i:i+block_size, j:j+block_size, :] = 1.0
return z
def test_fm_loss_vis():
device = "cpu"
scheduler = FlowMatchScheduler()
scheduler.set_timesteps(num_inference_steps=1000, training=True)
B, T, H, W, D = 4, 9, 64, 64, 3
z = create_grid_z(B, T, H, W, D).to(device)
# Logic from flow_matching_loss
t_indices = torch.randint(0, scheduler.timesteps.shape[0], (B, ), device=device)
t = t_indices.unsqueeze(1).expand(-1, T).clone()
t_values = scheduler.timesteps[t].clone()
# Set first frame to clean (t=0) or small noise
set_to_0 = torch.rand(B, device=device) < 0.5
for b in range(B):
if set_to_0[b]:
t_values[b, 0] = 0
else:
small_noise_indices = torch.randint(len(scheduler.timesteps)-5, len(scheduler.timesteps), (1,), device=device)
t_values[b, 0] = scheduler.timesteps[small_noise_indices]
eps = torch.randn_like(z)
z_t = scheduler.add_noise(z, eps, t_values)
# Visualization
fig, axes = plt.subplots(B, T, figsize=(T*2, B*2))
plt.suptitle("Flow Matching Loss Input Visualization (z_t)\nRow = Batch Item (Different Noise Level), Col = Timestep (T=0:9)\nNote: Frame 0 is clean/low-noise", fontsize=16)
for b in range(B):
for t_idx in range(T):
img = z_t[b, t_idx].numpy()
img = np.clip(img, 0, 1)
axes[b, t_idx].imshow(img)
axes[b, t_idx].axis('off')
if t_idx == 0:
axes[b, t_idx].set_title(f"B{b} T_val={t_values[b, t_idx]:.1f}\n(Clean/Low)")
else:
axes[b, t_idx].set_title(f"T_val={t_values[b, t_idx]:.1f}")
plt.tight_layout()
output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/fm_loss_input_vis.png"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
print(f"Visualization saved to {output_path}")
if __name__ == "__main__":
test_fm_loss_vis()
|