| """Chunk-based diffusion model (no history re-noising). |
| |
| Config: history_len=m, chunk_size=n, steps=T |
| - Global time t ∈ [0, num_chunks), where num_chunks = 1 + ceil((N - (m+n)) / n) |
| - Schedule: before window → 1.0, history → 1.0 (clean), target → frac(t), after → 0.0 |
| - Inference: history stays clean, only target frames are denoised |
| - First chunk uses GT history frames as conditioning |
| """ |
|
|
| import math |
|
|
| import numpy as np |
| import torch |
|
|
| from .diffusion_forcing_wan import DiffForcingWanModel |
|
|
| EPSILON = 0.05 |
|
|
|
|
| class ChunkDiffusionScheduler: |
|
|
| def __init__(self, config): |
| self.steps = config["steps"] |
| self.chunk_size = config["chunk_size"] |
| self.history_len = config.get("history_len", 0) |
| self.window_size = self.history_len + self.chunk_size |
| self.noise_type = config.get("noise_type", "linear") |
| self.sigma_type = config.get("sigma_type", "zero") |
| self.random_epsilon = config.get("random_epsilon", 0.0) |
| self.content_len = config.get("content_len", None) |
|
|
| if self.noise_type in ("exponential", "exponential_rev"): |
| self.exp_max = config.get("exp_max", 5.0) |
| elif self.noise_type == "diffusion": |
| self.T = config.get("T", 1000) |
| self.beta_start = config.get("beta_start", 0.0001) |
| self.beta_end = config.get("beta_end", 0.02) |
|
|
| if self.sigma_type == "memoryless": |
| self.sigma_scale = config.get("sigma_scale", 1.0) |
|
|
| |
| |
| |
|
|
| def _num_chunks(self, seq_len): |
| if seq_len <= self.window_size: |
| return 1 |
| return 1 + math.ceil((seq_len - self.window_size) / self.chunk_size) |
|
|
| def _window_range(self, seq_len, chunk_idx, training=False): |
| """Return (input_start, input_end, output_start, output_end) for a chunk.""" |
| if chunk_idx == 0: |
| os_ = self.history_len |
| oe_ = min(self.window_size, seq_len) |
| is_ = 0 |
| else: |
| os_ = self.window_size + (chunk_idx - 1) * self.chunk_size |
| oe_ = min(os_ + self.chunk_size, seq_len) |
| is_ = os_ - self.history_len |
| if self.content_len is not None: |
| is_ = max(is_, oe_ - self.content_len) |
| |
| return is_, oe_, os_, oe_ |
|
|
| |
| |
| |
|
|
| def get_total_steps(self, seq_len): |
| return self._num_chunks(seq_len) * self.steps |
|
|
| def get_time_steps(self, device, valid_len, current_step=None): |
| time_steps = [] |
| if current_step is None: |
| for i in range(len(valid_len)): |
| max_time = self._num_chunks(valid_len[i]) |
| time_steps.append( |
| torch.tensor(np.random.uniform(0, max_time), device=device) |
| ) |
| elif isinstance(current_step, int): |
| for i in range(len(valid_len)): |
| t = current_step * (1.0 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| elif isinstance(current_step, list): |
| for i in range(len(valid_len)): |
| t = current_step[i] * (1.0 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| return time_steps |
|
|
| def get_time_schedules(self, device, valid_len, time_steps, training=False): |
| time_schedules = [] |
| time_schedules_derivative = [] |
| for i in range(len(valid_len)): |
| t = time_steps[i].item() |
| chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1) |
| t_frac = t - chunk_idx |
| is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx) |
|
|
| ts = torch.zeros(valid_len[i], device=device) |
| |
| ts[:is_] = 1.0 |
| if training: |
| |
| ts[is_:ie_] = t_frac |
| else: |
| |
| ts[is_:os_] = 1.0 |
| ts[os_:oe_] = t_frac |
|
|
| tsd = torch.full((valid_len[i],), 1.0 / self.steps, device=device) |
| if training: |
| ts = torch.clamp( |
| ts + torch.randn_like(ts) * self.random_epsilon, |
| min=0.0, max=1.0, |
| ) |
| time_schedules.append(ts) |
| time_schedules_derivative.append(tsd) |
| return time_schedules, time_schedules_derivative |
|
|
| def get_windows(self, valid_len, time_steps, training=False): |
| input_start, input_end, output_start, output_end = [], [], [], [] |
| for i in range(len(time_steps)): |
| t = time_steps[i].item() |
| chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1) |
| is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx, training=training) |
| input_start.append(is_) |
| input_end.append(ie_) |
| output_start.append(os_) |
| output_end.append(oe_) |
| return input_start, input_end, output_start, output_end |
|
|
| def get_noise_levels(self, device, valid_len, time_schedules, training=False): |
| alpha, dalpha, dlog_alpha = [], [], [] |
| beta, dbeta, dlog_beta = [], [], [] |
| sigma = [] |
| for i in range(len(valid_len)): |
| t = time_schedules[i] |
| if self.noise_type == "linear": |
| alpha_i = t |
| dalpha_i = torch.ones_like(alpha_i) |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| beta_i = 1 - t |
| dbeta_i = -torch.ones_like(beta_i) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential": |
| k = self.exp_max |
| alpha_i = torch.exp(-k * (1 - t)) |
| dalpha_i = k * alpha_i |
| dlog_alpha_i = k * torch.ones_like(alpha_i) |
| beta_i = 1 - alpha_i |
| dbeta_i = -dalpha_i |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential_rev": |
| k = self.exp_max |
| beta_i = torch.exp(-k * t) |
| dbeta_i = -k * beta_i |
| dlog_beta_i = -k * torch.ones_like(beta_i) |
| alpha_i = 1 - beta_i |
| dalpha_i = -dbeta_i |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| elif self.noise_type == "diffusion": |
| t_rev = 1.0 - t |
| beta_rate = (self.beta_start + t_rev * (self.beta_end - self.beta_start)) * self.T |
| Gamma = (self.beta_start * t_rev + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev) * self.T |
| alpha_i = torch.exp(-0.5 * Gamma) |
| dalpha_i = 0.5 * beta_rate * alpha_i |
| dlog_alpha_i = 0.5 * beta_rate |
| beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) |
| dbeta_i = -0.5 * torch.exp(-Gamma) * beta_rate / torch.clamp(beta_i, min=EPSILON) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| else: |
| raise ValueError(f"Unknown noise type: {self.noise_type}") |
|
|
| alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) |
| dalpha.append(dalpha_i) |
| dlog_alpha.append(dlog_alpha_i) |
| beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) |
| dbeta.append(dbeta_i) |
| dlog_beta.append(dlog_beta_i) |
|
|
| if self.sigma_type == "zero": |
| sigma_i = torch.zeros_like(t) |
| elif self.sigma_type == "memoryless": |
| if self.noise_type in ("linear", "exponential", "exponential_rev"): |
| sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0)) |
| elif self.noise_type == "diffusion": |
| sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i, min=0.0)) |
| else: |
| sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0)) |
| sigma.append(sigma_i) |
| return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta |
|
|
| def add_noise(self, x, alpha, beta, input_start, input_end, output_start, output_end, training=False, noise=None): |
| x0, eps, xt = [], [], [] |
| if training: |
| for i in range(len(x)): |
| noise_i = noise[i] if noise is not None else torch.randn_like(x[i]) |
| alpha_i = alpha[i][None, :, None, None] |
| beta_i = beta[i][None, :, None, None] |
| noisy_x_i = x[i] * alpha_i + noise_i * beta_i |
| x0.append(x[i][:, output_start[i]:output_end[i], ...]) |
| eps.append(noise_i[:, output_start[i]:output_end[i], ...]) |
| xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...]) |
| else: |
| |
| for i in range(len(x)): |
| xt.append(x[i][:, input_start[i]:input_end[i], ...]) |
| return x0, eps, xt |
|
|
| def prepare(self, x, device, valid_len, training=True, current_step=None): |
| """Single call replacing 5 separate scheduler calls.""" |
| time_steps = self.get_time_steps(device, valid_len, current_step) |
| time_schedules, time_schedules_derivative = self.get_time_schedules( |
| device, valid_len, time_steps, training=training |
| ) |
| alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ |
| self.get_noise_levels(device, valid_len, time_schedules, training=training) |
| input_start, input_end, output_start, output_end = \ |
| self.get_windows(valid_len, time_steps, training=training) |
| x0, eps, xt = self.add_noise( |
| x, alpha, beta, input_start, input_end, |
| output_start, output_end, training=training |
| ) |
|
|
| |
| batch_size = len(valid_len) |
| time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] |
| time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
|
|
| return { |
| "time_schedules": time_schedules, |
| "time_schedules_derivative": time_schedules_derivative, |
| "input_start": input_start, |
| "input_end": input_end, |
| "output_start": output_start, |
| "output_end": output_end, |
| "alpha": alpha, |
| "dalpha": dalpha, |
| "beta": beta, |
| "dbeta": dbeta, |
| "sigma": sigma, |
| "dlog_alpha": dlog_alpha, |
| "dlog_beta": dlog_beta, |
| "xt": xt, |
| "x0": x0, |
| "eps": eps, |
| } |
|
|
| |
| |
| |
|
|
| def get_committable(self, total_frames): |
| if total_frames < self.window_size: |
| return 0, 0 |
| committed = self.window_size |
| committable_steps = self.steps |
| remaining = total_frames - self.window_size |
| extra_chunks = remaining // self.chunk_size |
| committed += extra_chunks * self.chunk_size |
| committable_steps += extra_chunks * self.steps |
| return committed, committable_steps |
|
|
| def get_step_rollback(self, seq_len): |
| if seq_len < self.window_size: |
| return 0 |
| completed = 1 |
| remaining = seq_len - self.window_size |
| completed += remaining // self.chunk_size |
| return completed * self.steps |
|
|
|
|
| class ChunkDiffWanModel(DiffForcingWanModel): |
| """Chunk-based diffusion model with clean history conditioning. |
| |
| First chunk: GT history (history_len frames) + noisy target. |
| Subsequent chunks: previously generated frames as history + noisy target. |
| History is never re-noised. |
| """ |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config) |
|
|
| def generate(self, x): |
| x = self._extract_inputs(x) |
| extra_len = self.schedule_config.get("extra_len", 0) |
| feature_length = x["feature_length"] |
| batch_size = len(feature_length) |
| seq_len = max(feature_length).item() + extra_len |
| device = next(self.parameters()).device |
| valid_len = [min(fl.item(), seq_len) for fl in feature_length] |
| generated_len = [seq_len] * batch_size |
| history_len = self.time_scheduler.history_len |
|
|
| |
| generated = torch.randn( |
| batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device |
| ) |
| generated = [generated[i] for i in range(batch_size)] |
| generated = self.preprocess(generated) |
|
|
| |
| if "feature" in x: |
| gt_feature = x["feature"] |
| gt_feature = (gt_feature - self.mean) / self.std |
| gt_list = [] |
| for i in range(batch_size): |
| gt_list.append(gt_feature[i, :valid_len[i], ...]) |
| gt_list = self.preprocess(gt_list) |
| for i in range(batch_size): |
| h = min(history_len, gt_list[i].shape[1]) |
| generated[i][:, :h, ...] = gt_list[i][:, :h, ...] |
|
|
| |
| text_context, metadata = self.text_module.get_context( |
| x, generated_len, device, self.param_dtype, training=False, |
| ) |
| null_context = self.text_module.get_null_context(batch_size, device, self.param_dtype) |
| full_text = metadata["full_text"] |
|
|
| total_steps = self.time_scheduler.get_total_steps(seq_len) |
| for step in range(total_steps): |
| s = self.time_scheduler.prepare( |
| generated, device, generated_len, training=False, current_step=step |
| ) |
| time_schedules = s["time_schedules"] |
| time_schedules_derivative = s["time_schedules_derivative"] |
| alpha = s["alpha"] |
| dalpha = s["dalpha"] |
| beta = s["beta"] |
| dbeta = s["dbeta"] |
| sigma = s["sigma"] |
| dlog_alpha = s["dlog_alpha"] |
| dlog_beta = s["dlog_beta"] |
| input_start_index = s["input_start"] |
| input_end_index = s["input_end"] |
| output_start_index = s["output_start"] |
| output_end_index = s["output_end"] |
| xt = s["xt"] |
|
|
| time_schedules_input = [ |
| time_schedules[i] * self.time_embedding_scale for i in range(batch_size) |
| ] |
|
|
| if isinstance(text_context[0], (list, tuple)): |
| window_text_context = [ |
| text_context[i][input_start_index[i]:input_end_index[i]] |
| for i in range(batch_size) |
| ] |
| else: |
| window_text_context = text_context |
|
|
| |
| pred_text = self.model(xt, time_schedules_input, window_text_context, seq_len, y=None) |
| pred_null = self.model(xt, time_schedules_input, null_context, seq_len, y=None) |
| predicted_result = [ |
| self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn |
| for pt, pn in zip(pred_text, pred_null) |
| ] |
|
|
| |
| for i in range(batch_size): |
| os_idx, oe_idx = output_start_index[i], output_end_index[i] |
| pred_os = os_idx - input_start_index[i] |
| pred_oe = oe_idx - input_start_index[i] |
| predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...] |
| generated_i = generated[i][:, os_idx:oe_idx, ...] |
| dt = time_schedules_derivative[i][None, :, None, None] |
| alpha_i = alpha[i][None, :, None, None] |
| dalpha_i = dalpha[i][None, :, None, None] |
| beta_i = beta[i][None, :, None, None] |
| dbeta_i = dbeta[i][None, :, None, None] |
| sigma_i = sigma[i][None, :, None, None] |
| dlog_alpha_i = dlog_alpha[i][None, :, None, None] |
| dlog_beta_i = dlog_beta[i][None, :, None, None] |
|
|
| if self.prediction_type == "vel": |
| vel = predicted_result_i |
| elif self.prediction_type == "x0": |
| vel = ( |
| predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) |
| + generated_i * dlog_beta_i |
| ) |
| elif self.prediction_type == "eps": |
| vel = ( |
| predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) |
| + generated_i * dlog_alpha_i |
| ) |
| st = (vel - generated_i * dlog_alpha_i) / ( |
| (beta_i * dlog_alpha_i - dbeta_i) * beta_i |
| ) |
| generated[i][:, os_idx:oe_idx, ...] += ( |
| vel * dt |
| + st * 0.5 * sigma_i ** 2 * dt |
| + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) |
| ) |
|
|
| generated = self.postprocess(generated) |
| y_hat_out = [] |
| for i in range(batch_size): |
| single_generated = generated[i][:valid_len[i], :] * self.std + self.mean |
| y_hat_out.append(single_generated) |
| return {"generated": y_hat_out, "text": full_text} |
|
|
| def init_generated(self, seq_len, batch_size=1, schedule_config={}): |
| super().init_generated(seq_len, batch_size, schedule_config) |
| self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config) |
|
|