| import math |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .tools.t5 import T5EncoderModel |
| from .tools.wan_model import WanModel |
|
|
| EPSILON = 0.05 |
|
|
|
|
| class TriangularTimeScheduler: |
| def __init__(self, config): |
| self.steps = config["steps"] |
| self.chunk_size = config["chunk_size"] |
| self.random_epsilon = config.get("random_epsilon", 0.00) |
| self.noise_type = config.get("noise_type", "linear") |
| self.sigma_type = config.get("sigma_type", "zero") |
|
|
| if self.noise_type == "exponential" or self.noise_type == "exponential_rev": |
| self.exp_max = config.get("exp_max", 5.0) |
| elif self.noise_type == "diffusion": |
| self.T = config.get("T", 1000) |
| self.beta_start = config.get("beta_start", 0.0001) |
| self.beta_end = config.get("beta_end", 0.02) |
|
|
| if self.sigma_type == "memoryless": |
| self.sigma_scale = config.get("sigma_scale", 1.0) |
| self.content_len = config.get("content_len", None) |
| |
|
|
| def get_total_steps(self, seq_len): |
| return int(self.steps * seq_len / self.chunk_size) |
|
|
| def get_time_steps(self, device, valid_len, current_step=None): |
| time_steps = [] |
| if current_step is None: |
| for i in range(len(valid_len)): |
| max_time = valid_len[i] / self.chunk_size |
| time_steps.append( |
| torch.tensor(np.random.uniform(0, max_time), device=device) |
| ) |
| elif isinstance(current_step, int): |
| for i in range(len(valid_len)): |
| t = current_step * (1 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| elif isinstance(current_step, list): |
| for i in range(len(valid_len)): |
| t = current_step[i] * (1 / self.steps) |
| time_steps.append(torch.tensor(t, device=device)) |
| return time_steps |
|
|
| def get_time_schedules(self, device, valid_len, time_steps, training=False): |
| time_schedules = [] |
| time_schedules_derivative = [] |
| for i in range(len(valid_len)): |
| t = time_steps[i].item() |
| current_time_schedules = torch.clamp( |
| -torch.arange(valid_len[i], device=device) / self.chunk_size + t, |
| min=0.0, |
| max=1.0, |
| ) |
| current_time_schedules_derivative = torch.ones_like( |
| current_time_schedules |
| ) * (1 / self.steps) |
| if training: |
| current_time_schedules = torch.clamp( |
| current_time_schedules |
| + torch.randn_like(current_time_schedules) * self.random_epsilon, |
| min=0.0, |
| max=1.0, |
| ) |
| time_schedules.append(current_time_schedules) |
| time_schedules_derivative.append(current_time_schedules_derivative) |
| return time_schedules, time_schedules_derivative |
|
|
| def get_windows(self, valid_len, time_steps, training=False): |
| |
| |
| input_start, input_end, output_start, output_end = [], [], [], [] |
| for i in range(len(time_steps)): |
| t = time_steps[i].item() |
| start_index = max( |
| 0, |
| math.floor( |
| (t - 1) * self.chunk_size |
| + 0.5 * (1 / (self.steps * self.chunk_size)) |
| ) |
| + 1, |
| ) |
| end_index = min( |
| valid_len[i], |
| math.floor( |
| t * self.chunk_size + 0.5 * (1 / (self.steps * self.chunk_size)) |
| ) |
| + 1, |
| ) |
|
|
| if self.content_len is not None: |
| input_start.append(max(0, end_index - self.content_len)) |
| else: |
| input_start.append(0) |
| input_end.append(end_index) |
| output_start.append(start_index) |
| output_end.append(end_index) |
| return input_start, input_end, output_start, output_end |
|
|
| def get_noise_levels(self, device, valid_len, time_schedules, training=False): |
| alpha = [] |
| dalpha = [] |
| dlog_alpha = [] |
| beta = [] |
| dbeta = [] |
| dlog_beta = [] |
| sigma = [] |
| for i in range(len(valid_len)): |
| t = time_schedules[i] |
| if self.noise_type == "linear": |
| alpha_i = t |
| dalpha_i = torch.ones_like(alpha_i) |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| beta_i = 1 - t |
| dbeta_i = -torch.ones_like(beta_i) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential": |
| |
| k = self.exp_max |
| alpha_i = torch.exp(-k * (1 - t)) |
| dalpha_i = k * alpha_i |
| dlog_alpha_i = k * torch.ones_like(alpha_i) |
| beta_i = 1 - alpha_i |
| dbeta_i = -dalpha_i |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| elif self.noise_type == "exponential_rev": |
| |
| k = self.exp_max |
| beta_i = torch.exp(-k * t) |
| dbeta_i = -k * beta_i |
| dlog_beta_i = -k * torch.ones_like(beta_i) |
| alpha_i = 1 - beta_i |
| dalpha_i = -dbeta_i |
| dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON) |
| elif self.noise_type == "diffusion": |
| t_rev = 1.0 - t |
| beta_rate = ( |
| self.beta_start + t_rev * (self.beta_end - self.beta_start) |
| ) * self.T |
| Gamma = ( |
| self.beta_start * t_rev |
| + 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev |
| ) * self.T |
| alpha_i = torch.exp(-0.5 * Gamma) |
| dalpha_i = 0.5 * beta_rate * alpha_i |
| dlog_alpha_i = 0.5 * beta_rate |
| beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0)) |
| dbeta_i = ( |
| -0.5 |
| * torch.exp(-Gamma) |
| * beta_rate |
| / torch.clamp(beta_i, min=EPSILON) |
| ) |
| dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON) |
| else: |
| raise ValueError(f"Unknown noise type: {self.noise_type}") |
| alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0)) |
| dalpha.append(dalpha_i) |
| dlog_alpha.append(dlog_alpha_i) |
| beta.append(torch.clamp(beta_i, min=0.0, max=1.0)) |
| dbeta.append(dbeta_i) |
| dlog_beta.append(dlog_beta_i) |
| if self.sigma_type == "zero": |
| sigma_i = torch.zeros_like(t) |
| elif self.sigma_type == "memoryless": |
| if ( |
| self.noise_type == "linear" |
| or self.noise_type == "exponential" |
| or self.noise_type == "exponential_rev" |
| ): |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0) |
| ) |
| elif self.noise_type == "diffusion": |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp(2 * dlog_alpha_i, min=0.0) |
| ) |
| else: |
| sigma_i = self.sigma_scale * torch.sqrt( |
| torch.clamp( |
| 2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0 |
| ) |
| ) |
| sigma.append(sigma_i) |
| return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta |
|
|
| def add_noise( |
| self, |
| x, |
| alpha, |
| beta, |
| input_start, |
| input_end, |
| output_start, |
| output_end, |
| training=False, |
| noise=None, |
| ): |
| """Add noise and slice into input/reference regions. |
| Args: |
| x: list of (C, T, H, W), x0 in training, xt in inference |
| alpha: list of (T,) |
| beta: list of (T,) |
| input_start/input_end: per-sample input window indices |
| output_start/output_end: per-sample output window indices |
| Returns: |
| x0: list of (C, output_len, H, W) |
| eps: list of (C, output_len, H, W) |
| xt: list of (C, input_len, H, W) |
| """ |
| x0 = [] |
| eps = [] |
| xt = [] |
| if training: |
| for i in range(len(x)): |
| if noise is not None: |
| noise_i = noise[i] |
| else: |
| noise_i = torch.randn_like(x[i]) |
| alpha_i = alpha[i][None, :, None, None] |
| beta_i = beta[i][None, :, None, None] |
| noisy_x_i = x[i] * alpha_i + noise_i * beta_i |
| x0.append(x[i][:, output_start[i] : output_end[i], ...]) |
| eps.append(noise_i[:, output_start[i] : output_end[i], ...]) |
| xt.append(noisy_x_i[:, input_start[i] : input_end[i], ...]) |
| else: |
| for i in range(len(x)): |
| xt.append(x[i][:, input_start[i] : input_end[i], ...]) |
| return x0, eps, xt |
|
|
| def prepare(self, x, device, valid_len, training=True, current_step=None): |
| """Single call replacing get_time_steps + get_time_schedules + |
| get_noise_levels + get_windows + add_noise. |
| |
| Args: |
| x: list of (C, T, H, W). Training: clean features. Inference: current state. |
| device: torch device |
| valid_len: list of int |
| training: bool |
| current_step: int (inference only) |
| |
| Returns dict. Training keys: |
| time_schedules, dalpha, dbeta, input_start, input_end, |
| output_start, output_end, x0, eps, xt |
| Inference keys: |
| time_schedules, time_schedules_derivative, |
| alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta, |
| input_start, input_end, output_start, output_end, xt |
| """ |
| time_steps = self.get_time_steps(device, valid_len, current_step) |
| time_schedules, time_schedules_derivative = self.get_time_schedules( |
| device, valid_len, time_steps, training=training |
| ) |
| alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \ |
| self.get_noise_levels(device, valid_len, time_schedules, training=training) |
| input_start, input_end, output_start, output_end = \ |
| self.get_windows(valid_len, time_steps, training=training) |
| x0, eps, xt = self.add_noise( |
| x, alpha, beta, input_start, input_end, |
| output_start, output_end, training=training |
| ) |
|
|
| |
| batch_size = len(valid_len) |
| time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)] |
| time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
| dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)] |
|
|
| result = { |
| "time_schedules": time_schedules, |
| "time_schedules_derivative": time_schedules_derivative, |
| "input_start": input_start, |
| "input_end": input_end, |
| "output_start": output_start, |
| "output_end": output_end, |
| "alpha": alpha, |
| "dalpha": dalpha, |
| "beta": beta, |
| "dbeta": dbeta, |
| "sigma": sigma, |
| "dlog_alpha": dlog_alpha, |
| "dlog_beta": dlog_beta, |
| "xt": xt, |
| "x0": x0, |
| "eps": eps, |
| } |
| return result |
|
|
| |
|
|
| def get_committable(self, total_frames): |
| """Given total accumulated conditions, return how many frames can be committed. |
| Currently, we suppose steps % chunk_size == 0 for simplicity.""" |
| committable_length = max(0, total_frames - self.chunk_size + 1) |
| committable_steps = total_frames * (self.steps // self.chunk_size) |
| return committable_length, committable_steps |
|
|
| def get_step_rollback(self, seq_len): |
| """Get the step count to subtract when wrapping the buffer by seq_len. |
| Corresponds to how many steps were consumed by seq_len frames.""" |
| steps = seq_len * (self.steps // self.chunk_size) |
| return steps |
|
|
|
|
| class T5TextCrossModule(nn.Module): |
| """Cross-attention module for T5 text conditioning.""" |
|
|
| def __init__( |
| self, |
| len=512, |
| dim=4096, |
| t5_size="xxl", |
| checkpoint_path=None, |
| tokenizer_path=None, |
| drop_out=0.1, |
| input_keys={ |
| "text": "text", |
| "text_end": "text_end", |
| }, |
| ): |
| assert checkpoint_path is not None and tokenizer_path is not None, ( |
| "T5 checkpoint and tokenizer paths must be provided." |
| ) |
| super().__init__() |
| self.len = len |
| self.dim = dim |
| self.cross_attn_norm = True |
| self.cross_rope = False |
| self.drop_out = drop_out |
| self.input_keys = input_keys |
|
|
| self.text_encoder = T5EncoderModel( |
| text_len=len, |
| dtype=torch.bfloat16, |
| device=torch.device("cpu"), |
| checkpoint_path=checkpoint_path, |
| tokenizer_path=tokenizer_path, |
| shard_fn=None, |
| t5_size=t5_size, |
| ) |
| self.text_cache = {} |
|
|
| def encode(self, text_list, device): |
| """Encode text list with cache. Returns List[Tensor].""" |
| |
| texts_to_encode = [] |
| for text in text_list: |
| if text not in self.text_cache and text not in texts_to_encode: |
| texts_to_encode.append(text) |
|
|
| |
| if texts_to_encode: |
| self.text_encoder.model.to(device) |
| encoded = self.text_encoder(texts_to_encode, device) |
| for text, feature in zip(texts_to_encode, encoded): |
| self.text_cache[text] = feature.cpu() |
|
|
| |
| return [self.text_cache[text].to(device) for text in text_list] |
|
|
| def get_context(self, x, valid_len, device, param_dtype, training=False): |
| """ |
| Get cross-attention context from input dict. |
| |
| Returns: |
| context: List[Tensor] |
| metadata: dict, may contain 'full_text' |
| """ |
| text_key = self.input_keys.get("text", "text") |
| text_end_key = self.input_keys.get("text_end", "text_end") |
| metadata = {} |
|
|
| if text_key not in x: |
| text_list = ["" for _ in range(len(valid_len))] |
| else: |
| text_list = x[text_key] |
|
|
| if isinstance(text_list[0], list): |
| |
| full_text = [] |
| all_context = [] |
| text_end_list = x[text_end_key] |
|
|
| for i in range(len(valid_len)): |
| if training and np.random.rand() <= self.drop_out: |
| single_text_list = [""] |
| single_text_end_list = [0, valid_len[i]] |
| else: |
| single_text_list = text_list[i] |
| single_text_end_list = [0] + [ |
| min(t, valid_len[i]) for t in text_end_list[i] |
| ] |
| single_text_length_list = [ |
| t - b |
| for t, b in zip(single_text_end_list[1:], single_text_end_list[:-1]) |
| ] |
|
|
| full_text.append( |
| " ////////// ".join( |
| [ |
| f"{u} //dur:{t}" |
| for u, t in zip(single_text_list, single_text_length_list) |
| ] |
| ) |
| ) |
|
|
| single_text_context = self.encode(single_text_list, device) |
| single_text_context = [u.to(param_dtype) for u in single_text_context] |
| sample_context = [] |
| for u, duration in zip(single_text_context, single_text_length_list): |
| sample_context.extend([u for _ in range(duration)]) |
| all_context.append(sample_context) |
| metadata["full_text"] = full_text |
| return all_context, metadata |
| else: |
| |
| full_text = [u for u in text_list] |
| metadata["full_text"] = full_text |
| if training: |
| text_list = [ |
| ("" if np.random.rand() <= self.drop_out else u) for u in text_list |
| ] |
| else: |
| text_list = [u for u in text_list] |
| context = self.encode(text_list, device) |
| context = [u.to(param_dtype) for u in context] |
|
|
| return context, metadata |
|
|
| def get_null_context(self, batch_size, device, param_dtype): |
| """Get null/empty context for classifier-free guidance.""" |
| null_ctx = self.encode([""] * batch_size, device) |
| return [u.to(param_dtype) for u in null_ctx] |
|
|
| |
|
|
| def init_stream(self, batch_size): |
| self.stream_condition_list = [[] for _ in range(batch_size)] |
|
|
| def update_stream(self, x, device, param_dtype): |
| """Add one frame of context for a streaming step.""" |
| text_key = self.input_keys.get("text", "text") |
| text_input = x[text_key] |
| new_ctx = self.encode(text_input, device) |
| new_ctx = [u.to(param_dtype) for u in new_ctx] |
| for i in range(len(self.stream_condition_list)): |
| self.stream_condition_list[i].append(new_ctx[i]) |
|
|
| def get_stream_context(self, start_index, end_index): |
| context = [] |
| for i in range(len(self.stream_condition_list)): |
| context.append(self.stream_condition_list[i][start_index:end_index]) |
| return context |
|
|
| def trim_stream(self, trim_len): |
| """Trim stream state when wrapping around.""" |
| for i in range(len(self.stream_condition_list)): |
| self.stream_condition_list[i] = self.stream_condition_list[i][trim_len:] |
|
|
|
|
| class DiffForcingWanModel(nn.Module): |
| def __init__( |
| self, |
| input_dim=256, |
| mean_path=None, |
| std_path=None, |
| hidden_dim=1024, |
| ffn_dim=2048, |
| freq_dim=256, |
| num_heads=8, |
| num_layers=8, |
| time_embedding_scale=1.0, |
| causal=False, |
| rope_channel_split=[1, 0, 0], |
| spatial_shape=(1, 1), |
| prediction_type="vel", |
| text_config={ |
| "len": 512, |
| "dim": 4096, |
| }, |
| schedule_config={ |
| "noise_type": "linear", |
| "chunk_size": 5, |
| "steps": 10, |
| "extra_len": 4, |
| "random_epsilon": 0.00, |
| }, |
| cfg_config={ |
| "text_scale": 5.0, |
| "null_scale": -4.0, |
| }, |
| input_keys={ |
| "feature": "feature", |
| "feature_length": "feature_length", |
| "text": "text", |
| "text_end": "text_end", |
| }, |
| ): |
| super().__init__() |
| self.input_keys = input_keys |
|
|
| self.mean_path = mean_path |
| self.std_path = std_path |
| self.input_dim = input_dim |
| self.spatial_shape = tuple(spatial_shape) |
| self.hidden_dim = hidden_dim |
| self.ffn_dim = ffn_dim |
| self.freq_dim = freq_dim |
| self.num_heads = num_heads |
| self.num_layers = num_layers |
| self.time_embedding_scale = time_embedding_scale |
| self.causal = causal |
| self.rope_channel_split = rope_channel_split |
| self.prediction_type = prediction_type |
| self.cfg_config = cfg_config |
| self.schedule_config = schedule_config |
| self.time_scheduler = TriangularTimeScheduler(schedule_config) |
| |
| self.text_module = T5TextCrossModule(**text_config) |
|
|
| if self.mean_path is not None: |
| self.register_buffer( |
| "mean", torch.from_numpy(np.load(self.mean_path)).float() |
| ) |
| else: |
| self.register_buffer("mean", torch.zeros(input_dim)) |
|
|
| if self.std_path is not None: |
| self.register_buffer( |
| "std", torch.from_numpy(np.load(self.std_path)).float() |
| ) |
| else: |
| self.register_buffer("std", torch.ones(input_dim)) |
|
|
| self.model = WanModel( |
| patch_size=(1, 1, 1), |
| text_len=self.text_module.len, |
| text_dim=self.text_module.dim, |
| cross_attn_norm=self.text_module.cross_attn_norm, |
| cross_rope=self.text_module.cross_rope, |
| in_dim=self.input_dim, |
| dim=self.hidden_dim, |
| ffn_dim=self.ffn_dim, |
| freq_dim=self.freq_dim, |
| out_dim=self.input_dim, |
| num_heads=self.num_heads, |
| num_layers=self.num_layers, |
| window_size=(-1, -1), |
| qk_norm=True, |
| eps=1e-6, |
| causal=self.causal, |
| rope_channel_split=self.rope_channel_split, |
| ) |
| self.param_dtype = torch.float32 |
|
|
| def _extract_inputs(self, x): |
| """Extract inputs from x using input_keys mapping.""" |
| inputs = {} |
| for internal_key, external_key in self.input_keys.items(): |
| if external_key in x: |
| inputs[internal_key] = x[external_key] |
| return inputs |
|
|
| def preprocess(self, x): |
| """Convert last-channel format to channel-first, padding to 4D (C, T, H, W). |
| (T, C) -> (C, T, 1, 1) |
| (T, H, C) -> (C, T, H, 1) |
| (T, H, W, C) -> (C, T, H, W) |
| """ |
| for i in range(len(x)): |
| ndim = x[i].ndim |
| if ndim == 2: |
| x[i] = x[i].permute(1, 0)[:, :, None, None] |
| elif ndim == 3: |
| x[i] = x[i].permute(2, 0, 1)[:, :, :, None] |
| elif ndim == 4: |
| x[i] = x[i].permute(3, 0, 1, 2) |
| return x |
|
|
| def postprocess(self, x): |
| """Reverse of preprocess: channel-first 4D back to last-channel, stripping padding dims. |
| (C, T, 1, 1) -> (T, C) |
| (C, T, H, 1) -> (T, H, C) |
| (C, T, H, W) -> (T, H, W, C) |
| """ |
| for i in range(len(x)): |
| shape = x[i].shape |
| if shape[2] == 1 and shape[3] == 1: |
| x[i] = x[i][:, :, 0, 0].permute(1, 0) |
| elif shape[3] == 1: |
| x[i] = x[i][:, :, :, 0].permute(1, 2, 0) |
| else: |
| x[i] = x[i].permute(1, 2, 3, 0) |
| return x |
|
|
| def forward(self, x): |
| x = self._extract_inputs(x) |
| feature_original = x["feature"] |
| feature_length = x["feature_length"] |
| feature_original = (feature_original - self.mean) / self.std |
| batch_size = feature_original.shape[0] |
| seq_len = feature_original.shape[1] |
| device = feature_original.device |
| feature = [] |
| valid_len = [] |
| for i in range(batch_size): |
| length = min(feature_length[i].item(), seq_len) |
| valid_len.append(length) |
| feature.append(feature_original[i, :length, ...]) |
|
|
| |
| feature = self.preprocess(feature) |
|
|
| |
| context, _ = self.text_module.get_context( |
| x, |
| valid_len, |
| device, |
| self.param_dtype, |
| training=True, |
| ) |
|
|
| |
| s = self.time_scheduler.prepare(feature, device, valid_len, training=True) |
| time_schedules = s["time_schedules"] |
| input_start_index = s["input_start"] |
| input_end_index = s["input_end"] |
| output_start_index = s["output_start"] |
| output_end_index = s["output_end"] |
| dalpha = s["dalpha"] |
| dbeta = s["dbeta"] |
| x0, eps, xt = s["x0"], s["eps"], s["xt"] |
|
|
| |
| if isinstance(context[0], (list, tuple)): |
| context = [ |
| context[i][input_start_index[i] : input_end_index[i]] |
| for i in range(batch_size) |
| ] |
|
|
| |
| time_schedules_input = [ |
| time_schedules[i] * self.time_embedding_scale |
| for i in range(batch_size) |
| ] |
|
|
| |
| predicted_result = self.model( |
| xt, |
| time_schedules_input, |
| context, |
| seq_len, |
| y=None, |
| ) |
|
|
| loss = 0.0 |
| for b in range(batch_size): |
| pred_os = output_start_index[b] - input_start_index[b] |
| pred_oe = output_end_index[b] - input_start_index[b] |
| |
| dalpha_i = dalpha[b] |
| dbeta_i = dbeta[b] |
| if self.prediction_type == "vel": |
| vel = ( |
| x0[b] * dalpha_i[None, :, None, None] |
| + eps[b] * dbeta_i[None, :, None, None] |
| ) |
| squared_error = ( |
| predicted_result[b][:, pred_os:pred_oe, ...] - vel |
| ) ** 2 |
| elif self.prediction_type == "x0": |
| squared_error = ( |
| predicted_result[b][:, pred_os:pred_oe, ...] - x0[b] |
| ) ** 2 |
| elif self.prediction_type == "eps": |
| squared_error = ( |
| predicted_result[b][:, pred_os:pred_oe, ...] - eps[b] |
| ) ** 2 |
| sample_loss = squared_error.mean() |
| loss += sample_loss |
| loss = loss / batch_size |
| loss_dict = {"total": loss, "mse": loss} |
| return loss_dict |
|
|
| def generate(self, x): |
| """ |
| Generation - Diffusion Forcing inference |
| Uses triangular noise schedule, progressively generating from left to right |
| |
| Generation process: |
| 1. Start from t=0, gradually increase t |
| 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
| 3. After each denoising step, t increases slightly and continues |
| """ |
| x = self._extract_inputs(x) |
| extra_len = self.schedule_config.get("extra_len", 0) |
| feature_length = x["feature_length"] |
| batch_size = len(feature_length) |
| seq_len = max(feature_length).item() + extra_len |
| device = next(self.parameters()).device |
| valid_len = [] |
| for i in range(batch_size): |
| length = min(feature_length[i].item(), seq_len) |
| valid_len.append(length) |
| generated_len = [seq_len for _ in range(batch_size)] |
|
|
| |
| generated = torch.randn( |
| batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device |
| ) |
| generated = [generated[i] for i in range(batch_size)] |
| generated = self.preprocess(generated) |
|
|
| |
| text_context, metadata = self.text_module.get_context( |
| x, |
| generated_len, |
| device, |
| self.param_dtype, |
| training=False, |
| ) |
| null_context = self.text_module.get_null_context( |
| batch_size, device, self.param_dtype |
| ) |
| full_text = metadata["full_text"] |
|
|
| total_steps = self.time_scheduler.get_total_steps(seq_len) |
| |
| for step in range(total_steps): |
| s = self.time_scheduler.prepare( |
| generated, device, generated_len, training=False, current_step=step |
| ) |
| time_schedules = s["time_schedules"] |
| time_schedules_derivative = s["time_schedules_derivative"] |
| alpha = s["alpha"] |
| dalpha = s["dalpha"] |
| beta = s["beta"] |
| dbeta = s["dbeta"] |
| sigma = s["sigma"] |
| dlog_alpha = s["dlog_alpha"] |
| dlog_beta = s["dlog_beta"] |
| input_start_index = s["input_start"] |
| input_end_index = s["input_end"] |
| output_start_index = s["output_start"] |
| output_end_index = s["output_end"] |
| xt = s["xt"] |
|
|
| |
| time_schedules_input = [ |
| time_schedules[i] * self.time_embedding_scale |
| for i in range(batch_size) |
| ] |
|
|
| |
| if isinstance(text_context[0], (list, tuple)): |
| window_text_context = [ |
| text_context[i][input_start_index[i] : input_end_index[i]] |
| for i in range(batch_size) |
| ] |
| else: |
| window_text_context = text_context |
|
|
| |
| pred_text = self.model( |
| xt, |
| time_schedules_input, |
| window_text_context, |
| seq_len, |
| y=None, |
| ) |
| pred_null = self.model( |
| xt, |
| time_schedules_input, |
| null_context, |
| seq_len, |
| y=None, |
| ) |
| predicted_result = [ |
| self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn |
| for pt, pn in zip(pred_text, pred_null) |
| ] |
|
|
| |
| for i in range(batch_size): |
| os, oe = output_start_index[i], output_end_index[i] |
| pred_os = os - input_start_index[i] |
| pred_oe = oe - input_start_index[i] |
| predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...] |
| generated_i = generated[i][:, os:oe, ...] |
| dt = time_schedules_derivative[i][None, :, None, None] |
| alpha_i = alpha[i][None, :, None, None] |
| dalpha_i = dalpha[i][None, :, None, None] |
| beta_i = beta[i][None, :, None, None] |
| dbeta_i = dbeta[i][None, :, None, None] |
| sigma_i = sigma[i][None, :, None, None] |
| dlog_alpha_i = dlog_alpha[i][None, :, None, None] |
| dlog_beta_i = dlog_beta[i][None, :, None, None] |
| if self.prediction_type == "vel": |
| vel = predicted_result_i |
| elif self.prediction_type == "x0": |
| vel = ( |
| predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) |
| + generated_i * dlog_beta_i |
| ) |
| elif self.prediction_type == "eps": |
| vel = ( |
| predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) |
| + generated_i * dlog_alpha_i |
| ) |
| st = (vel - generated_i * dlog_alpha_i) / ( |
| (beta_i * dlog_alpha_i - dbeta_i) * beta_i |
| ) |
| generated[i][:, os:oe, ...] += ( |
| vel * dt |
| + st * 0.5 * sigma_i**2 * dt |
| + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) |
| ) |
|
|
| generated = self.postprocess(generated) |
| y_hat_out = [] |
| for i in range(batch_size): |
| single_generated = generated[i][: valid_len[i], :] * self.std + self.mean |
| y_hat_out.append(single_generated) |
| out = {} |
| out["generated"] = y_hat_out |
| out["text"] = full_text |
|
|
| return out |
|
|
| def init_generated(self, seq_len, batch_size=1, schedule_config={}): |
| """Initialize streaming generation state. |
| |
| Args: |
| seq_len: Model window size (how many frames WanModel processes per step). |
| schedule_config: Optional schedule config overrides. |
| |
| Buffer is 2*seq_len. Model window is always buffer[0:seq_len]. |
| When conditions overflow seq_len, shift buffer by seq_len and restart. |
| """ |
| self.schedule_config.update(schedule_config) |
| content_len = self.schedule_config.get("content_len", None) |
| if content_len is None: |
| self.schedule_config["content_len"] = seq_len |
| else: |
| self.schedule_config["content_len"] = min(seq_len, content_len) |
| self.time_scheduler = TriangularTimeScheduler(self.schedule_config) |
|
|
| self.batch_size = batch_size |
| self.seq_len = seq_len |
| self.buf_len = seq_len * 2 |
| self.current_step = 0 |
| self.current_commit = 0 |
| self.condition_frames = 0 |
|
|
| device = next(self.parameters()).device |
| |
| generated = torch.randn( |
| batch_size, self.buf_len, *self.spatial_shape, self.input_dim, device=device |
| ) |
| generated = [generated[i] for i in range(batch_size)] |
| self.generated = self.preprocess(generated) |
|
|
| |
| self.text_module.init_stream(self.batch_size) |
|
|
| def _rollback(self): |
| """Shift buffer by seq_len when conditions overflow the window.""" |
| for i in range(self.batch_size): |
| self.generated[i][:, : self.seq_len, ...] = self.generated[i][ |
| :, self.seq_len :, ... |
| ].clone() |
| self.generated[i][:, self.seq_len :, ...] = torch.randn_like( |
| self.generated[i][:, self.seq_len :, ...] |
| ) |
| self.current_step -= self.time_scheduler.get_step_rollback(self.seq_len) |
| self.condition_frames -= self.seq_len |
| self.current_commit -= self.seq_len |
| self.text_module.trim_stream(self.seq_len) |
|
|
| @torch.no_grad() |
| def stream_generate_step(self, x): |
| """ |
| Streaming generation step. Each call provides 1 frame of conditions. |
| The scheduler determines committable frames from accumulated conditions. |
| |
| Returns: |
| dict with "generated": list of one (N, C) tensor, or [] if nothing to commit. |
| """ |
| x = self._extract_inputs(x) |
| device = next(self.parameters()).device |
| self.generated = [g.to(device) for g in self.generated] |
|
|
| |
| self.text_module.update_stream(x, device, self.param_dtype) |
| self.condition_frames += 1 |
|
|
| |
| if self.condition_frames > self.buf_len: |
| self._rollback() |
|
|
| |
| committable_length, committable_steps = self.time_scheduler.get_committable( |
| self.condition_frames |
| ) |
| while self.current_step < committable_steps: |
| s = self.time_scheduler.prepare( |
| self.generated, device, [self.buf_len] * self.batch_size, |
| training=False, current_step=self.current_step |
| ) |
| time_schedules = s["time_schedules"] |
| time_schedules_derivative = s["time_schedules_derivative"] |
| alpha = s["alpha"] |
| dalpha = s["dalpha"] |
| beta = s["beta"] |
| dbeta = s["dbeta"] |
| sigma = s["sigma"] |
| dlog_alpha = s["dlog_alpha"] |
| dlog_beta = s["dlog_beta"] |
| is_ = s["input_start"] |
| ie_ = s["input_end"] |
| os_ = s["output_start"] |
| oe_ = s["output_end"] |
| xt = s["xt"] |
|
|
| |
| time_schedules_input = [ |
| time_schedules[0] * self.time_embedding_scale |
| ] * self.batch_size |
|
|
| |
| text_context = self.text_module.get_stream_context(is_[0], ie_[0]) |
| null_context = self.text_module.get_null_context( |
| self.batch_size, device, self.param_dtype |
| ) |
| |
| window_len = ie_[0] - is_[0] |
| null_context_pf = [ |
| [null_context[i]] * window_len for i in range(self.batch_size) |
| ] |
| pred_all = self.model( |
| xt + xt, |
| time_schedules_input + time_schedules_input, |
| text_context + null_context_pf, |
| self.seq_len, |
| y=None, |
| ) |
| pred_text = pred_all[: self.batch_size] |
| pred_null = pred_all[self.batch_size :] |
| predicted_result = [ |
| self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn |
| for pt, pn in zip(pred_text, pred_null) |
| ] |
|
|
| |
| os_idx, oe_idx = os_[0], oe_[0] |
| pred_os_idx = os_idx - is_[0] |
| pred_oe_idx = oe_idx - is_[0] |
| dt = time_schedules_derivative[0][None, :, None, None] |
| alpha_i = alpha[0][None, :, None, None] |
| dalpha_i = dalpha[0][None, :, None, None] |
| beta_i = beta[0][None, :, None, None] |
| dbeta_i = dbeta[0][None, :, None, None] |
| sigma_i = sigma[0][None, :, None, None] |
| dlog_alpha_i = dlog_alpha[0][None, :, None, None] |
| dlog_beta_i = dlog_beta[0][None, :, None, None] |
| for i in range(self.batch_size): |
| predicted_result_i = predicted_result[i][ |
| :, pred_os_idx:pred_oe_idx, ... |
| ] |
| generated_i = self.generated[i][:, os_idx:oe_idx, ...] |
| if self.prediction_type == "vel": |
| vel = predicted_result_i |
| elif self.prediction_type == "x0": |
| vel = ( |
| predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i) |
| + generated_i * dlog_beta_i |
| ) |
| elif self.prediction_type == "eps": |
| vel = ( |
| predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i) |
| + generated_i * dlog_alpha_i |
| ) |
| st = (vel - generated_i * dlog_alpha_i) / ( |
| (beta_i * dlog_alpha_i - dbeta_i) * beta_i |
| ) |
| self.generated[i][:, os_idx:oe_idx, ...] += ( |
| vel * dt |
| + st * 0.5 * sigma_i**2 * dt |
| + sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i) |
| ) |
| self.current_step += 1 |
|
|
| |
| if self.current_commit < committable_length: |
| output = [ |
| self.generated[i][:, self.current_commit : committable_length, ...] |
| for i in range(self.batch_size) |
| ] |
| output = self.postprocess(output) |
| output = [o * self.std + self.mean for o in output] |
| self.current_commit = committable_length |
| return {"generated": output} |
| else: |
| empty = [ |
| torch.zeros(self.input_dim, 0, *self.spatial_shape, device=device) |
| for _ in range(self.batch_size) |
| ] |
| empty = self.postprocess(empty) |
| return {"generated": empty} |
|
|