import math import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .tools.t5 import T5EncoderModel from .tools.wan_model import WanModel EPSILON = 0.05 class TriangularTimeScheduler: def __init__(self, config): self.steps = config["steps"] self.chunk_size = config["chunk_size"] self.random_epsilon = config.get("random_epsilon", 0.00) # schedule jittering self.noise_type = config.get("noise_type", "linear") self.sigma_type = config.get("sigma_type", "zero") # "zero", "memoryless" if self.noise_type == "exponential" or self.noise_type == "exponential_rev": self.exp_max = config.get("exp_max", 5.0) elif self.noise_type == "diffusion": self.T = config.get("T", 1000) self.beta_start = config.get("beta_start", 0.0001) self.beta_end = config.get("beta_end", 0.02) if self.sigma_type == "memoryless": self.sigma_scale = config.get("sigma_scale", 1.0) self.content_len = config.get("content_len", None) # For simplicity we require steps to be divisible by chunk_size, so that time windows align well. def get_total_steps(self, seq_len): return int(self.steps * seq_len / self.chunk_size) def get_time_steps(self, device, valid_len, current_step=None): time_steps = [] if current_step is None: for i in range(len(valid_len)): max_time = valid_len[i] / self.chunk_size time_steps.append( torch.tensor(np.random.uniform(0, max_time), device=device) ) elif isinstance(current_step, int): for i in range(len(valid_len)): t = current_step * (1 / self.steps) time_steps.append(torch.tensor(t, device=device)) elif isinstance(current_step, list): for i in range(len(valid_len)): t = current_step[i] * (1 / self.steps) time_steps.append(torch.tensor(t, device=device)) return time_steps def get_time_schedules(self, device, valid_len, time_steps, training=False): time_schedules = [] time_schedules_derivative = [] for i in range(len(valid_len)): t = time_steps[i].item() current_time_schedules = torch.clamp( -torch.arange(valid_len[i], device=device) / self.chunk_size + t, min=0.0, max=1.0, ) current_time_schedules_derivative = torch.ones_like( current_time_schedules ) * (1 / self.steps) if training: current_time_schedules = torch.clamp( current_time_schedules + torch.randn_like(current_time_schedules) * self.random_epsilon, min=0.0, max=1.0, ) time_schedules.append(current_time_schedules) time_schedules_derivative.append(current_time_schedules_derivative) return time_schedules, time_schedules_derivative def get_windows(self, valid_len, time_steps, training=False): # for the floating point issue, we can add the start_index by 0.5 / [steps * chunk_size] # for convenience, we just choose 0.5 * (1 / (self.steps * self.chunk_size)) here input_start, input_end, output_start, output_end = [], [], [], [] for i in range(len(time_steps)): t = time_steps[i].item() start_index = max( 0, math.floor( (t - 1) * self.chunk_size + 0.5 * (1 / (self.steps * self.chunk_size)) ) + 1, ) end_index = min( valid_len[i], math.floor( t * self.chunk_size + 0.5 * (1 / (self.steps * self.chunk_size)) ) + 1, ) if self.content_len is not None: input_start.append(max(0, end_index - self.content_len)) else: input_start.append(0) input_end.append(end_index) output_start.append(start_index) output_end.append(end_index) return input_start, input_end, output_start, output_end def get_noise_levels(self, device, valid_len, time_schedules, training=False): alpha = [] dalpha = [] dlog_alpha = [] beta = [] dbeta = [] dlog_beta = [] sigma = [] for i in range(len(valid_len)): t = time_schedules[i] if self.noise_type == "linear": alpha_i = t dalpha_i = torch.ones_like(alpha_i) dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) beta_i = 1 - t dbeta_i = -torch.ones_like(beta_i) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential": # "eps" prediction k = self.exp_max alpha_i = torch.exp(-k * (1 - t)) dalpha_i = k * alpha_i dlog_alpha_i = k * torch.ones_like(alpha_i) beta_i = 1 - alpha_i dbeta_i = -dalpha_i dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) elif self.noise_type == "exponential_rev": # "x0" prediction k = self.exp_max beta_i = torch.exp(-k * t) dbeta_i = -k * beta_i dlog_beta_i = -k * torch.ones_like(beta_i) alpha_i = 1 - beta_i dalpha_i = -dbeta_i dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) elif self.noise_type == "diffusion": t_rev = 1.0 - t beta_rate = ( self.beta_start + t_rev * (self.beta_end - self.beta_start) ) * self.T Gamma = ( self.beta_start * t_rev + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev ) * self.T alpha_i = torch.exp(-0.5 * Gamma) dalpha_i = 0.5 * beta_rate * alpha_i dlog_alpha_i = 0.5 * beta_rate beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) dbeta_i = ( -0.5 * torch.exp(-Gamma) * beta_rate / torch.clamp(beta_i, min=EPSILON) ) dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) else: raise ValueError(f"Unknown noise type: {self.noise_type}") alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) dalpha.append(dalpha_i) dlog_alpha.append(dlog_alpha_i) beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) dbeta.append(dbeta_i) dlog_beta.append(dlog_beta_i) if self.sigma_type == "zero": sigma_i = torch.zeros_like(t) elif self.sigma_type == "memoryless": if ( self.noise_type == "linear" or self.noise_type == "exponential" or self.noise_type == "exponential_rev" ): sigma_i = self.sigma_scale * torch.sqrt( torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0) ) elif self.noise_type == "diffusion": sigma_i = self.sigma_scale * torch.sqrt( torch.clamp(2 * dlog_alpha_i, min=0.0) ) else: sigma_i = self.sigma_scale * torch.sqrt( torch.clamp( 2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0 ) ) sigma.append(sigma_i) return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta def add_noise( self, x, alpha, beta, input_start, input_end, output_start, output_end, training=False, noise=None, ): """Add noise and slice into input/reference regions. Args: x: list of (C, T, H, W), x0 in training, xt in inference alpha: list of (T,) beta: list of (T,) input_start/input_end: per-sample input window indices output_start/output_end: per-sample output window indices Returns: x0: list of (C, output_len, H, W) eps: list of (C, output_len, H, W) xt: list of (C, input_len, H, W) """ x0 = [] eps = [] xt = [] if training: for i in range(len(x)): if noise is not None: noise_i = noise[i] else: noise_i = torch.randn_like(x[i]) alpha_i = alpha[i][None, :, None, None] # (1, T, 1, 1) beta_i = beta[i][None, :, None, None] # (1, T, 1, 1) noisy_x_i = x[i] * alpha_i + noise_i * beta_i # (C, T, H, W) x0.append(x[i][:, output_start[i] : output_end[i], ...]) eps.append(noise_i[:, output_start[i] : output_end[i], ...]) xt.append(noisy_x_i[:, input_start[i] : input_end[i], ...]) else: for i in range(len(x)): xt.append(x[i][:, input_start[i] : input_end[i], ...]) return x0, eps, xt def prepare(self, x, device, valid_len, training=True, current_step=None): """Single call replacing get_time_steps + get_time_schedules + get_noise_levels + get_windows + add_noise. Args: x: list of (C, T, H, W). Training: clean features. Inference: current state. device: torch device valid_len: list of int training: bool current_step: int (inference only) Returns dict. Training keys: time_schedules, dalpha, dbeta, input_start, input_end, output_start, output_end, x0, eps, xt Inference keys: time_schedules, time_schedules_derivative, alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta, input_start, input_end, output_start, output_end, xt """ time_steps = self.get_time_steps(device, valid_len, current_step) time_schedules, time_schedules_derivative = self.get_time_schedules( device, valid_len, time_steps, training=training ) alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ self.get_noise_levels(device, valid_len, time_schedules, training=training) input_start, input_end, output_start, output_end = \ self.get_windows(valid_len, time_steps, training=training) x0, eps, xt = self.add_noise( x, alpha, beta, input_start, input_end, output_start, output_end, training=training ) # Slice all coefficients to their respective windows batch_size = len(valid_len) time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] result = { "time_schedules": time_schedules, "time_schedules_derivative": time_schedules_derivative, "input_start": input_start, "input_end": input_end, "output_start": output_start, "output_end": output_end, "alpha": alpha, "dalpha": dalpha, "beta": beta, "dbeta": dbeta, "sigma": sigma, "dlog_alpha": dlog_alpha, "dlog_beta": dlog_beta, "xt": xt, "x0": x0, "eps": eps, } return result # --- Streaming support --- def get_committable(self, total_frames): """Given total accumulated conditions, return how many frames can be committed. Currently, we suppose steps % chunk_size == 0 for simplicity.""" committable_length = max(0, total_frames - self.chunk_size + 1) committable_steps = total_frames * (self.steps // self.chunk_size) return committable_length, committable_steps def get_step_rollback(self, seq_len): """Get the step count to subtract when wrapping the buffer by seq_len. Corresponds to how many steps were consumed by seq_len frames.""" steps = seq_len * (self.steps // self.chunk_size) return steps class T5TextCrossModule(nn.Module): """Cross-attention module for T5 text conditioning.""" def __init__( self, len=512, dim=4096, t5_size="xxl", checkpoint_path=None, tokenizer_path=None, drop_out=0.1, input_keys={ "text": "text", "text_end": "text_end", }, ): assert checkpoint_path is not None and tokenizer_path is not None, ( "T5 checkpoint and tokenizer paths must be provided." ) super().__init__() self.len = len self.dim = dim self.cross_attn_norm = True self.cross_rope = False self.drop_out = drop_out self.input_keys = input_keys self.text_encoder = T5EncoderModel( text_len=len, dtype=torch.bfloat16, device=torch.device("cpu"), checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, shard_fn=None, t5_size=t5_size, ) self.text_cache = {} def encode(self, text_list, device): """Encode text list with cache. Returns List[Tensor].""" # Deduplicate uncached texts texts_to_encode = [] for text in text_list: if text not in self.text_cache and text not in texts_to_encode: texts_to_encode.append(text) # Batch encode deduplicated texts if texts_to_encode: self.text_encoder.model.to(device) encoded = self.text_encoder(texts_to_encode, device) for text, feature in zip(texts_to_encode, encoded): self.text_cache[text] = feature.cpu() # Collect from cache return [self.text_cache[text].to(device) for text in text_list] def get_context(self, x, valid_len, device, param_dtype, training=False): """ Get cross-attention context from input dict. Returns: context: List[Tensor] metadata: dict, may contain 'full_text' """ text_key = self.input_keys.get("text", "text") text_end_key = self.input_keys.get("text_end", "text_end") metadata = {} if text_key not in x: text_list = ["" for _ in range(len(valid_len))] else: text_list = x[text_key] if isinstance(text_list[0], list): # Multi-segment text (stream mode) full_text = [] all_context = [] text_end_list = x[text_end_key] for i in range(len(valid_len)): if training and np.random.rand() <= self.drop_out: single_text_list = [""] single_text_end_list = [0, valid_len[i]] else: single_text_list = text_list[i] single_text_end_list = [0] + [ min(t, valid_len[i]) for t in text_end_list[i] ] single_text_length_list = [ t - b for t, b in zip(single_text_end_list[1:], single_text_end_list[:-1]) ] full_text.append( " ////////// ".join( [ f"{u} //dur:{t}" for u, t in zip(single_text_list, single_text_length_list) ] ) ) single_text_context = self.encode(single_text_list, device) single_text_context = [u.to(param_dtype) for u in single_text_context] sample_context = [] for u, duration in zip(single_text_context, single_text_length_list): sample_context.extend([u for _ in range(duration)]) all_context.append(sample_context) metadata["full_text"] = full_text return all_context, metadata else: # Single text per sample full_text = [u for u in text_list] metadata["full_text"] = full_text if training: text_list = [ ("" if np.random.rand() <= self.drop_out else u) for u in text_list ] else: text_list = [u for u in text_list] context = self.encode(text_list, device) context = [u.to(param_dtype) for u in context] return context, metadata def get_null_context(self, batch_size, device, param_dtype): """Get null/empty context for classifier-free guidance.""" null_ctx = self.encode([""] * batch_size, device) return [u.to(param_dtype) for u in null_ctx] # --- Streaming state management --- def init_stream(self, batch_size): self.stream_condition_list = [[] for _ in range(batch_size)] def update_stream(self, x, device, param_dtype): """Add one frame of context for a streaming step.""" text_key = self.input_keys.get("text", "text") text_input = x[text_key] new_ctx = self.encode(text_input, device) new_ctx = [u.to(param_dtype) for u in new_ctx] for i in range(len(self.stream_condition_list)): self.stream_condition_list[i].append(new_ctx[i]) def get_stream_context(self, start_index, end_index): context = [] for i in range(len(self.stream_condition_list)): context.append(self.stream_condition_list[i][start_index:end_index]) return context def trim_stream(self, trim_len): """Trim stream state when wrapping around.""" for i in range(len(self.stream_condition_list)): self.stream_condition_list[i] = self.stream_condition_list[i][trim_len:] class DiffForcingWanModel(nn.Module): def __init__( self, input_dim=256, mean_path=None, std_path=None, hidden_dim=1024, ffn_dim=2048, freq_dim=256, num_heads=8, num_layers=8, time_embedding_scale=1.0, causal=False, rope_channel_split=[1, 0, 0], spatial_shape=(1, 1), prediction_type="vel", # "vel", "x0", "eps" text_config={ "len": 512, "dim": 4096, }, schedule_config={ "noise_type": "linear", "chunk_size": 5, "steps": 10, "extra_len": 4, "random_epsilon": 0.00, }, cfg_config={ "text_scale": 5.0, "null_scale": -4.0, }, input_keys={ "feature": "feature", "feature_length": "feature_length", "text": "text", "text_end": "text_end", }, ): super().__init__() self.input_keys = input_keys self.mean_path = mean_path self.std_path = std_path self.input_dim = input_dim self.spatial_shape = tuple(spatial_shape) self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.num_heads = num_heads self.num_layers = num_layers self.time_embedding_scale = time_embedding_scale self.causal = causal self.rope_channel_split = rope_channel_split self.prediction_type = prediction_type self.cfg_config = cfg_config self.schedule_config = schedule_config self.time_scheduler = TriangularTimeScheduler(schedule_config) # Cross-attention module (text) self.text_module = T5TextCrossModule(**text_config) if self.mean_path is not None: self.register_buffer( "mean", torch.from_numpy(np.load(self.mean_path)).float() ) else: self.register_buffer("mean", torch.zeros(input_dim)) if self.std_path is not None: self.register_buffer( "std", torch.from_numpy(np.load(self.std_path)).float() ) else: self.register_buffer("std", torch.ones(input_dim)) self.model = WanModel( patch_size=(1, 1, 1), text_len=self.text_module.len, text_dim=self.text_module.dim, cross_attn_norm=self.text_module.cross_attn_norm, cross_rope=self.text_module.cross_rope, in_dim=self.input_dim, dim=self.hidden_dim, ffn_dim=self.ffn_dim, freq_dim=self.freq_dim, out_dim=self.input_dim, num_heads=self.num_heads, num_layers=self.num_layers, window_size=(-1, -1), qk_norm=True, eps=1e-6, causal=self.causal, rope_channel_split=self.rope_channel_split, ) self.param_dtype = torch.float32 def _extract_inputs(self, x): """Extract inputs from x using input_keys mapping.""" inputs = {} for internal_key, external_key in self.input_keys.items(): if external_key in x: inputs[internal_key] = x[external_key] return inputs def preprocess(self, x): """Convert last-channel format to channel-first, padding to 4D (C, T, H, W). (T, C) -> (C, T, 1, 1) (T, H, C) -> (C, T, H, 1) (T, H, W, C) -> (C, T, H, W) """ for i in range(len(x)): ndim = x[i].ndim if ndim == 2: # (T, C) x[i] = x[i].permute(1, 0)[:, :, None, None] elif ndim == 3: # (T, H, C) x[i] = x[i].permute(2, 0, 1)[:, :, :, None] elif ndim == 4: # (T, H, W, C) x[i] = x[i].permute(3, 0, 1, 2) return x def postprocess(self, x): """Reverse of preprocess: channel-first 4D back to last-channel, stripping padding dims. (C, T, 1, 1) -> (T, C) (C, T, H, 1) -> (T, H, C) (C, T, H, W) -> (T, H, W, C) """ for i in range(len(x)): shape = x[i].shape # (C, T, H, W) if shape[2] == 1 and shape[3] == 1: # (C, T, 1, 1) -> (T, C) x[i] = x[i][:, :, 0, 0].permute(1, 0) elif shape[3] == 1: # (C, T, H, 1) -> (T, H, C) x[i] = x[i][:, :, :, 0].permute(1, 2, 0) else: # (C, T, H, W) -> (T, H, W, C) x[i] = x[i].permute(1, 2, 3, 0) return x def forward(self, x): x = self._extract_inputs(x) feature_original = x["feature"] # (B, T, C) feature_length = x["feature_length"] # (B,) feature_original = (feature_original - self.mean) / self.std batch_size = feature_original.shape[0] seq_len = feature_original.shape[1] device = feature_original.device feature = [] valid_len = [] for i in range(batch_size): length = min(feature_length[i].item(), seq_len) valid_len.append(length) feature.append(feature_original[i, :length, ...]) # Preprocess to (C, T, 1, 1) per sample feature = self.preprocess(feature) # Get context from text cross module context, _ = self.text_module.get_context( x, valid_len, device, self.param_dtype, training=True, ) # Prepare noised data and schedule s = self.time_scheduler.prepare(feature, device, valid_len, training=True) time_schedules = s["time_schedules"] input_start_index = s["input_start"] input_end_index = s["input_end"] output_start_index = s["output_start"] output_end_index = s["output_end"] dalpha = s["dalpha"] dbeta = s["dbeta"] x0, eps, xt = s["x0"], s["eps"], s["xt"] # Slice per-frame context to match input window if isinstance(context[0], (list, tuple)): context = [ context[i][input_start_index[i] : input_end_index[i]] for i in range(batch_size) ] # time_schedules already sliced to input window by prepare() time_schedules_input = [ time_schedules[i] * self.time_embedding_scale for i in range(batch_size) ] # Through WanModel predicted_result = self.model( xt, time_schedules_input, context, seq_len, y=None, ) # (B, C, T, 1, 1) loss = 0.0 for b in range(batch_size): pred_os = output_start_index[b] - input_start_index[b] pred_oe = output_end_index[b] - input_start_index[b] # dalpha, dbeta already sliced to output window by prepare() dalpha_i = dalpha[b] dbeta_i = dbeta[b] if self.prediction_type == "vel": vel = ( x0[b] * dalpha_i[None, :, None, None] + eps[b] * dbeta_i[None, :, None, None] ) # (C, output_length, 1, 1) squared_error = ( predicted_result[b][:, pred_os:pred_oe, ...] - vel ) ** 2 elif self.prediction_type == "x0": squared_error = ( predicted_result[b][:, pred_os:pred_oe, ...] - x0[b] ) ** 2 elif self.prediction_type == "eps": squared_error = ( predicted_result[b][:, pred_os:pred_oe, ...] - eps[b] ) ** 2 sample_loss = squared_error.mean() loss += sample_loss loss = loss / batch_size loss_dict = {"total": loss, "mse": loss} return loss_dict def generate(self, x): """ Generation - Diffusion Forcing inference Uses triangular noise schedule, progressively generating from left to right Generation process: 1. Start from t=0, gradually increase t 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle 3. After each denoising step, t increases slightly and continues """ x = self._extract_inputs(x) extra_len = self.schedule_config.get("extra_len", 0) feature_length = x["feature_length"] # (B,) batch_size = len(feature_length) seq_len = max(feature_length).item() + extra_len device = next(self.parameters()).device valid_len = [] for i in range(batch_size): length = min(feature_length[i].item(), seq_len) valid_len.append(length) generated_len = [seq_len for _ in range(batch_size)] # Initialize entire sequence as pure noise generated = torch.randn( batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device ) generated = [generated[i] for i in range(batch_size)] generated = self.preprocess(generated) # Precompute text and null contexts for CFG text_context, metadata = self.text_module.get_context( x, generated_len, device, self.param_dtype, training=False, ) null_context = self.text_module.get_null_context( batch_size, device, self.param_dtype ) full_text = metadata["full_text"] total_steps = self.time_scheduler.get_total_steps(seq_len) # Progressively advance from t=0 to t=max_t for step in range(total_steps): s = self.time_scheduler.prepare( generated, device, generated_len, training=False, current_step=step ) time_schedules = s["time_schedules"] time_schedules_derivative = s["time_schedules_derivative"] alpha = s["alpha"] dalpha = s["dalpha"] beta = s["beta"] dbeta = s["dbeta"] sigma = s["sigma"] dlog_alpha = s["dlog_alpha"] dlog_beta = s["dlog_beta"] input_start_index = s["input_start"] input_end_index = s["input_end"] output_start_index = s["output_start"] output_end_index = s["output_end"] xt = s["xt"] # time_schedules already sliced to input window by prepare() time_schedules_input = [ time_schedules[i] * self.time_embedding_scale for i in range(batch_size) ] # Slice per-frame context to match input window if isinstance(text_context[0], (list, tuple)): window_text_context = [ text_context[i][input_start_index[i] : input_end_index[i]] for i in range(batch_size) ] else: window_text_context = text_context # CFG: text_scale * pred_text + null_scale * pred_null pred_text = self.model( xt, time_schedules_input, window_text_context, seq_len, y=None, ) pred_null = self.model( xt, time_schedules_input, null_context, seq_len, y=None, ) predicted_result = [ self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn for pt, pn in zip(pred_text, pred_null) ] # All noise coefficients already sliced to output window by prepare() for i in range(batch_size): os, oe = output_start_index[i], output_end_index[i] pred_os = os - input_start_index[i] pred_oe = oe - input_start_index[i] predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...] generated_i = generated[i][:, os:oe, ...] dt = time_schedules_derivative[i][None, :, None, None] alpha_i = alpha[i][None, :, None, None] dalpha_i = dalpha[i][None, :, None, None] beta_i = beta[i][None, :, None, None] dbeta_i = dbeta[i][None, :, None, None] sigma_i = sigma[i][None, :, None, None] dlog_alpha_i = dlog_alpha[i][None, :, None, None] dlog_beta_i = dlog_beta[i][None, :, None, None] if self.prediction_type == "vel": vel = predicted_result_i elif self.prediction_type == "x0": vel = ( predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) + generated_i * dlog_beta_i ) elif self.prediction_type == "eps": vel = ( predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) + generated_i * dlog_alpha_i ) st = (vel - generated_i * dlog_alpha_i) / ( (beta_i * dlog_alpha_i - dbeta_i) * beta_i ) generated[i][:, os:oe, ...] += ( vel * dt + st * 0.5 * sigma_i**2 * dt + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) ) generated = self.postprocess(generated) # list of (T, C) y_hat_out = [] for i in range(batch_size): single_generated = generated[i][: valid_len[i], :] * self.std + self.mean y_hat_out.append(single_generated) out = {} out["generated"] = y_hat_out out["text"] = full_text return out def init_generated(self, seq_len, batch_size=1, schedule_config={}): """Initialize streaming generation state. Args: seq_len: Model window size (how many frames WanModel processes per step). schedule_config: Optional schedule config overrides. Buffer is 2*seq_len. Model window is always buffer[0:seq_len]. When conditions overflow seq_len, shift buffer by seq_len and restart. """ self.schedule_config.update(schedule_config) content_len = self.schedule_config.get("content_len", None) if content_len is None: self.schedule_config["content_len"] = seq_len else: self.schedule_config["content_len"] = min(seq_len, content_len) self.time_scheduler = TriangularTimeScheduler(self.schedule_config) self.batch_size = batch_size self.seq_len = seq_len self.buf_len = seq_len * 2 self.current_step = 0 self.current_commit = 0 self.condition_frames = 0 device = next(self.parameters()).device # Initialize entire buffer as pure noise generated = torch.randn( batch_size, self.buf_len, *self.spatial_shape, self.input_dim, device=device ) generated = [generated[i] for i in range(batch_size)] self.generated = self.preprocess(generated) # Initialize streaming state for cross module self.text_module.init_stream(self.batch_size) def _rollback(self): """Shift buffer by seq_len when conditions overflow the window.""" for i in range(self.batch_size): self.generated[i][:, : self.seq_len, ...] = self.generated[i][ :, self.seq_len :, ... ].clone() self.generated[i][:, self.seq_len :, ...] = torch.randn_like( self.generated[i][:, self.seq_len :, ...] ) self.current_step -= self.time_scheduler.get_step_rollback(self.seq_len) self.condition_frames -= self.seq_len self.current_commit -= self.seq_len self.text_module.trim_stream(self.seq_len) @torch.no_grad() def stream_generate_step(self, x): """ Streaming generation step. Each call provides 1 frame of conditions. The scheduler determines committable frames from accumulated conditions. Returns: dict with "generated": list of one (N, C) tensor, or [] if nothing to commit. """ x = self._extract_inputs(x) device = next(self.parameters()).device self.generated = [g.to(device) for g in self.generated] # 1. Update conditions (1 frame per call) self.text_module.update_stream(x, device, self.param_dtype) self.condition_frames += 1 # 2. Rollback if conditions overflow the window if self.condition_frames > self.buf_len: self._rollback() # 3. Determine how many frames can be committed committable_length, committable_steps = self.time_scheduler.get_committable( self.condition_frames ) while self.current_step < committable_steps: s = self.time_scheduler.prepare( self.generated, device, [self.buf_len] * self.batch_size, training=False, current_step=self.current_step ) time_schedules = s["time_schedules"] time_schedules_derivative = s["time_schedules_derivative"] alpha = s["alpha"] dalpha = s["dalpha"] beta = s["beta"] dbeta = s["dbeta"] sigma = s["sigma"] dlog_alpha = s["dlog_alpha"] dlog_beta = s["dlog_beta"] is_ = s["input_start"] ie_ = s["input_end"] os_ = s["output_start"] oe_ = s["output_end"] xt = s["xt"] # time_schedules already sliced to input window by prepare() time_schedules_input = [ time_schedules[0] * self.time_embedding_scale ] * self.batch_size # CFG: batch text + null in one forward pass text_context = self.text_module.get_stream_context(is_[0], ie_[0]) null_context = self.text_module.get_null_context( self.batch_size, device, self.param_dtype ) # Convert null to per-frame format to match text_context window_len = ie_[0] - is_[0] null_context_pf = [ [null_context[i]] * window_len for i in range(self.batch_size) ] pred_all = self.model( xt + xt, time_schedules_input + time_schedules_input, text_context + null_context_pf, self.seq_len, y=None, ) pred_text = pred_all[: self.batch_size] pred_null = pred_all[self.batch_size :] predicted_result = [ self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn for pt, pn in zip(pred_text, pred_null) ] # All noise coefficients already sliced to output window by prepare() os_idx, oe_idx = os_[0], oe_[0] pred_os_idx = os_idx - is_[0] pred_oe_idx = oe_idx - is_[0] dt = time_schedules_derivative[0][None, :, None, None] alpha_i = alpha[0][None, :, None, None] dalpha_i = dalpha[0][None, :, None, None] beta_i = beta[0][None, :, None, None] dbeta_i = dbeta[0][None, :, None, None] sigma_i = sigma[0][None, :, None, None] dlog_alpha_i = dlog_alpha[0][None, :, None, None] dlog_beta_i = dlog_beta[0][None, :, None, None] for i in range(self.batch_size): predicted_result_i = predicted_result[i][ :, pred_os_idx:pred_oe_idx, ... ] generated_i = self.generated[i][:, os_idx:oe_idx, ...] if self.prediction_type == "vel": vel = predicted_result_i elif self.prediction_type == "x0": vel = ( predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) + generated_i * dlog_beta_i ) elif self.prediction_type == "eps": vel = ( predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) + generated_i * dlog_alpha_i ) st = (vel - generated_i * dlog_alpha_i) / ( (beta_i * dlog_alpha_i - dbeta_i) * beta_i ) self.generated[i][:, os_idx:oe_idx, ...] += ( vel * dt + st * 0.5 * sigma_i**2 * dt + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) ) self.current_step += 1 # 5. Extract newly committed frames if self.current_commit < committable_length: output = [ self.generated[i][:, self.current_commit : committable_length, ...] for i in range(self.batch_size) ] output = self.postprocess(output) output = [o * self.std + self.mean for o in output] self.current_commit = committable_length return {"generated": output} else: empty = [ torch.zeros(self.input_dim, 0, *self.spatial_shape, device=device) for _ in range(self.batch_size) ] empty = self.postprocess(empty) return {"generated": empty}