FloodDiffusion-MEI / models /diffusion_wan.py
H-Liu1997's picture
Upload models/diffusion_wan.py with huggingface_hub
ec7592b verified
"""Standard diffusion model (non-forcing).
All frames share the same noise level at each time step.
Scheduler replaces TriangularTimeScheduler; model inherits DiffForcingWanModel.
Config: steps=T
- Training: random t in (0, 1], uniform noise across all frames
- Inference: T-step denoising from t=0 (noise) to t=1 (clean)
"""
import numpy as np
import torch
from .diffusion_forcing_wan import DiffForcingWanModel
EPSILON = 0.05
class DiffusionScheduler:
"""Standard diffusion scheduler - uniform noise level across all frames.
Unlike TriangularTimeScheduler which assigns per-frame noise levels in a
triangular pattern, this scheduler gives every frame the same noise level t.
No windowing: input and output always span the full sequence.
"""
def __init__(self, config):
self.steps = config["steps"]
self.noise_type = config.get("noise_type", "linear")
self.sigma_type = config.get("sigma_type", "zero")
if self.noise_type in ("exponential", "exponential_rev"):
self.exp_max = config.get("exp_max", 5.0)
elif self.noise_type == "diffusion":
self.T = config.get("T", 1000)
self.beta_start = config.get("beta_start", 0.0001)
self.beta_end = config.get("beta_end", 0.02)
if self.sigma_type == "memoryless":
self.sigma_scale = config.get("sigma_scale", 1.0)
def get_total_steps(self, seq_len):
return self.steps
def get_time_steps(self, device, valid_len, current_step=None):
time_steps = []
if current_step is None:
for i in range(len(valid_len)):
time_steps.append(
torch.tensor(np.random.uniform(0, 1), device=device)
)
elif isinstance(current_step, int):
for i in range(len(valid_len)):
t = current_step * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
elif isinstance(current_step, list):
for i in range(len(valid_len)):
t = current_step[i] * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
return time_steps
def get_time_schedules(self, device, valid_len, time_steps, training=False):
time_schedules = []
time_schedules_derivative = []
for i in range(len(valid_len)):
t = time_steps[i].item()
time_schedules.append(torch.full((valid_len[i],), t, device=device))
time_schedules_derivative.append(
torch.full((valid_len[i],), 1.0 / self.steps, device=device)
)
return time_schedules, time_schedules_derivative
def get_windows(self, valid_len, time_steps, training=False):
n = len(valid_len)
return [0] * n, list(valid_len), [0] * n, list(valid_len)
def get_noise_levels(self, device, valid_len, time_schedules, training=False):
alpha, dalpha, dlog_alpha = [], [], []
beta, dbeta, dlog_beta = [], [], []
sigma = []
for i in range(len(valid_len)):
t = time_schedules[i]
if self.noise_type == "linear":
alpha_i = t
dalpha_i = torch.ones_like(t)
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
beta_i = 1 - t
dbeta_i = -torch.ones_like(t)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential":
k = self.exp_max
alpha_i = torch.exp(-k * (1 - t))
dalpha_i = k * alpha_i
dlog_alpha_i = k * torch.ones_like(t)
beta_i = 1 - alpha_i
dbeta_i = -dalpha_i
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential_rev":
k = self.exp_max
beta_i = torch.exp(-k * t)
dbeta_i = -k * beta_i
dlog_beta_i = -k * torch.ones_like(t)
alpha_i = 1 - beta_i
dalpha_i = -dbeta_i
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
elif self.noise_type == "diffusion":
t_rev = 1.0 - t
beta_rate = (
self.beta_start + t_rev * (self.beta_end - self.beta_start)
) * self.T
Gamma = (
self.beta_start * t_rev
+ 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev
) * self.T
alpha_i = torch.exp(-0.5 * Gamma)
dalpha_i = 0.5 * beta_rate * alpha_i
dlog_alpha_i = 0.5 * beta_rate
beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0))
dbeta_i = (
-0.5 * torch.exp(-Gamma) * beta_rate
/ torch.clamp(beta_i, min=EPSILON)
)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
else:
raise ValueError(f"Unknown noise type: {self.noise_type}")
alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0))
dalpha.append(dalpha_i)
dlog_alpha.append(dlog_alpha_i)
beta.append(torch.clamp(beta_i, min=0.0, max=1.0))
dbeta.append(dbeta_i)
dlog_beta.append(dlog_beta_i)
if self.sigma_type == "zero":
sigma_i = torch.zeros_like(t)
elif self.sigma_type == "memoryless":
if self.noise_type in ("linear", "exponential", "exponential_rev"):
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0)
)
elif self.noise_type == "diffusion":
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i, min=0.0)
)
else:
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(
2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0
)
)
sigma.append(sigma_i)
return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta
def add_noise(
self, x, alpha, beta, input_start, input_end,
output_start, output_end, training=False, noise=None,
):
x0, eps, xt = [], [], []
if training:
for i in range(len(x)):
noise_i = noise[i] if noise is not None else torch.randn_like(x[i])
alpha_i = alpha[i][None, :, None, None]
beta_i = beta[i][None, :, None, None]
noisy_x_i = x[i] * alpha_i + noise_i * beta_i
x0.append(x[i][:, output_start[i]:output_end[i], ...])
eps.append(noise_i[:, output_start[i]:output_end[i], ...])
xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...])
else:
for i in range(len(x)):
xt.append(x[i][:, input_start[i]:input_end[i], ...])
return x0, eps, xt
def prepare(self, x, device, valid_len, training=True, current_step=None):
"""Single call replacing 5 separate scheduler calls.
Returns dict. Training keys:
time_schedules, dalpha, dbeta, input_start, input_end,
output_start, output_end, x0, eps, xt
Inference keys:
time_schedules, time_schedules_derivative,
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta,
input_start, input_end, output_start, output_end, xt
"""
time_steps = self.get_time_steps(device, valid_len, current_step)
time_schedules, time_schedules_derivative = self.get_time_schedules(
device, valid_len, time_steps, training=training
)
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \
self.get_noise_levels(device, valid_len, time_schedules, training=training)
input_start, input_end, output_start, output_end = \
self.get_windows(valid_len, time_steps, training=training)
x0, eps, xt = self.add_noise(
x, alpha, beta, input_start, input_end,
output_start, output_end, training=training
)
# Slice all coefficients to their respective windows
# (no-op for pure diffusion since windows = full sequence)
batch_size = len(valid_len)
time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)]
time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)]
alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
return {
"time_schedules": time_schedules,
"time_schedules_derivative": time_schedules_derivative,
"input_start": input_start,
"input_end": input_end,
"output_start": output_start,
"output_end": output_end,
"alpha": alpha,
"dalpha": dalpha,
"beta": beta,
"dbeta": dbeta,
"sigma": sigma,
"dlog_alpha": dlog_alpha,
"dlog_beta": dlog_beta,
"xt": xt,
"x0": x0,
"eps": eps,
}
class DiffusionWanModel(DiffForcingWanModel):
"""Standard diffusion model. Inherits DiffForcingWanModel,
only replacing the scheduler. Parent's forward/generate work as-is.
No windowing, no streaming. All frames share the same noise level.
"""
def __init__(self, **kwargs):
sc = kwargs.get("schedule_config", {})
if "chunk_size" not in sc:
sc["chunk_size"] = 1
kwargs["schedule_config"] = sc
super().__init__(**kwargs)
self.time_scheduler = DiffusionScheduler(self.schedule_config)