"""Chunk-based diffusion model (no history re-noising). Config: history_len=m, chunk_size=n, steps=T - Global time t ∈ [0, num_chunks), where num_chunks = 1 + ceil((N - (m+n)) / n) - Schedule: before window → 1.0, history → 1.0 (clean), target → frac(t), after → 0.0 - Inference: history stays clean, only target frames are denoised - First chunk uses GT history frames as conditioning """ import math import numpy as np import torch from .diffusion_forcing_wan import DiffForcingWanModel EPSILON = 0.05 class ChunkDiffusionScheduler: def __init__(self, config): self.steps = config["steps"] self.chunk_size = config["chunk_size"] # n self.history_len = config.get("history_len", 0) # m self.window_size = self.history_len + self.chunk_size # m+n self.noise_type = config.get("noise_type", "linear") self.sigma_type = config.get("sigma_type", "zero") self.random_epsilon = config.get("random_epsilon", 0.0) self.content_len = config.get("content_len", None) if self.noise_type in ("exponential", "exponential_rev"): self.exp_max = config.get("exp_max", 5.0) elif self.noise_type == "diffusion": self.T = config.get("T", 1000) self.beta_start = config.get("beta_start", 0.0001) self.beta_end = config.get("beta_end", 0.02) if self.sigma_type == "memoryless": self.sigma_scale = config.get("sigma_scale", 1.0) # ---------------------------------------------------------------- # Chunks # ---------------------------------------------------------------- def _num_chunks(self, seq_len): if seq_len <= self.window_size: return 1 return 1 + math.ceil((seq_len - self.window_size) / self.chunk_size) def _window_range(self, seq_len, chunk_idx, training=False): """Return (input_start, input_end, output_start, output_end) for a chunk.""" if chunk_idx == 0: os_ = self.history_len # First m frames are always GT history oe_ = min(self.window_size, seq_len) is_ = 0 else: os_ = self.window_size + (chunk_idx - 1) * self.chunk_size oe_ = min(os_ + self.chunk_size, seq_len) is_ = os_ - self.history_len if self.content_len is not None: is_ = max(is_, oe_ - self.content_len) # output always covers target only (excludes history) return is_, oe_, os_, oe_ # ---------------------------------------------------------------- # Scheduler interface # ---------------------------------------------------------------- def get_total_steps(self, seq_len): return self._num_chunks(seq_len) * self.steps def get_time_steps(self, device, valid_len, current_step=None): time_steps = [] if current_step is None: for i in range(len(valid_len)): max_time = self._num_chunks(valid_len[i]) time_steps.append( torch.tensor(np.random.uniform(0, max_time), device=device) ) elif isinstance(current_step, int): for i in range(len(valid_len)): t = current_step * (1.0 / self.steps) time_steps.append(torch.tensor(t, device=device)) elif isinstance(current_step, list): for i in range(len(valid_len)): t = current_step[i] * (1.0 / self.steps) time_steps.append(torch.tensor(t, device=device)) return time_steps def get_time_schedules(self, device, valid_len, time_steps, training=False): time_schedules = [] time_schedules_derivative = [] for i in range(len(valid_len)): t = time_steps[i].item() chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1) t_frac = t - chunk_idx is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx) ts = torch.zeros(valid_len[i], device=device) # Before window → 1.0 (clean) ts[:is_] = 1.0 if training: # Training: entire window uses t_frac ts[is_:ie_] = t_frac else: # Inference: history → 1.0 (clean, no renoise), target → t_frac ts[is_:os_] = 1.0 ts[os_:oe_] = t_frac tsd = torch.full((valid_len[i],), 1.0 / self.steps, device=device) if training: ts = torch.clamp( ts + torch.randn_like(ts) * self.random_epsilon, min=0.0, max=1.0, ) time_schedules.append(ts) time_schedules_derivative.append(tsd) return time_schedules, time_schedules_derivative def get_windows(self, valid_len, time_steps, training=False): input_start, input_end, output_start, output_end = [], [], [], [] for i in range(len(time_steps)): t = time_steps[i].item() chunk_idx = min(int(t), self._num_chunks(valid_len[i]) - 1) is_, ie_, os_, oe_ = self._window_range(valid_len[i], chunk_idx, training=training) input_start.append(is_) input_end.append(ie_) output_start.append(os_) output_end.append(oe_) return input_start, input_end, output_start, output_end def get_noise_levels(self, device, valid_len, time_schedules, training=False): alpha, dalpha, dlog_alpha = [], [], [] beta, dbeta, dlog_beta = [], [], [] sigma = [] for i in range(len(valid_len)): t = time_schedules[i] if self.noise_type == "linear": alpha_i = t dalpha_i = torch.ones_like(alpha_i) dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) beta_i = 1 - t dbeta_i = -torch.ones_like(beta_i) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential": k = self.exp_max alpha_i = torch.exp(-k * (1 - t)) dalpha_i = k * alpha_i dlog_alpha_i = k * torch.ones_like(alpha_i) beta_i = 1 - alpha_i dbeta_i = -dalpha_i dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential_rev": k = self.exp_max beta_i = torch.exp(-k * t) dbeta_i = -k * beta_i dlog_beta_i = -k * torch.ones_like(beta_i) alpha_i = 1 - beta_i dalpha_i = -dbeta_i dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) elif self.noise_type == "diffusion": t_rev = 1.0 - t beta_rate = (self.beta_start + t_rev * (self.beta_end - self.beta_start)) * self.T Gamma = (self.beta_start * t_rev + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev) * self.T alpha_i = torch.exp(-0.5 * Gamma) dalpha_i = 0.5 * beta_rate * alpha_i dlog_alpha_i = 0.5 * beta_rate beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) dbeta_i = -0.5 * torch.exp(-Gamma) * beta_rate / torch.clamp(beta_i, min=EPSILON) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) else: raise ValueError(f"Unknown noise type: {self.noise_type}") alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) dalpha.append(dalpha_i) dlog_alpha.append(dlog_alpha_i) beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) dbeta.append(dbeta_i) dlog_beta.append(dlog_beta_i) if self.sigma_type == "zero": sigma_i = torch.zeros_like(t) elif self.sigma_type == "memoryless": if self.noise_type in ("linear", "exponential", "exponential_rev"): sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0)) elif self.noise_type == "diffusion": sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * dlog_alpha_i, min=0.0)) else: sigma_i = self.sigma_scale * torch.sqrt(torch.clamp(2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0)) sigma.append(sigma_i) return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta def add_noise(self, x, alpha, beta, input_start, input_end, output_start, output_end, training=False, noise=None): x0, eps, xt = [], [], [] if training: for i in range(len(x)): noise_i = noise[i] if noise is not None else torch.randn_like(x[i]) alpha_i = alpha[i][None, :, None, None] beta_i = beta[i][None, :, None, None] noisy_x_i = x[i] * alpha_i + noise_i * beta_i x0.append(x[i][:, output_start[i]:output_end[i], ...]) eps.append(noise_i[:, output_start[i]:output_end[i], ...]) xt.append(noisy_x_i[:, input_start[i]:input_end[i], ...]) else: # No re-noising: history frames stay as-is, target frames stay as-is for i in range(len(x)): xt.append(x[i][:, input_start[i]:input_end[i], ...]) return x0, eps, xt def prepare(self, x, device, valid_len, training=True, current_step=None): """Single call replacing 5 separate scheduler calls.""" time_steps = self.get_time_steps(device, valid_len, current_step) time_schedules, time_schedules_derivative = self.get_time_schedules( device, valid_len, time_steps, training=training ) alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ self.get_noise_levels(device, valid_len, time_schedules, training=training) input_start, input_end, output_start, output_end = \ self.get_windows(valid_len, time_steps, training=training) x0, eps, xt = self.add_noise( x, alpha, beta, input_start, input_end, output_start, output_end, training=training ) # Slice all coefficients to their respective windows batch_size = len(valid_len) time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] return { "time_schedules": time_schedules, "time_schedules_derivative": time_schedules_derivative, "input_start": input_start, "input_end": input_end, "output_start": output_start, "output_end": output_end, "alpha": alpha, "dalpha": dalpha, "beta": beta, "dbeta": dbeta, "sigma": sigma, "dlog_alpha": dlog_alpha, "dlog_beta": dlog_beta, "xt": xt, "x0": x0, "eps": eps, } # ---------------------------------------------------------------- # Streaming support # ---------------------------------------------------------------- def get_committable(self, total_frames): if total_frames < self.window_size: return 0, 0 committed = self.window_size committable_steps = self.steps remaining = total_frames - self.window_size extra_chunks = remaining // self.chunk_size committed += extra_chunks * self.chunk_size committable_steps += extra_chunks * self.steps return committed, committable_steps def get_step_rollback(self, seq_len): if seq_len < self.window_size: return 0 completed = 1 remaining = seq_len - self.window_size completed += remaining // self.chunk_size return completed * self.steps class ChunkDiffWanModel(DiffForcingWanModel): """Chunk-based diffusion model with clean history conditioning. First chunk: GT history (history_len frames) + noisy target. Subsequent chunks: previously generated frames as history + noisy target. History is never re-noised. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config) def generate(self, x): x = self._extract_inputs(x) extra_len = self.schedule_config.get("extra_len", 0) feature_length = x["feature_length"] batch_size = len(feature_length) seq_len = max(feature_length).item() + extra_len device = next(self.parameters()).device valid_len = [min(fl.item(), seq_len) for fl in feature_length] generated_len = [seq_len] * batch_size history_len = self.time_scheduler.history_len # Initialize entire sequence as pure noise generated = torch.randn( batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device ) generated = [generated[i] for i in range(batch_size)] generated = self.preprocess(generated) # Inject GT history into the first history_len frames if "feature" in x: gt_feature = x["feature"] gt_feature = (gt_feature - self.mean) / self.std gt_list = [] for i in range(batch_size): gt_list.append(gt_feature[i, :valid_len[i], ...]) gt_list = self.preprocess(gt_list) for i in range(batch_size): h = min(history_len, gt_list[i].shape[1]) generated[i][:, :h, ...] = gt_list[i][:, :h, ...] # Precompute text and null contexts text_context, metadata = self.text_module.get_context( x, generated_len, device, self.param_dtype, training=False, ) null_context = self.text_module.get_null_context(batch_size, device, self.param_dtype) full_text = metadata["full_text"] total_steps = self.time_scheduler.get_total_steps(seq_len) for step in range(total_steps): s = self.time_scheduler.prepare( generated, device, generated_len, training=False, current_step=step ) time_schedules = s["time_schedules"] time_schedules_derivative = s["time_schedules_derivative"] alpha = s["alpha"] dalpha = s["dalpha"] beta = s["beta"] dbeta = s["dbeta"] sigma = s["sigma"] dlog_alpha = s["dlog_alpha"] dlog_beta = s["dlog_beta"] input_start_index = s["input_start"] input_end_index = s["input_end"] output_start_index = s["output_start"] output_end_index = s["output_end"] xt = s["xt"] time_schedules_input = [ time_schedules[i] * self.time_embedding_scale for i in range(batch_size) ] if isinstance(text_context[0], (list, tuple)): window_text_context = [ text_context[i][input_start_index[i]:input_end_index[i]] for i in range(batch_size) ] else: window_text_context = text_context # CFG pred_text = self.model(xt, time_schedules_input, window_text_context, seq_len, y=None) pred_null = self.model(xt, time_schedules_input, null_context, seq_len, y=None) predicted_result = [ self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn for pt, pn in zip(pred_text, pred_null) ] # SDE update only on output (target) frames for i in range(batch_size): os_idx, oe_idx = output_start_index[i], output_end_index[i] pred_os = os_idx - input_start_index[i] pred_oe = oe_idx - input_start_index[i] predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...] generated_i = generated[i][:, os_idx:oe_idx, ...] dt = time_schedules_derivative[i][None, :, None, None] alpha_i = alpha[i][None, :, None, None] dalpha_i = dalpha[i][None, :, None, None] beta_i = beta[i][None, :, None, None] dbeta_i = dbeta[i][None, :, None, None] sigma_i = sigma[i][None, :, None, None] dlog_alpha_i = dlog_alpha[i][None, :, None, None] dlog_beta_i = dlog_beta[i][None, :, None, None] if self.prediction_type == "vel": vel = predicted_result_i elif self.prediction_type == "x0": vel = ( predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) + generated_i * dlog_beta_i ) elif self.prediction_type == "eps": vel = ( predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) + generated_i * dlog_alpha_i ) st = (vel - generated_i * dlog_alpha_i) / ( (beta_i * dlog_alpha_i - dbeta_i) * beta_i ) generated[i][:, os_idx:oe_idx, ...] += ( vel * dt + st * 0.5 * sigma_i ** 2 * dt + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) ) generated = self.postprocess(generated) y_hat_out = [] for i in range(batch_size): single_generated = generated[i][:valid_len[i], :] * self.std + self.mean y_hat_out.append(single_generated) return {"generated": y_hat_out, "text": full_text} def init_generated(self, seq_len, batch_size=1, schedule_config={}): super().init_generated(seq_len, batch_size, schedule_config) self.time_scheduler = ChunkDiffusionScheduler(self.schedule_config)