File size: 2,582 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")

from wm.model.diffusion.flow_matching import FlowMatchScheduler

def create_grid_z(B, T, H, W, D):
    # Create a black-white checkerboard pattern
    z = torch.zeros((B, T, H, W, D))
    block_size = 8
    for i in range(0, H, block_size):
        for j in range(0, W, block_size):
            if (i // block_size + j // block_size) % 2 == 0:
                z[:, :, i:i+block_size, j:j+block_size, :] = 1.0
    return z

def test_fm_loss_vis():
    device = "cpu"
    scheduler = FlowMatchScheduler()
    scheduler.set_timesteps(num_inference_steps=1000, training=True)
    
    B, T, H, W, D = 4, 9, 64, 64, 3
    z = create_grid_z(B, T, H, W, D).to(device)
    
    # Logic from flow_matching_loss
    t_indices = torch.randint(0, scheduler.timesteps.shape[0], (B, ), device=device)
    t = t_indices.unsqueeze(1).expand(-1, T).clone()
    t_values = scheduler.timesteps[t].clone()
    
    # Set first frame to clean (t=0) or small noise
    set_to_0 = torch.rand(B, device=device) < 0.5
    for b in range(B):
        if set_to_0[b]:
            t_values[b, 0] = 0
        else:
            small_noise_indices = torch.randint(len(scheduler.timesteps)-5, len(scheduler.timesteps), (1,), device=device)
            t_values[b, 0] = scheduler.timesteps[small_noise_indices]
            
    eps = torch.randn_like(z)
    z_t = scheduler.add_noise(z, eps, t_values)
    
    # Visualization
    fig, axes = plt.subplots(B, T, figsize=(T*2, B*2))
    plt.suptitle("Flow Matching Loss Input Visualization (z_t)\nRow = Batch Item (Different Noise Level), Col = Timestep (T=0:9)\nNote: Frame 0 is clean/low-noise", fontsize=16)
    
    for b in range(B):
        for t_idx in range(T):
            img = z_t[b, t_idx].numpy()
            img = np.clip(img, 0, 1)
            axes[b, t_idx].imshow(img)
            axes[b, t_idx].axis('off')
            if t_idx == 0:
                axes[b, t_idx].set_title(f"B{b} T_val={t_values[b, t_idx]:.1f}\n(Clean/Low)")
            else:
                axes[b, t_idx].set_title(f"T_val={t_values[b, t_idx]:.1f}")

    plt.tight_layout()
    output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/fm_loss_input_vis.png"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path)
    print(f"Visualization saved to {output_path}")

if __name__ == "__main__":
    test_fm_loss_vis()