| | import torch |
| | import math |
| |
|
| | class FlowMatchScheduler(torch.nn.Module): |
| | """ |
| | A simplified Flow Matching scheduler specifically for the Wan template. |
| | Supports scalars, [B], [B, T], and higher-dimensional timesteps. |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.num_train_timesteps = 1000 |
| | self.register_buffer("sigmas", None, persistent=False) |
| | self.register_buffer("timesteps", None, persistent=False) |
| | self.register_buffer("linear_timesteps_weights", None, persistent=False) |
| | self.training = False |
| |
|
| | @property |
| | def device(self): |
| | if self.timesteps is not None: |
| | return self.timesteps.device |
| | return torch.device('cpu') |
| |
|
| | def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, shift=5.0, training=False): |
| | """ |
| | Sets the timesteps and sigmas for the Wan template. |
| | """ |
| | sigma_min = 0.0 |
| | sigma_max = 1.0 |
| | sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength |
| | |
| | |
| | sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) |
| | |
| | |
| | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
| | |
| | |
| | device = self.device |
| | sigmas = sigmas.to(device) |
| | timesteps = (sigmas * self.num_train_timesteps).to(device) |
| | |
| | self.register_buffer("sigmas", sigmas, persistent=False) |
| | self.register_buffer("timesteps", timesteps, persistent=False) |
| | |
| | if training: |
| | self.set_training_weight() |
| | self.training = True |
| | else: |
| | self.training = False |
| |
|
| | def set_training_weight(self): |
| | steps = 1000 |
| | x = self.timesteps |
| | y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) |
| | y_shifted = y - y.min() |
| | bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) |
| | if len(self.timesteps) != 1000: |
| | |
| | bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) |
| | bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] |
| | |
| | |
| | self.register_buffer("linear_timesteps_weights", bsmntw_weighing.to(self.device), persistent=False) |
| |
|
| | def _get_timestep_indices(self, timestep: torch.Tensor): |
| | """ |
| | Efficiently find the nearest indices in self.timesteps for input timesteps. |
| | Supports any input shape by flattening, computing, and reshapping. |
| | """ |
| | if not isinstance(timestep, torch.Tensor): |
| | timestep = torch.tensor(timestep, device=self.device) |
| | |
| | t_input = timestep.to(self.device) |
| | orig_shape = t_input.shape |
| | |
| | |
| | t_flat = t_input.reshape(-1, 1) |
| | |
| | |
| | diff = (t_flat - self.timesteps.unsqueeze(0)).abs() |
| | indices = torch.argmin(diff, dim=-1) |
| | |
| | return indices.view(orig_shape) |
| |
|
| | def step(self, model_output, timestep, sample, to_final=False): |
| | indices = self._get_timestep_indices(timestep) |
| | sigma = self.sigmas[indices] |
| | |
| | if to_final: |
| | sigma_next = torch.zeros_like(sigma) |
| | else: |
| | |
| | next_indices = (indices + 1).clamp(max=len(self.sigmas) - 1) |
| | sigma_next = self.sigmas[next_indices] |
| | |
| | sigma_next = torch.where(indices + 1 >= len(self.sigmas), torch.zeros_like(sigma), sigma_next) |
| |
|
| | |
| | sigma_diff = (sigma_next - sigma).view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim))) |
| | sigma_diff = sigma_diff.to(sample.device) |
| | return sample + model_output * sigma_diff |
| | |
| | def return_to_timestep(self, timestep, sample, sample_stablized): |
| | indices = self._get_timestep_indices(timestep) |
| | sigma = self.sigmas[indices] |
| | sigma_view = sigma.view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim))) |
| | sigma_view = sigma_view.to(sample.device) |
| | model_output = (sample - sample_stablized) / sigma_view |
| | return model_output |
| | |
| | def add_noise(self, original_samples, noise, timestep): |
| | indices = self._get_timestep_indices(timestep) |
| | sigma = self.sigmas[indices] |
| | |
| | |
| | sigma_view = sigma.view(*sigma.shape, *([1] * (original_samples.ndim - sigma.ndim))) |
| | sigma_view = sigma_view.to(original_samples.device) |
| | |
| | return (1 - sigma_view) * original_samples + sigma_view * noise |
| | |
| | def add_independent_noise(self, original_samples, timestep): |
| | """ |
| | Helper that samples noise independently for each element in original_samples |
| | and applies it based on the provided timestep (which should match the leading dims). |
| | """ |
| | noise = torch.randn_like(original_samples) |
| | return self.add_noise(original_samples, noise, timestep), noise |
| |
|
| | def training_target(self, sample, noise, timestep): |
| | return noise - sample |
| | |
| | def training_weight(self, timestep): |
| | indices = self._get_timestep_indices(timestep) |
| | return self.linear_timesteps_weights[indices] |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import os |
| |
|
| | |
| | os.makedirs("results/test_flow_matching", exist_ok=True) |
| |
|
| | |
| | scheduler = FlowMatchScheduler() |
| | num_steps = 50 |
| | scheduler.set_timesteps(num_inference_steps=num_steps, training=True) |
| | |
| | |
| | B, T = 2, 4 |
| | indices_bt = torch.randint(0, num_steps, (B, T)) |
| | timesteps_bt = scheduler.timesteps[indices_bt] |
| | print(f"Testing with (B, T) shape: {timesteps_bt.shape}") |
| | |
| | |
| | x0 = torch.randn(B, T, 3, 64, 64) |
| | noise = torch.randn_like(x0) |
| | xt = scheduler.add_noise(x0, noise, timesteps_bt) |
| | print(f"xt shape: {xt.shape}") |
| | assert xt.shape == x0.shape |
| |
|
| | |
| | fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| | |
| | |
| | axes[0].plot(range(len(scheduler.timesteps)), scheduler.timesteps.numpy(), marker='.', color='blue', label='Timesteps') |
| | axes[0].set_title("Timestep Mapping (Wan Shift=5)") |
| | axes[0].set_xlabel("Inference Step Index") |
| | axes[0].set_ylabel("Training Timestep (0-1000)") |
| | axes[0].grid(True) |
| | axes[0].legend() |
| |
|
| | |
| | axes[1].plot(scheduler.timesteps.numpy(), scheduler.linear_timesteps_weights.numpy(), marker='.', color='red', label='Weights') |
| | axes[1].set_title("Training Weights vs Timestep") |
| | axes[1].set_xlabel("Training Timestep") |
| | axes[1].set_ylabel("Weight Value") |
| | axes[1].grid(True) |
| | axes[1].legend() |
| |
|
| | plt.tight_layout() |
| | plt.savefig("results/test_flow_matching/scheduler_curves.png") |
| | print("Saved curves to results/test_flow_matching/scheduler_curves.png") |
| |
|
| | |
| | |
| | size = 256 |
| | grid = np.zeros((size, size, 3), dtype=np.float32) |
| | grid[::32, :] = 1.0 |
| | grid[:, ::32] = 1.0 |
| | original_image = torch.from_numpy(grid).permute(2, 0, 1).unsqueeze(0) |
| | |
| | |
| | noise = torch.randn_like(original_image) |
| | |
| | |
| | vis_indices = [0, num_steps//4, num_steps//2, 3*num_steps//4, num_steps-1] |
| | num_vis = len(vis_indices) |
| | |
| | fig_xt, axes_xt = plt.subplots(1, num_vis, figsize=(15, 3)) |
| | for i, idx in enumerate(vis_indices): |
| | t = scheduler.timesteps[idx] |
| | xt_img = scheduler.add_noise(original_image, noise, t) |
| | |
| | |
| | vis_img = xt_img.squeeze(0).permute(1, 2, 0).numpy() |
| | vis_img = np.clip(vis_img, 0, 1) |
| | |
| | axes_xt[i].imshow(vis_img) |
| | axes_xt[i].set_title(f"t={t:.1f}") |
| | axes_xt[i].axis('off') |
| | |
| | plt.suptitle("Flow Matching Interpolation (x_t) from Data (left) to Noise (right)") |
| | plt.tight_layout() |
| | plt.savefig("results/test_flow_matching/xt_interpolation.png") |
| | print("Saved x_t interpolation to results/test_flow_matching/xt_interpolation.png") |
| |
|