File size: 10,580 Bytes
ec7592b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | """Standard diffusion model (non-forcing).
All frames share the same noise level at each time step.
Scheduler replaces TriangularTimeScheduler; model inherits DiffForcingWanModel.
Config: steps=T
- Training: random t in (0, 1], uniform noise across all frames
- Inference: T-step denoising from t=0 (noise) to t=1 (clean)
"""
import numpy as np
import torch
from .diffusion_forcing_wan import DiffForcingWanModel
EPSILON = 0.05
class DiffusionScheduler:
"""Standard diffusion scheduler - uniform noise level across all frames.
Unlike TriangularTimeScheduler which assigns per-frame noise levels in a
triangular pattern, this scheduler gives every frame the same noise level t.
No windowing: input and output always span the full sequence.
"""
def __init__(self, config):
self.steps = config["steps"]
self.noise_type = config.get("noise_type", "linear")
self.sigma_type = config.get("sigma_type", "zero")
if self.noise_type in ("exponential", "exponential_rev"):
self.exp_max = config.get("exp_max", 5.0)
elif self.noise_type == "diffusion":
self.T = config.get("T", 1000)
self.beta_start = config.get("beta_start", 0.0001)
self.beta_end = config.get("beta_end", 0.02)
if self.sigma_type == "memoryless":
self.sigma_scale = config.get("sigma_scale", 1.0)
def get_total_steps(self, seq_len):
return self.steps
def get_time_steps(self, device, valid_len, current_step=None):
time_steps = []
if current_step is None:
for i in range(len(valid_len)):
time_steps.append(
torch.tensor(np.random.uniform(0, 1), device=device)
)
elif isinstance(current_step, int):
for i in range(len(valid_len)):
t = current_step * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
elif isinstance(current_step, list):
for i in range(len(valid_len)):
t = current_step[i] * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
return time_steps
def get_time_schedules(self, device, valid_len, time_steps, training=False):
time_schedules = []
time_schedules_derivative = []
for i in range(len(valid_len)):
t = time_steps[i].item()
time_schedules.append(torch.full((valid_len[i],), t, device=device))
time_schedules_derivative.append(
torch.full((valid_len[i],), 1.0 / self.steps, device=device)
)
return time_schedules, time_schedules_derivative
def get_windows(self, valid_len, time_steps, training=False):
n = len(valid_len)
return [0] * n, list(valid_len), [0] * n, list(valid_len)
def get_noise_levels(self, device, valid_len, time_schedules, training=False):
alpha, dalpha, dlog_alpha = [], [], []
beta, dbeta, dlog_beta = [], [], []
sigma = []
for i in range(len(valid_len)):
t = time_schedules[i]
if self.noise_type == "linear":
alpha_i = t
dalpha_i = torch.ones_like(t)
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
beta_i = 1 - t
dbeta_i = -torch.ones_like(t)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential":
k = self.exp_max
alpha_i = torch.exp(-k * (1 - t))
dalpha_i = k * alpha_i
dlog_alpha_i = k * torch.ones_like(t)
beta_i = 1 - alpha_i
dbeta_i = -dalpha_i
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential_rev":
k = self.exp_max
beta_i = torch.exp(-k * t)
dbeta_i = -k * beta_i
dlog_beta_i = -k * torch.ones_like(t)
alpha_i = 1 - beta_i
dalpha_i = -dbeta_i
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
elif self.noise_type == "diffusion":
t_rev = 1.0 - t
beta_rate = (
self.beta_start + t_rev * (self.beta_end - self.beta_start)
) * self.T
Gamma = (
self.beta_start * t_rev
+ 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev
) * self.T
alpha_i = torch.exp(-0.5 * Gamma)
dalpha_i = 0.5 * beta_rate * alpha_i
dlog_alpha_i = 0.5 * beta_rate
beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0))
dbeta_i = (
-0.5 * torch.exp(-Gamma) * beta_rate
/ torch.clamp(beta_i, min=EPSILON)
)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
else:
raise ValueError(f"Unknown noise type: {self.noise_type}")
alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0))
dalpha.append(dalpha_i)
dlog_alpha.append(dlog_alpha_i)
beta.append(torch.clamp(beta_i, min=0.0, max=1.0))
dbeta.append(dbeta_i)
dlog_beta.append(dlog_beta_i)
if self.sigma_type == "zero":
sigma_i = torch.zeros_like(t)
elif self.sigma_type == "memoryless":
if self.noise_type in ("linear", "exponential", "exponential_rev"):
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0)
)
elif self.noise_type == "diffusion":
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i, min=0.0)
)
else:
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(
2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0
)
)
sigma.append(sigma_i)
return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta
def add_noise(
self, x, alpha, beta, input_start, input_end,
output_start, output_end, training=False, noise=None,
):
x0, eps, xt = [], [], []
if training:
for i in range(len(x)):
noise_i = noise[i] if noise is not None else torch.randn_like(x[i])
alpha_i = alpha[i][None, :, None, None]
beta_i = beta[i][None, :, None, None]
noisy_x_i = x[i] * alpha_i + noise_i * beta_i
x0.append(x[i][:, output_start[i]:output_end[i], ...])
eps.append(noise_i[:, output_start[i]:output_end[i], ...])
xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...])
else:
for i in range(len(x)):
xt.append(x[i][:, input_start[i]:input_end[i], ...])
return x0, eps, xt
def prepare(self, x, device, valid_len, training=True, current_step=None):
"""Single call replacing 5 separate scheduler calls.
Returns dict. Training keys:
time_schedules, dalpha, dbeta, input_start, input_end,
output_start, output_end, x0, eps, xt
Inference keys:
time_schedules, time_schedules_derivative,
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta,
input_start, input_end, output_start, output_end, xt
"""
time_steps = self.get_time_steps(device, valid_len, current_step)
time_schedules, time_schedules_derivative = self.get_time_schedules(
device, valid_len, time_steps, training=training
)
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \
self.get_noise_levels(device, valid_len, time_schedules, training=training)
input_start, input_end, output_start, output_end = \
self.get_windows(valid_len, time_steps, training=training)
x0, eps, xt = self.add_noise(
x, alpha, beta, input_start, input_end,
output_start, output_end, training=training
)
# Slice all coefficients to their respective windows
# (no-op for pure diffusion since windows = full sequence)
batch_size = len(valid_len)
time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)]
time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)]
alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
return {
"time_schedules": time_schedules,
"time_schedules_derivative": time_schedules_derivative,
"input_start": input_start,
"input_end": input_end,
"output_start": output_start,
"output_end": output_end,
"alpha": alpha,
"dalpha": dalpha,
"beta": beta,
"dbeta": dbeta,
"sigma": sigma,
"dlog_alpha": dlog_alpha,
"dlog_beta": dlog_beta,
"xt": xt,
"x0": x0,
"eps": eps,
}
class DiffusionWanModel(DiffForcingWanModel):
"""Standard diffusion model. Inherits DiffForcingWanModel,
only replacing the scheduler. Parent's forward/generate work as-is.
No windowing, no streaming. All frames share the same noise level.
"""
def __init__(self, **kwargs):
sc = kwargs.get("schedule_config", {})
if "chunk_size" not in sc:
sc["chunk_size"] = 1
kwargs["schedule_config"] = sc
super().__init__(**kwargs)
self.time_scheduler = DiffusionScheduler(self.schedule_config)
|