FloodDiffusion-MEI / models /chunk_diffusion_wan.py
H-Liu1997's picture
Upload models/chunk_diffusion_wan.py with huggingface_hub
c785bc6 verified
"""Chunk-based diffusion model (no history re-noising).
Config: history_len=m, chunk_size=n, steps=T
- Global time t ∈ [0, num_chunks), where num_chunks = 1 + ceil((N - (m+n)) / n)
- Schedule: before window → 1.0, history → 1.0 (clean), target → frac(t), after → 0.0
- Inference: history stays clean, only target frames are denoised
- First chunk uses GT history frames as conditioning
"""
import math
import numpy as np
import torch
from .diffusion_forcing_wan import DiffForcingWanModel
EPSILON = 0.05
class ChunkDiffusionScheduler:
def __init__(self, config):
self.steps = config["steps"]
self.chunk_size = config["chunk_size"] # n
self.history_len = config.get("history_len", 0) # m
self.window_size = self.history_len + self.chunk_size # m+n
self.noise_type = config.get("noise_type", "linear")
self.sigma_type = config.get("sigma_type", "zero")
self.random_epsilon = config.get("random_epsilon", 0.0)
self.content_len = config.get("content_len", None)
if self.noise_type in ("exponential", "exponential_rev"):
self.exp_max = config.get("exp_max", 5.0)
elif self.noise_type == "diffusion":
self.T = config.get("T", 1000)
self.beta_start = config.get("beta_start", 0.0001)
self.beta_end = config.get("beta_end", 0.02)
if self.sigma_type == "memoryless":
self.sigma_scale = config.get("sigma_scale", 1.0)
# ----------------------------------------------------------------
# Chunks
# ----------------------------------------------------------------
def _num_chunks(self, seq_len):
if seq_len <= self.window_size:
return 1
return 1 + math.ceil((seq_len - self.window_size) / self.chunk_size)
def _window_range(self, seq_len, chunk_idx, training=False):
"""Return (input_start, input_end, output_start, output_end) for a chunk."""
if chunk_idx == 0:
os_ = self.history_len # First m frames are always GT history
oe_ = min(self.window_size, seq_len)
is_ = 0
else:
os_ = self.window_size + (chunk_idx - 1) * self.chunk_size
oe_ = min(os_ + self.chunk_size, seq_len)
is_ = os_ - self.history_len
if self.content_len is not None:
is_ = max(is_, oe_ - self.content_len)
# output always covers target only (excludes history)
return is_, oe_, os_, oe_
# ----------------------------------------------------------------
# Scheduler interface
# ----------------------------------------------------------------
def get_total_steps(self, seq_len):
return self._num_chunks(seq_len) * self.steps
def get_time_steps(self, device, valid_len, current_step=None):
time_steps = []
if current_step is None:
for i in range(len(valid_len)):
max_time = self._num_chunks(valid_len[i])
time_steps.append(
torch.tensor(np.random.uniform(0, max_time), device=device)
)
elif isinstance(current_step, int):
for i in range(len(valid_len)):
t = current_step * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
elif isinstance(current_step, list):
for i in range(len(valid_len)):
t = current_step[i] * (1.0 / self.steps)
time_steps.append(torch.tensor(t, device=device))
return time_steps
def get_time_schedules(self, device, valid_len, time_steps, training=False):
time_schedules = []
time_schedules_derivative = []
for i in range(len(valid_len)):
t = time_steps[i].item()
chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1)
t_frac = t - chunk_idx
is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx)
ts = torch.zeros(valid_len[i], device=device)
# Before window → 1.0 (clean)
ts[:is_] = 1.0
if training:
# Training: entire window uses t_frac
ts[is_:ie_] = t_frac
else:
# Inference: history → 1.0 (clean, no renoise), target → t_frac
ts[is_:os_] = 1.0
ts[os_:oe_] = t_frac
tsd = torch.full((valid_len[i],), 1.0 / self.steps, device=device)
if training:
ts = torch.clamp(
ts + torch.randn_like(ts) * self.random_epsilon,
min=0.0, max=1.0,
)
time_schedules.append(ts)
time_schedules_derivative.append(tsd)
return time_schedules, time_schedules_derivative
def get_windows(self, valid_len, time_steps, training=False):
input_start, input_end, output_start, output_end = [], [], [], []
for i in range(len(time_steps)):
t = time_steps[i].item()
chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1)
is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx, training=training)
input_start.append(is_)
input_end.append(ie_)
output_start.append(os_)
output_end.append(oe_)
return input_start, input_end, output_start, output_end
def get_noise_levels(self, device, valid_len, time_schedules, training=False):
alpha, dalpha, dlog_alpha = [], [], []
beta, dbeta, dlog_beta = [], [], []
sigma = []
for i in range(len(valid_len)):
t = time_schedules[i]
if self.noise_type == "linear":
alpha_i = t
dalpha_i = torch.ones_like(alpha_i)
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
beta_i = 1 - t
dbeta_i = -torch.ones_like(beta_i)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential":
k = self.exp_max
alpha_i = torch.exp(-k * (1 - t))
dalpha_i = k * alpha_i
dlog_alpha_i = k * torch.ones_like(alpha_i)
beta_i = 1 - alpha_i
dbeta_i = -dalpha_i
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential_rev":
k = self.exp_max
beta_i = torch.exp(-k * t)
dbeta_i = -k * beta_i
dlog_beta_i = -k * torch.ones_like(beta_i)
alpha_i = 1 - beta_i
dalpha_i = -dbeta_i
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
elif self.noise_type == "diffusion":
t_rev = 1.0 - t
beta_rate = (self.beta_start + t_rev * (self.beta_end - self.beta_start)) * self.T
Gamma = (self.beta_start * t_rev + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev) * self.T
alpha_i = torch.exp(-0.5 * Gamma)
dalpha_i = 0.5 * beta_rate * alpha_i
dlog_alpha_i = 0.5 * beta_rate
beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0))
dbeta_i = -0.5 * torch.exp(-Gamma) * beta_rate / torch.clamp(beta_i, min=EPSILON)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
else:
raise ValueError(f"Unknown noise type: {self.noise_type}")
alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0))
dalpha.append(dalpha_i)
dlog_alpha.append(dlog_alpha_i)
beta.append(torch.clamp(beta_i, min=0.0, max=1.0))
dbeta.append(dbeta_i)
dlog_beta.append(dlog_beta_i)
if self.sigma_type == "zero":
sigma_i = torch.zeros_like(t)
elif self.sigma_type == "memoryless":
if self.noise_type in ("linear", "exponential", "exponential_rev"):
sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0))
elif self.noise_type == "diffusion":
sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i, min=0.0))
else:
sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0))
sigma.append(sigma_i)
return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta
def add_noise(self, x, alpha, beta, input_start, input_end, output_start, output_end, training=False, noise=None):
x0, eps, xt = [], [], []
if training:
for i in range(len(x)):
noise_i = noise[i] if noise is not None else torch.randn_like(x[i])
alpha_i = alpha[i][None, :, None, None]
beta_i = beta[i][None, :, None, None]
noisy_x_i = x[i] * alpha_i + noise_i * beta_i
x0.append(x[i][:, output_start[i]:output_end[i], ...])
eps.append(noise_i[:, output_start[i]:output_end[i], ...])
xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...])
else:
# No re-noising: history frames stay as-is, target frames stay as-is
for i in range(len(x)):
xt.append(x[i][:, input_start[i]:input_end[i], ...])
return x0, eps, xt
def prepare(self, x, device, valid_len, training=True, current_step=None):
"""Single call replacing 5 separate scheduler calls."""
time_steps = self.get_time_steps(device, valid_len, current_step)
time_schedules, time_schedules_derivative = self.get_time_schedules(
device, valid_len, time_steps, training=training
)
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \
self.get_noise_levels(device, valid_len, time_schedules, training=training)
input_start, input_end, output_start, output_end = \
self.get_windows(valid_len, time_steps, training=training)
x0, eps, xt = self.add_noise(
x, alpha, beta, input_start, input_end,
output_start, output_end, training=training
)
# Slice all coefficients to their respective windows
batch_size = len(valid_len)
time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)]
time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)]
alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
return {
"time_schedules": time_schedules,
"time_schedules_derivative": time_schedules_derivative,
"input_start": input_start,
"input_end": input_end,
"output_start": output_start,
"output_end": output_end,
"alpha": alpha,
"dalpha": dalpha,
"beta": beta,
"dbeta": dbeta,
"sigma": sigma,
"dlog_alpha": dlog_alpha,
"dlog_beta": dlog_beta,
"xt": xt,
"x0": x0,
"eps": eps,
}
# ----------------------------------------------------------------
# Streaming support
# ----------------------------------------------------------------
def get_committable(self, total_frames):
if total_frames < self.window_size:
return 0, 0
committed = self.window_size
committable_steps = self.steps
remaining = total_frames - self.window_size
extra_chunks = remaining // self.chunk_size
committed += extra_chunks * self.chunk_size
committable_steps += extra_chunks * self.steps
return committed, committable_steps
def get_step_rollback(self, seq_len):
if seq_len < self.window_size:
return 0
completed = 1
remaining = seq_len - self.window_size
completed += remaining // self.chunk_size
return completed * self.steps
class ChunkDiffWanModel(DiffForcingWanModel):
"""Chunk-based diffusion model with clean history conditioning.
First chunk: GT history (history_len frames) + noisy target.
Subsequent chunks: previously generated frames as history + noisy target.
History is never re-noised.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config)
def generate(self, x):
x = self._extract_inputs(x)
extra_len = self.schedule_config.get("extra_len", 0)
feature_length = x["feature_length"]
batch_size = len(feature_length)
seq_len = max(feature_length).item() + extra_len
device = next(self.parameters()).device
valid_len = [min(fl.item(), seq_len) for fl in feature_length]
generated_len = [seq_len] * batch_size
history_len = self.time_scheduler.history_len
# Initialize entire sequence as pure noise
generated = torch.randn(
batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device
)
generated = [generated[i] for i in range(batch_size)]
generated = self.preprocess(generated)
# Inject GT history into the first history_len frames
if "feature" in x:
gt_feature = x["feature"]
gt_feature = (gt_feature - self.mean) / self.std
gt_list = []
for i in range(batch_size):
gt_list.append(gt_feature[i, :valid_len[i], ...])
gt_list = self.preprocess(gt_list)
for i in range(batch_size):
h = min(history_len, gt_list[i].shape[1])
generated[i][:, :h, ...] = gt_list[i][:, :h, ...]
# Precompute text and null contexts
text_context, metadata = self.text_module.get_context(
x, generated_len, device, self.param_dtype, training=False,
)
null_context = self.text_module.get_null_context(batch_size, device, self.param_dtype)
full_text = metadata["full_text"]
total_steps = self.time_scheduler.get_total_steps(seq_len)
for step in range(total_steps):
s = self.time_scheduler.prepare(
generated, device, generated_len, training=False, current_step=step
)
time_schedules = s["time_schedules"]
time_schedules_derivative = s["time_schedules_derivative"]
alpha = s["alpha"]
dalpha = s["dalpha"]
beta = s["beta"]
dbeta = s["dbeta"]
sigma = s["sigma"]
dlog_alpha = s["dlog_alpha"]
dlog_beta = s["dlog_beta"]
input_start_index = s["input_start"]
input_end_index = s["input_end"]
output_start_index = s["output_start"]
output_end_index = s["output_end"]
xt = s["xt"]
time_schedules_input = [
time_schedules[i] * self.time_embedding_scale for i in range(batch_size)
]
if isinstance(text_context[0], (list, tuple)):
window_text_context = [
text_context[i][input_start_index[i]:input_end_index[i]]
for i in range(batch_size)
]
else:
window_text_context = text_context
# CFG
pred_text = self.model(xt, time_schedules_input, window_text_context, seq_len, y=None)
pred_null = self.model(xt, time_schedules_input, null_context, seq_len, y=None)
predicted_result = [
self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn
for pt, pn in zip(pred_text, pred_null)
]
# SDE update only on output (target) frames
for i in range(batch_size):
os_idx, oe_idx = output_start_index[i], output_end_index[i]
pred_os = os_idx - input_start_index[i]
pred_oe = oe_idx - input_start_index[i]
predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...]
generated_i = generated[i][:, os_idx:oe_idx, ...]
dt = time_schedules_derivative[i][None, :, None, None]
alpha_i = alpha[i][None, :, None, None]
dalpha_i = dalpha[i][None, :, None, None]
beta_i = beta[i][None, :, None, None]
dbeta_i = dbeta[i][None, :, None, None]
sigma_i = sigma[i][None, :, None, None]
dlog_alpha_i = dlog_alpha[i][None, :, None, None]
dlog_beta_i = dlog_beta[i][None, :, None, None]
if self.prediction_type == "vel":
vel = predicted_result_i
elif self.prediction_type == "x0":
vel = (
predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i)
+ generated_i * dlog_beta_i
)
elif self.prediction_type == "eps":
vel = (
predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i)
+ generated_i * dlog_alpha_i
)
st = (vel - generated_i * dlog_alpha_i) / (
(beta_i * dlog_alpha_i - dbeta_i) * beta_i
)
generated[i][:, os_idx:oe_idx, ...] += (
vel * dt
+ st * 0.5 * sigma_i ** 2 * dt
+ sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i)
)
generated = self.postprocess(generated)
y_hat_out = []
for i in range(batch_size):
single_generated = generated[i][:valid_len[i], :] * self.std + self.mean
y_hat_out.append(single_generated)
return {"generated": y_hat_out, "text": full_text}
def init_generated(self, seq_len, batch_size=1, schedule_config={}):
super().init_generated(seq_len, batch_size, schedule_config)
self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config)