File size: 2,140 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from wm.model.diffusion.flow_matching import FlowMatchScheduler

def test_independent_noise():
    os.makedirs("results/test_diffusion_forcing", exist_ok=True)
    
    # 1. Initialize scheduler
    scheduler = FlowMatchScheduler()
    num_steps = 50
    scheduler.set_timesteps(num_inference_steps=num_steps, training=True)
    
    # 2. Create a white-black grid as a "video" sequence
    # Shape: [B=1, T=5, C=3, H=256, W=256]
    B, T, C, H, W = 1, 5, 3, 256, 256
    video = torch.zeros(B, T, C, H, W)
    
    # Fill all frames with the SAME grid pattern to clearly see noise effects
    grid_size = 32
    for t in range(T):
        video[0, t, :, ::grid_size, :] = 1.0
        video[0, t, :, :, ::grid_size] = 1.0
        
    # 3. Sample independent timesteps for each frame (Randomly like in training)
    t_indices = torch.randint(0, len(scheduler.timesteps), (B, T))
    t_values = scheduler.timesteps[t_indices]
    
    # 4. Add independent noise
    video_noisy, noise = scheduler.add_independent_noise(video, t_values)
    
    # 5. Visualize
    fig, axes = plt.subplots(2, T, figsize=(15, 6))
    
    for t in range(T):
        # Original
        orig_img = video[0, t].permute(1, 2, 0).numpy()
        axes[0, t].imshow(orig_img)
        axes[0, t].set_title(f"Original Frame {t}")
        axes[0, t].axis('off')
        
        # Noisy
        noisy_img = video_noisy[0, t].permute(1, 2, 0).numpy()
        noisy_img = np.clip(noisy_img, 0, 1)
        axes[1, t].imshow(noisy_img)
        axes[1, t].set_title(f"Noisy (t={t_values[0, t].item():.0f})")
        axes[1, t].axis('off')
        
    plt.suptitle("Diffusion Forcing: Independent Noise Addition per Frame")
    plt.tight_layout()
    plt.savefig("results/test_diffusion_forcing/independent_noise_test.png")
    print("Saved test results to results/test_diffusion_forcing/independent_noise_test.png")

if __name__ == "__main__":
    test_independent_noise()