File size: 2,140 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from wm.model.diffusion.flow_matching import FlowMatchScheduler
def test_independent_noise():
os.makedirs("results/test_diffusion_forcing", exist_ok=True)
# 1. Initialize scheduler
scheduler = FlowMatchScheduler()
num_steps = 50
scheduler.set_timesteps(num_inference_steps=num_steps, training=True)
# 2. Create a white-black grid as a "video" sequence
# Shape: [B=1, T=5, C=3, H=256, W=256]
B, T, C, H, W = 1, 5, 3, 256, 256
video = torch.zeros(B, T, C, H, W)
# Fill all frames with the SAME grid pattern to clearly see noise effects
grid_size = 32
for t in range(T):
video[0, t, :, ::grid_size, :] = 1.0
video[0, t, :, :, ::grid_size] = 1.0
# 3. Sample independent timesteps for each frame (Randomly like in training)
t_indices = torch.randint(0, len(scheduler.timesteps), (B, T))
t_values = scheduler.timesteps[t_indices]
# 4. Add independent noise
video_noisy, noise = scheduler.add_independent_noise(video, t_values)
# 5. Visualize
fig, axes = plt.subplots(2, T, figsize=(15, 6))
for t in range(T):
# Original
orig_img = video[0, t].permute(1, 2, 0).numpy()
axes[0, t].imshow(orig_img)
axes[0, t].set_title(f"Original Frame {t}")
axes[0, t].axis('off')
# Noisy
noisy_img = video_noisy[0, t].permute(1, 2, 0).numpy()
noisy_img = np.clip(noisy_img, 0, 1)
axes[1, t].imshow(noisy_img)
axes[1, t].set_title(f"Noisy (t={t_values[0, t].item():.0f})")
axes[1, t].axis('off')
plt.suptitle("Diffusion Forcing: Independent Noise Addition per Frame")
plt.tight_layout()
plt.savefig("results/test_diffusion_forcing/independent_noise_test.png")
print("Saved test results to results/test_diffusion_forcing/independent_noise_test.png")
if __name__ == "__main__":
test_independent_noise()
|