world_model / wm /model /diffusion /flow_matching.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import math
class FlowMatchScheduler(torch.nn.Module):
"""
A simplified Flow Matching scheduler specifically for the Wan template.
Supports scalars, [B], [B, T], and higher-dimensional timesteps.
"""
def __init__(self):
super().__init__()
self.num_train_timesteps = 1000
self.register_buffer("sigmas", None, persistent=False)
self.register_buffer("timesteps", None, persistent=False)
self.register_buffer("linear_timesteps_weights", None, persistent=False)
self.training = False # Renamed from self.training as nn.Module has a training attribute
@property
def device(self):
if self.timesteps is not None:
return self.timesteps.device
return torch.device('cpu')
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, shift=5.0, training=False):
"""
Sets the timesteps and sigmas for the Wan template.
"""
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
# Sigmas for Wan template: ensure we include 0.0 for clean samples
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
# Apply shift (default is 5 for Wan)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
# Move to the current device of the module
device = self.device
sigmas = sigmas.to(device)
timesteps = (sigmas * self.num_train_timesteps).to(device)
self.register_buffer("sigmas", sigmas, persistent=False)
self.register_buffer("timesteps", timesteps, persistent=False)
if training:
self.set_training_weight()
self.training = True
else:
self.training = False
def set_training_weight(self):
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
# This is an empirical formula.
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
# Move to the current device of the module
self.register_buffer("linear_timesteps_weights", bsmntw_weighing.to(self.device), persistent=False)
def _get_timestep_indices(self, timestep: torch.Tensor):
"""
Efficiently find the nearest indices in self.timesteps for input timesteps.
Supports any input shape by flattening, computing, and reshapping.
"""
if not isinstance(timestep, torch.Tensor):
timestep = torch.tensor(timestep, device=self.device)
t_input = timestep.to(self.device)
orig_shape = t_input.shape
# Flatten input to handle any shape (B, T, ...)
t_flat = t_input.reshape(-1, 1)
# Broadcast against self.timesteps [N] -> [len(t_flat), N]
diff = (t_flat - self.timesteps.unsqueeze(0)).abs()
indices = torch.argmin(diff, dim=-1)
return indices.view(orig_shape)
def step(self, model_output, timestep, sample, to_final=False):
indices = self._get_timestep_indices(timestep)
sigma = self.sigmas[indices]
if to_final:
sigma_next = torch.zeros_like(sigma)
else:
# Get next sigma, clamping to avoid out of bounds
next_indices = (indices + 1).clamp(max=len(self.sigmas) - 1)
sigma_next = self.sigmas[next_indices]
# If we were already at the last step, next sigma is 0
sigma_next = torch.where(indices + 1 >= len(self.sigmas), torch.zeros_like(sigma), sigma_next)
# Broadcast sigma diff to match sample shape (e.g. [B, T, C, H, W] or [B, C, H, W])
sigma_diff = (sigma_next - sigma).view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim)))
sigma_diff = sigma_diff.to(sample.device)
return sample + model_output * sigma_diff
def return_to_timestep(self, timestep, sample, sample_stablized):
indices = self._get_timestep_indices(timestep)
sigma = self.sigmas[indices]
sigma_view = sigma.view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim)))
sigma_view = sigma_view.to(sample.device)
model_output = (sample - sample_stablized) / sigma_view
return model_output
def add_noise(self, original_samples, noise, timestep):
indices = self._get_timestep_indices(timestep)
sigma = self.sigmas[indices]
# Broadcast sigma to match sample shape (e.g. [B, T, 1, 1, 1])
sigma_view = sigma.view(*sigma.shape, *([1] * (original_samples.ndim - sigma.ndim)))
sigma_view = sigma_view.to(original_samples.device)
return (1 - sigma_view) * original_samples + sigma_view * noise
def add_independent_noise(self, original_samples, timestep):
"""
Helper that samples noise independently for each element in original_samples
and applies it based on the provided timestep (which should match the leading dims).
"""
noise = torch.randn_like(original_samples)
return self.add_noise(original_samples, noise, timestep), noise
def training_target(self, sample, noise, timestep):
return noise - sample
def training_weight(self, timestep):
indices = self._get_timestep_indices(timestep)
return self.linear_timesteps_weights[indices]
if __name__ == "__main__":
import matplotlib.pyplot as plt
import numpy as np
import os
# Create results directory
os.makedirs("results/test_flow_matching", exist_ok=True)
# 1. Initialize scheduler
scheduler = FlowMatchScheduler()
num_steps = 50
scheduler.set_timesteps(num_inference_steps=num_steps, training=True)
# 2. Test with (B, T) shape
B, T = 2, 4
indices_bt = torch.randint(0, num_steps, (B, T))
timesteps_bt = scheduler.timesteps[indices_bt]
print(f"Testing with (B, T) shape: {timesteps_bt.shape}")
# Test add_noise with (B, T, C, H, W)
x0 = torch.randn(B, T, 3, 64, 64)
noise = torch.randn_like(x0)
xt = scheduler.add_noise(x0, noise, timesteps_bt)
print(f"xt shape: {xt.shape}")
assert xt.shape == x0.shape
# 3. Visualize Timestep Mapping and Training Weights
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Left: Timestep Mapping Curve
axes[0].plot(range(len(scheduler.timesteps)), scheduler.timesteps.numpy(), marker='.', color='blue', label='Timesteps')
axes[0].set_title("Timestep Mapping (Wan Shift=5)")
axes[0].set_xlabel("Inference Step Index")
axes[0].set_ylabel("Training Timestep (0-1000)")
axes[0].grid(True)
axes[0].legend()
# Right: Training Weights Curve
axes[1].plot(scheduler.timesteps.numpy(), scheduler.linear_timesteps_weights.numpy(), marker='.', color='red', label='Weights')
axes[1].set_title("Training Weights vs Timestep")
axes[1].set_xlabel("Training Timestep")
axes[1].set_ylabel("Weight Value")
axes[1].grid(True)
axes[1].legend()
plt.tight_layout()
plt.savefig("results/test_flow_matching/scheduler_curves.png")
print("Saved curves to results/test_flow_matching/scheduler_curves.png")
# 4. Visualize x_t interpolation (add_noise)
# Create a simple grid pattern as original image
size = 256
grid = np.zeros((size, size, 3), dtype=np.float32)
grid[::32, :] = 1.0
grid[:, ::32] = 1.0
original_image = torch.from_numpy(grid).permute(2, 0, 1).unsqueeze(0) # [1, 3, 256, 256]
# Random noise
noise = torch.randn_like(original_image)
# Pick a few steps to visualize
vis_indices = [0, num_steps//4, num_steps//2, 3*num_steps//4, num_steps-1]
num_vis = len(vis_indices)
fig_xt, axes_xt = plt.subplots(1, num_vis, figsize=(15, 3))
for i, idx in enumerate(vis_indices):
t = scheduler.timesteps[idx]
xt_img = scheduler.add_noise(original_image, noise, t)
# Denormalize for visualization (clip and permute)
vis_img = xt_img.squeeze(0).permute(1, 2, 0).numpy()
vis_img = np.clip(vis_img, 0, 1)
axes_xt[i].imshow(vis_img)
axes_xt[i].set_title(f"t={t:.1f}")
axes_xt[i].axis('off')
plt.suptitle("Flow Matching Interpolation (x_t) from Data (left) to Noise (right)")
plt.tight_layout()
plt.savefig("results/test_flow_matching/xt_interpolation.png")
print("Saved x_t interpolation to results/test_flow_matching/xt_interpolation.png")