| """Standard diffusion model (non-forcing). |
| |
| All frames share the same noise level at each time step. |
| Scheduler replaces TriangularTimeScheduler; model inherits DiffForcingWanModel. |
| |
| Config: steps=T |
| - Training: random t in (0, 1], uniform noise across all frames |
| - Inference: T-step denoising from t=0 (noise) to t=1 (clean) |
| """ |
|
|
| import numpy as np |
| import torch |
|
|
| from .diffusion_forcing_wan import DiffForcingWanModel |
|
|
| EPSILON = 0.05 |
|
|
|
|
| class DiffusionScheduler: |
| """Standard diffusion scheduler - uniform noise level across all frames. |
| |
| Unlike TriangularTimeScheduler which assigns per-frame noise levels in a |
| triangular pattern, this scheduler gives every frame the same noise level t. |
| No windowing: input and output always span the full sequence. |
| """ |
|
|
| def __init__(self, config): |
| self.steps = config["steps"] |
| self.noise_type = config.get("noise_type", "linear") |
| self.sigma_type = config.get("sigma_type", "zero") |
|
|
| if self.noise_type in ("exponential", "exponential_rev"): |
| self.exp_max = config.get("exp_max", 5.0) |
| elif self.noise_type == "diffusion": |
| self.T = config.get("T", 1000) |
| self.beta_start = config.get("beta_start", 0.0001) |
| self.beta_end = config.get("beta_end", 0.02) |
|
|
| if self.sigma_type == "memoryless": |
| self.sigma_scale = config.get("sigma_scale", 1.0) |
|
|
| def get_total_steps(self, seq_len): |
| return self.steps |
|
|
| def get_time_steps(self, device, valid_len, current_step=None): |
| time_steps = [] |
| if current_step is None: |
| for i in range(len(valid_len)): |
| time_steps.append( |
| torch.tensor(np.random.uniform(0, 1), device=device) |
| ) |
| elif isinstance(current_step, int): |
| for i in range(len(valid_len)): |
| t = current_step * (1.0 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| elif isinstance(current_step, list): |
| for i in range(len(valid_len)): |
| t = current_step[i] * (1.0 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| return time_steps |
|
|
| def get_time_schedules(self, device, valid_len, time_steps, training=False): |
| time_schedules = [] |
| time_schedules_derivative = [] |
| for i in range(len(valid_len)): |
| t = time_steps[i].item() |
| time_schedules.append(torch.full((valid_len[i],), t, device=device)) |
| time_schedules_derivative.append( |
| torch.full((valid_len[i],), 1.0 / self.steps, device=device) |
| ) |
| return time_schedules, time_schedules_derivative |
|
|
| def get_windows(self, valid_len, time_steps, training=False): |
| n = len(valid_len) |
| return [0] * n, list(valid_len), [0] * n, list(valid_len) |
|
|
| def get_noise_levels(self, device, valid_len, time_schedules, training=False): |
| alpha, dalpha, dlog_alpha = [], [], [] |
| beta, dbeta, dlog_beta = [], [], [] |
| sigma = [] |
| for i in range(len(valid_len)): |
| t = time_schedules[i] |
| if self.noise_type == "linear": |
| alpha_i = t |
| dalpha_i = torch.ones_like(t) |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| beta_i = 1 - t |
| dbeta_i = -torch.ones_like(t) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential": |
| k = self.exp_max |
| alpha_i = torch.exp(-k * (1 - t)) |
| dalpha_i = k * alpha_i |
| dlog_alpha_i = k * torch.ones_like(t) |
| beta_i = 1 - alpha_i |
| dbeta_i = -dalpha_i |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential_rev": |
| k = self.exp_max |
| beta_i = torch.exp(-k * t) |
| dbeta_i = -k * beta_i |
| dlog_beta_i = -k * torch.ones_like(t) |
| alpha_i = 1 - beta_i |
| dalpha_i = -dbeta_i |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| elif self.noise_type == "diffusion": |
| t_rev = 1.0 - t |
| beta_rate = ( |
| self.beta_start + t_rev * (self.beta_end - self.beta_start) |
| ) * self.T |
| Gamma = ( |
| self.beta_start * t_rev |
| + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev |
| ) * self.T |
| alpha_i = torch.exp(-0.5 * Gamma) |
| dalpha_i = 0.5 * beta_rate * alpha_i |
| dlog_alpha_i = 0.5 * beta_rate |
| beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) |
| dbeta_i = ( |
| -0.5 * torch.exp(-Gamma) * beta_rate |
| / torch.clamp(beta_i, min=EPSILON) |
| ) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| else: |
| raise ValueError(f"Unknown noise type: {self.noise_type}") |
| alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) |
| dalpha.append(dalpha_i) |
| dlog_alpha.append(dlog_alpha_i) |
| beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) |
| dbeta.append(dbeta_i) |
| dlog_beta.append(dlog_beta_i) |
| if self.sigma_type == "zero": |
| sigma_i = torch.zeros_like(t) |
| elif self.sigma_type == "memoryless": |
| if self.noise_type in ("linear", "exponential", "exponential_rev"): |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0) |
| ) |
| elif self.noise_type == "diffusion": |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp(2 * dlog_alpha_i, min=0.0) |
| ) |
| else: |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp( |
| 2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0 |
| ) |
| ) |
| sigma.append(sigma_i) |
| return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta |
|
|
| def add_noise( |
| self, x, alpha, beta, input_start, input_end, |
| output_start, output_end, training=False, noise=None, |
| ): |
| x0, eps, xt = [], [], [] |
| if training: |
| for i in range(len(x)): |
| noise_i = noise[i] if noise is not None else torch.randn_like(x[i]) |
| alpha_i = alpha[i][None, :, None, None] |
| beta_i = beta[i][None, :, None, None] |
| noisy_x_i = x[i] * alpha_i + noise_i * beta_i |
| x0.append(x[i][:, output_start[i]:output_end[i], ...]) |
| eps.append(noise_i[:, output_start[i]:output_end[i], ...]) |
| xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...]) |
| else: |
| for i in range(len(x)): |
| xt.append(x[i][:, input_start[i]:input_end[i], ...]) |
| return x0, eps, xt |
|
|
| def prepare(self, x, device, valid_len, training=True, current_step=None): |
| """Single call replacing 5 separate scheduler calls. |
| |
| Returns dict. Training keys: |
| time_schedules, dalpha, dbeta, input_start, input_end, |
| output_start, output_end, x0, eps, xt |
| Inference keys: |
| time_schedules, time_schedules_derivative, |
| alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta, |
| input_start, input_end, output_start, output_end, xt |
| """ |
| time_steps = self.get_time_steps(device, valid_len, current_step) |
| time_schedules, time_schedules_derivative = self.get_time_schedules( |
| device, valid_len, time_steps, training=training |
| ) |
| alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ |
| self.get_noise_levels(device, valid_len, time_schedules, training=training) |
| input_start, input_end, output_start, output_end = \ |
| self.get_windows(valid_len, time_steps, training=training) |
| x0, eps, xt = self.add_noise( |
| x, alpha, beta, input_start, input_end, |
| output_start, output_end, training=training |
| ) |
|
|
| |
| |
| batch_size = len(valid_len) |
| time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] |
| time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
|
|
| return { |
| "time_schedules": time_schedules, |
| "time_schedules_derivative": time_schedules_derivative, |
| "input_start": input_start, |
| "input_end": input_end, |
| "output_start": output_start, |
| "output_end": output_end, |
| "alpha": alpha, |
| "dalpha": dalpha, |
| "beta": beta, |
| "dbeta": dbeta, |
| "sigma": sigma, |
| "dlog_alpha": dlog_alpha, |
| "dlog_beta": dlog_beta, |
| "xt": xt, |
| "x0": x0, |
| "eps": eps, |
| } |
|
|
|
|
| class DiffusionWanModel(DiffForcingWanModel): |
| """Standard diffusion model. Inherits DiffForcingWanModel, |
| only replacing the scheduler. Parent's forward/generate work as-is. |
| |
| No windowing, no streaming. All frames share the same noise level. |
| """ |
|
|
| def __init__(self, **kwargs): |
| sc = kwargs.get("schedule_config", {}) |
| if "chunk_size" not in sc: |
| sc["chunk_size"] = 1 |
| kwargs["schedule_config"] = sc |
| super().__init__(**kwargs) |
| self.time_scheduler = DiffusionScheduler(self.schedule_config) |
|
|