world_model / wm /test /test_fm_loss_vis.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")
from wm.model.diffusion.flow_matching import FlowMatchScheduler
def create_grid_z(B, T, H, W, D):
# Create a black-white checkerboard pattern
z = torch.zeros((B, T, H, W, D))
block_size = 8
for i in range(0, H, block_size):
for j in range(0, W, block_size):
if (i // block_size + j // block_size) % 2 == 0:
z[:, :, i:i+block_size, j:j+block_size, :] = 1.0
return z
def test_fm_loss_vis():
device = "cpu"
scheduler = FlowMatchScheduler()
scheduler.set_timesteps(num_inference_steps=1000, training=True)
B, T, H, W, D = 4, 9, 64, 64, 3
z = create_grid_z(B, T, H, W, D).to(device)
# Logic from flow_matching_loss
t_indices = torch.randint(0, scheduler.timesteps.shape[0], (B, ), device=device)
t = t_indices.unsqueeze(1).expand(-1, T).clone()
t_values = scheduler.timesteps[t].clone()
# Set first frame to clean (t=0) or small noise
set_to_0 = torch.rand(B, device=device) < 0.5
for b in range(B):
if set_to_0[b]:
t_values[b, 0] = 0
else:
small_noise_indices = torch.randint(len(scheduler.timesteps)-5, len(scheduler.timesteps), (1,), device=device)
t_values[b, 0] = scheduler.timesteps[small_noise_indices]
eps = torch.randn_like(z)
z_t = scheduler.add_noise(z, eps, t_values)
# Visualization
fig, axes = plt.subplots(B, T, figsize=(T*2, B*2))
plt.suptitle("Flow Matching Loss Input Visualization (z_t)\nRow = Batch Item (Different Noise Level), Col = Timestep (T=0:9)\nNote: Frame 0 is clean/low-noise", fontsize=16)
for b in range(B):
for t_idx in range(T):
img = z_t[b, t_idx].numpy()
img = np.clip(img, 0, 1)
axes[b, t_idx].imshow(img)
axes[b, t_idx].axis('off')
if t_idx == 0:
axes[b, t_idx].set_title(f"B{b} T_val={t_values[b, t_idx]:.1f}\n(Clean/Low)")
else:
axes[b, t_idx].set_title(f"T_val={t_values[b, t_idx]:.1f}")
plt.tight_layout()
output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/fm_loss_input_vis.png"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
print(f"Visualization saved to {output_path}")
if __name__ == "__main__":
test_fm_loss_vis()