"""Standard diffusion model (non-forcing). All frames share the same noise level at each time step. Scheduler replaces TriangularTimeScheduler; model inherits DiffForcingWanModel. Config: steps=T - Training: random t in (0, 1], uniform noise across all frames - Inference: T-step denoising from t=0 (noise) to t=1 (clean) """ import numpy as np import torch from .diffusion_forcing_wan import DiffForcingWanModel EPSILON = 0.05 class DiffusionScheduler: """Standard diffusion scheduler - uniform noise level across all frames. Unlike TriangularTimeScheduler which assigns per-frame noise levels in a triangular pattern, this scheduler gives every frame the same noise level t. No windowing: input and output always span the full sequence. """ def __init__(self, config): self.steps = config["steps"] self.noise_type = config.get("noise_type", "linear") self.sigma_type = config.get("sigma_type", "zero") if self.noise_type in ("exponential", "exponential_rev"): self.exp_max = config.get("exp_max", 5.0) elif self.noise_type == "diffusion": self.T = config.get("T", 1000) self.beta_start = config.get("beta_start", 0.0001) self.beta_end = config.get("beta_end", 0.02) if self.sigma_type == "memoryless": self.sigma_scale = config.get("sigma_scale", 1.0) def get_total_steps(self, seq_len): return self.steps def get_time_steps(self, device, valid_len, current_step=None): time_steps = [] if current_step is None: for i in range(len(valid_len)): time_steps.append( torch.tensor(np.random.uniform(0, 1), device=device) ) elif isinstance(current_step, int): for i in range(len(valid_len)): t = current_step * (1.0 / self.steps) time_steps.append(torch.tensor(t, device=device)) elif isinstance(current_step, list): for i in range(len(valid_len)): t = current_step[i] * (1.0 / self.steps) time_steps.append(torch.tensor(t, device=device)) return time_steps def get_time_schedules(self, device, valid_len, time_steps, training=False): time_schedules = [] time_schedules_derivative = [] for i in range(len(valid_len)): t = time_steps[i].item() time_schedules.append(torch.full((valid_len[i],), t, device=device)) time_schedules_derivative.append( torch.full((valid_len[i],), 1.0 / self.steps, device=device) ) return time_schedules, time_schedules_derivative def get_windows(self, valid_len, time_steps, training=False): n = len(valid_len) return [0] * n, list(valid_len), [0] * n, list(valid_len) def get_noise_levels(self, device, valid_len, time_schedules, training=False): alpha, dalpha, dlog_alpha = [], [], [] beta, dbeta, dlog_beta = [], [], [] sigma = [] for i in range(len(valid_len)): t = time_schedules[i] if self.noise_type == "linear": alpha_i = t dalpha_i = torch.ones_like(t) dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) beta_i = 1 - t dbeta_i = -torch.ones_like(t) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential": k = self.exp_max alpha_i = torch.exp(-k * (1 - t)) dalpha_i = k * alpha_i dlog_alpha_i = k * torch.ones_like(t) beta_i = 1 - alpha_i dbeta_i = -dalpha_i dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential_rev": k = self.exp_max beta_i = torch.exp(-k * t) dbeta_i = -k * beta_i dlog_beta_i = -k * torch.ones_like(t) alpha_i = 1 - beta_i dalpha_i = -dbeta_i dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) elif self.noise_type == "diffusion": t_rev = 1.0 - t beta_rate = ( self.beta_start + t_rev * (self.beta_end - self.beta_start) ) * self.T Gamma = ( self.beta_start * t_rev + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev ) * self.T alpha_i = torch.exp(-0.5 * Gamma) dalpha_i = 0.5 * beta_rate * alpha_i dlog_alpha_i = 0.5 * beta_rate beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) dbeta_i = ( -0.5 * torch.exp(-Gamma) * beta_rate / torch.clamp(beta_i, min=EPSILON) ) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) else: raise ValueError(f"Unknown noise type: {self.noise_type}") alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) dalpha.append(dalpha_i) dlog_alpha.append(dlog_alpha_i) beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) dbeta.append(dbeta_i) dlog_beta.append(dlog_beta_i) if self.sigma_type == "zero": sigma_i = torch.zeros_like(t) elif self.sigma_type == "memoryless": if self.noise_type in ("linear", "exponential", "exponential_rev"): sigma_i = self.sigma_scale * torch.sqrt( torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0) ) elif self.noise_type == "diffusion": sigma_i = self.sigma_scale * torch.sqrt( torch.clamp(2 * dlog_alpha_i, min=0.0) ) else: sigma_i = self.sigma_scale * torch.sqrt( torch.clamp( 2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0 ) ) sigma.append(sigma_i) return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta def add_noise( self, x, alpha, beta, input_start, input_end, output_start, output_end, training=False, noise=None, ): x0, eps, xt = [], [], [] if training: for i in range(len(x)): noise_i = noise[i] if noise is not None else torch.randn_like(x[i]) alpha_i = alpha[i][None, :, None, None] beta_i = beta[i][None, :, None, None] noisy_x_i = x[i] * alpha_i + noise_i * beta_i x0.append(x[i][:, output_start[i]:output_end[i], ...]) eps.append(noise_i[:, output_start[i]:output_end[i], ...]) xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...]) else: for i in range(len(x)): xt.append(x[i][:, input_start[i]:input_end[i], ...]) return x0, eps, xt def prepare(self, x, device, valid_len, training=True, current_step=None): """Single call replacing 5 separate scheduler calls. Returns dict. Training keys: time_schedules, dalpha, dbeta, input_start, input_end, output_start, output_end, x0, eps, xt Inference keys: time_schedules, time_schedules_derivative, alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta, input_start, input_end, output_start, output_end, xt """ time_steps = self.get_time_steps(device, valid_len, current_step) time_schedules, time_schedules_derivative = self.get_time_schedules( device, valid_len, time_steps, training=training ) alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ self.get_noise_levels(device, valid_len, time_schedules, training=training) input_start, input_end, output_start, output_end = \ self.get_windows(valid_len, time_steps, training=training) x0, eps, xt = self.add_noise( x, alpha, beta, input_start, input_end, output_start, output_end, training=training ) # Slice all coefficients to their respective windows # (no-op for pure diffusion since windows = full sequence) batch_size = len(valid_len) time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] return { "time_schedules": time_schedules, "time_schedules_derivative": time_schedules_derivative, "input_start": input_start, "input_end": input_end, "output_start": output_start, "output_end": output_end, "alpha": alpha, "dalpha": dalpha, "beta": beta, "dbeta": dbeta, "sigma": sigma, "dlog_alpha": dlog_alpha, "dlog_beta": dlog_beta, "xt": xt, "x0": x0, "eps": eps, } class DiffusionWanModel(DiffForcingWanModel): """Standard diffusion model. Inherits DiffForcingWanModel, only replacing the scheduler. Parent's forward/generate work as-is. No windowing, no streaming. All frames share the same noise level. """ def __init__(self, **kwargs): sc = kwargs.get("schedule_config", {}) if "chunk_size" not in sc: sc["chunk_size"] = 1 kwargs["schedule_config"] = sc super().__init__(**kwargs) self.time_scheduler = DiffusionScheduler(self.schedule_config)