FloodDiffusion-MEI / models /diffusion_forcing_wan.py
H-Liu1997's picture
Upload models/diffusion_forcing_wan.py with huggingface_hub
471bcc0 verified
import math
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .tools.t5 import T5EncoderModel
from .tools.wan_model import WanModel
EPSILON = 0.05
class TriangularTimeScheduler:
def __init__(self, config):
self.steps = config["steps"]
self.chunk_size = config["chunk_size"]
self.random_epsilon = config.get("random_epsilon", 0.00) # schedule jittering
self.noise_type = config.get("noise_type", "linear")
self.sigma_type = config.get("sigma_type", "zero") # "zero", "memoryless"
if self.noise_type == "exponential" or self.noise_type == "exponential_rev":
self.exp_max = config.get("exp_max", 5.0)
elif self.noise_type == "diffusion":
self.T = config.get("T", 1000)
self.beta_start = config.get("beta_start", 0.0001)
self.beta_end = config.get("beta_end", 0.02)
if self.sigma_type == "memoryless":
self.sigma_scale = config.get("sigma_scale", 1.0)
self.content_len = config.get("content_len", None)
# For simplicity we require steps to be divisible by chunk_size, so that time windows align well.
def get_total_steps(self, seq_len):
return int(self.steps * seq_len / self.chunk_size)
def get_time_steps(self, device, valid_len, current_step=None):
time_steps = []
if current_step is None:
for i in range(len(valid_len)):
max_time = valid_len[i] / self.chunk_size
time_steps.append(
torch.tensor(np.random.uniform(0, max_time), device=device)
)
elif isinstance(current_step, int):
for i in range(len(valid_len)):
t = current_step * (1 / self.steps)
time_steps.append(torch.tensor(t, device=device))
elif isinstance(current_step, list):
for i in range(len(valid_len)):
t = current_step[i] * (1 / self.steps)
time_steps.append(torch.tensor(t, device=device))
return time_steps
def get_time_schedules(self, device, valid_len, time_steps, training=False):
time_schedules = []
time_schedules_derivative = []
for i in range(len(valid_len)):
t = time_steps[i].item()
current_time_schedules = torch.clamp(
-torch.arange(valid_len[i], device=device) / self.chunk_size + t,
min=0.0,
max=1.0,
)
current_time_schedules_derivative = torch.ones_like(
current_time_schedules
) * (1 / self.steps)
if training:
current_time_schedules = torch.clamp(
current_time_schedules
+ torch.randn_like(current_time_schedules) * self.random_epsilon,
min=0.0,
max=1.0,
)
time_schedules.append(current_time_schedules)
time_schedules_derivative.append(current_time_schedules_derivative)
return time_schedules, time_schedules_derivative
def get_windows(self, valid_len, time_steps, training=False):
# for the floating point issue, we can add the start_index by 0.5 / [steps * chunk_size]
# for convenience, we just choose 0.5 * (1 / (self.steps * self.chunk_size)) here
input_start, input_end, output_start, output_end = [], [], [], []
for i in range(len(time_steps)):
t = time_steps[i].item()
start_index = max(
0,
math.floor(
(t - 1) * self.chunk_size
+ 0.5 * (1 / (self.steps * self.chunk_size))
)
+ 1,
)
end_index = min(
valid_len[i],
math.floor(
t * self.chunk_size + 0.5 * (1 / (self.steps * self.chunk_size))
)
+ 1,
)
if self.content_len is not None:
input_start.append(max(0, end_index - self.content_len))
else:
input_start.append(0)
input_end.append(end_index)
output_start.append(start_index)
output_end.append(end_index)
return input_start, input_end, output_start, output_end
def get_noise_levels(self, device, valid_len, time_schedules, training=False):
alpha = []
dalpha = []
dlog_alpha = []
beta = []
dbeta = []
dlog_beta = []
sigma = []
for i in range(len(valid_len)):
t = time_schedules[i]
if self.noise_type == "linear":
alpha_i = t
dalpha_i = torch.ones_like(alpha_i)
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
beta_i = 1 - t
dbeta_i = -torch.ones_like(beta_i)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential":
# "eps" prediction
k = self.exp_max
alpha_i = torch.exp(-k * (1 - t))
dalpha_i = k * alpha_i
dlog_alpha_i = k * torch.ones_like(alpha_i)
beta_i = 1 - alpha_i
dbeta_i = -dalpha_i
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
elif self.noise_type == "exponential_rev":
# "x0" prediction
k = self.exp_max
beta_i = torch.exp(-k * t)
dbeta_i = -k * beta_i
dlog_beta_i = -k * torch.ones_like(beta_i)
alpha_i = 1 - beta_i
dalpha_i = -dbeta_i
dlog_alpha_i = dalpha_i / torch.clamp(alpha_i, min=EPSILON)
elif self.noise_type == "diffusion":
t_rev = 1.0 - t
beta_rate = (
self.beta_start + t_rev * (self.beta_end - self.beta_start)
) * self.T
Gamma = (
self.beta_start * t_rev
+ 0.5 * (self.beta_end - self.beta_start) * t_rev * t_rev
) * self.T
alpha_i = torch.exp(-0.5 * Gamma)
dalpha_i = 0.5 * beta_rate * alpha_i
dlog_alpha_i = 0.5 * beta_rate
beta_i = torch.sqrt(torch.clamp(1 - torch.exp(-Gamma), min=0.0))
dbeta_i = (
-0.5
* torch.exp(-Gamma)
* beta_rate
/ torch.clamp(beta_i, min=EPSILON)
)
dlog_beta_i = dbeta_i / torch.clamp(beta_i, min=EPSILON)
else:
raise ValueError(f"Unknown noise type: {self.noise_type}")
alpha.append(torch.clamp(alpha_i, min=0.0, max=1.0))
dalpha.append(dalpha_i)
dlog_alpha.append(dlog_alpha_i)
beta.append(torch.clamp(beta_i, min=0.0, max=1.0))
dbeta.append(dbeta_i)
dlog_beta.append(dlog_beta_i)
if self.sigma_type == "zero":
sigma_i = torch.zeros_like(t)
elif self.sigma_type == "memoryless":
if (
self.noise_type == "linear"
or self.noise_type == "exponential"
or self.noise_type == "exponential_rev"
):
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i * beta_i, min=0.0)
)
elif self.noise_type == "diffusion":
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(2 * dlog_alpha_i, min=0.0)
)
else:
sigma_i = self.sigma_scale * torch.sqrt(
torch.clamp(
2 * beta_i * (dlog_alpha_i * beta_i - dbeta_i), min=0.0
)
)
sigma.append(sigma_i)
return alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta
def add_noise(
self,
x,
alpha,
beta,
input_start,
input_end,
output_start,
output_end,
training=False,
noise=None,
):
"""Add noise and slice into input/reference regions.
Args:
x: list of (C, T, H, W), x0 in training, xt in inference
alpha: list of (T,)
beta: list of (T,)
input_start/input_end: per-sample input window indices
output_start/output_end: per-sample output window indices
Returns:
x0: list of (C, output_len, H, W)
eps: list of (C, output_len, H, W)
xt: list of (C, input_len, H, W)
"""
x0 = []
eps = []
xt = []
if training:
for i in range(len(x)):
if noise is not None:
noise_i = noise[i]
else:
noise_i = torch.randn_like(x[i])
alpha_i = alpha[i][None, :, None, None] # (1, T, 1, 1)
beta_i = beta[i][None, :, None, None] # (1, T, 1, 1)
noisy_x_i = x[i] * alpha_i + noise_i * beta_i # (C, T, H, W)
x0.append(x[i][:, output_start[i] : output_end[i], ...])
eps.append(noise_i[:, output_start[i] : output_end[i], ...])
xt.append(noisy_x_i[:, input_start[i] : input_end[i], ...])
else:
for i in range(len(x)):
xt.append(x[i][:, input_start[i] : input_end[i], ...])
return x0, eps, xt
def prepare(self, x, device, valid_len, training=True, current_step=None):
"""Single call replacing get_time_steps + get_time_schedules +
get_noise_levels + get_windows + add_noise.
Args:
x: list of (C, T, H, W). Training: clean features. Inference: current state.
device: torch device
valid_len: list of int
training: bool
current_step: int (inference only)
Returns dict. Training keys:
time_schedules, dalpha, dbeta, input_start, input_end,
output_start, output_end, x0, eps, xt
Inference keys:
time_schedules, time_schedules_derivative,
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta,
input_start, input_end, output_start, output_end, xt
"""
time_steps = self.get_time_steps(device, valid_len, current_step)
time_schedules, time_schedules_derivative = self.get_time_schedules(
device, valid_len, time_steps, training=training
)
alpha, dalpha, beta, dbeta, sigma, dlog_alpha, dlog_beta = \
self.get_noise_levels(device, valid_len, time_schedules, training=training)
input_start, input_end, output_start, output_end = \
self.get_windows(valid_len, time_steps, training=training)
x0, eps, xt = self.add_noise(
x, alpha, beta, input_start, input_end,
output_start, output_end, training=training
)
# Slice all coefficients to their respective windows
batch_size = len(valid_len)
time_schedules = [time_schedules[i][input_start[i]:input_end[i]] for i in range(batch_size)]
time_schedules_derivative = [time_schedules_derivative[i][output_start[i]:output_end[i]] for i in range(batch_size)]
alpha = [alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dalpha = [dalpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
beta = [beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dbeta = [dbeta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
sigma = [sigma[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_alpha = [dlog_alpha[i][output_start[i]:output_end[i]] for i in range(batch_size)]
dlog_beta = [dlog_beta[i][output_start[i]:output_end[i]] for i in range(batch_size)]
result = {
"time_schedules": time_schedules,
"time_schedules_derivative": time_schedules_derivative,
"input_start": input_start,
"input_end": input_end,
"output_start": output_start,
"output_end": output_end,
"alpha": alpha,
"dalpha": dalpha,
"beta": beta,
"dbeta": dbeta,
"sigma": sigma,
"dlog_alpha": dlog_alpha,
"dlog_beta": dlog_beta,
"xt": xt,
"x0": x0,
"eps": eps,
}
return result
# --- Streaming support ---
def get_committable(self, total_frames):
"""Given total accumulated conditions, return how many frames can be committed.
Currently, we suppose steps % chunk_size == 0 for simplicity."""
committable_length = max(0, total_frames - self.chunk_size + 1)
committable_steps = total_frames * (self.steps // self.chunk_size)
return committable_length, committable_steps
def get_step_rollback(self, seq_len):
"""Get the step count to subtract when wrapping the buffer by seq_len.
Corresponds to how many steps were consumed by seq_len frames."""
steps = seq_len * (self.steps // self.chunk_size)
return steps
class T5TextCrossModule(nn.Module):
"""Cross-attention module for T5 text conditioning."""
def __init__(
self,
len=512,
dim=4096,
t5_size="xxl",
checkpoint_path=None,
tokenizer_path=None,
drop_out=0.1,
input_keys={
"text": "text",
"text_end": "text_end",
},
):
assert checkpoint_path is not None and tokenizer_path is not None, (
"T5 checkpoint and tokenizer paths must be provided."
)
super().__init__()
self.len = len
self.dim = dim
self.cross_attn_norm = True
self.cross_rope = False
self.drop_out = drop_out
self.input_keys = input_keys
self.text_encoder = T5EncoderModel(
text_len=len,
dtype=torch.bfloat16,
device=torch.device("cpu"),
checkpoint_path=checkpoint_path,
tokenizer_path=tokenizer_path,
shard_fn=None,
t5_size=t5_size,
)
self.text_cache = {}
def encode(self, text_list, device):
"""Encode text list with cache. Returns List[Tensor]."""
# Deduplicate uncached texts
texts_to_encode = []
for text in text_list:
if text not in self.text_cache and text not in texts_to_encode:
texts_to_encode.append(text)
# Batch encode deduplicated texts
if texts_to_encode:
self.text_encoder.model.to(device)
encoded = self.text_encoder(texts_to_encode, device)
for text, feature in zip(texts_to_encode, encoded):
self.text_cache[text] = feature.cpu()
# Collect from cache
return [self.text_cache[text].to(device) for text in text_list]
def get_context(self, x, valid_len, device, param_dtype, training=False):
"""
Get cross-attention context from input dict.
Returns:
context: List[Tensor]
metadata: dict, may contain 'full_text'
"""
text_key = self.input_keys.get("text", "text")
text_end_key = self.input_keys.get("text_end", "text_end")
metadata = {}
if text_key not in x:
text_list = ["" for _ in range(len(valid_len))]
else:
text_list = x[text_key]
if isinstance(text_list[0], list):
# Multi-segment text (stream mode)
full_text = []
all_context = []
text_end_list = x[text_end_key]
for i in range(len(valid_len)):
if training and np.random.rand() <= self.drop_out:
single_text_list = [""]
single_text_end_list = [0, valid_len[i]]
else:
single_text_list = text_list[i]
single_text_end_list = [0] + [
min(t, valid_len[i]) for t in text_end_list[i]
]
single_text_length_list = [
t - b
for t, b in zip(single_text_end_list[1:], single_text_end_list[:-1])
]
full_text.append(
" ////////// ".join(
[
f"{u} //dur:{t}"
for u, t in zip(single_text_list, single_text_length_list)
]
)
)
single_text_context = self.encode(single_text_list, device)
single_text_context = [u.to(param_dtype) for u in single_text_context]
sample_context = []
for u, duration in zip(single_text_context, single_text_length_list):
sample_context.extend([u for _ in range(duration)])
all_context.append(sample_context)
metadata["full_text"] = full_text
return all_context, metadata
else:
# Single text per sample
full_text = [u for u in text_list]
metadata["full_text"] = full_text
if training:
text_list = [
("" if np.random.rand() <= self.drop_out else u) for u in text_list
]
else:
text_list = [u for u in text_list]
context = self.encode(text_list, device)
context = [u.to(param_dtype) for u in context]
return context, metadata
def get_null_context(self, batch_size, device, param_dtype):
"""Get null/empty context for classifier-free guidance."""
null_ctx = self.encode([""] * batch_size, device)
return [u.to(param_dtype) for u in null_ctx]
# --- Streaming state management ---
def init_stream(self, batch_size):
self.stream_condition_list = [[] for _ in range(batch_size)]
def update_stream(self, x, device, param_dtype):
"""Add one frame of context for a streaming step."""
text_key = self.input_keys.get("text", "text")
text_input = x[text_key]
new_ctx = self.encode(text_input, device)
new_ctx = [u.to(param_dtype) for u in new_ctx]
for i in range(len(self.stream_condition_list)):
self.stream_condition_list[i].append(new_ctx[i])
def get_stream_context(self, start_index, end_index):
context = []
for i in range(len(self.stream_condition_list)):
context.append(self.stream_condition_list[i][start_index:end_index])
return context
def trim_stream(self, trim_len):
"""Trim stream state when wrapping around."""
for i in range(len(self.stream_condition_list)):
self.stream_condition_list[i] = self.stream_condition_list[i][trim_len:]
class DiffForcingWanModel(nn.Module):
def __init__(
self,
input_dim=256,
mean_path=None,
std_path=None,
hidden_dim=1024,
ffn_dim=2048,
freq_dim=256,
num_heads=8,
num_layers=8,
time_embedding_scale=1.0,
causal=False,
rope_channel_split=[1, 0, 0],
spatial_shape=(1, 1),
prediction_type="vel", # "vel", "x0", "eps"
text_config={
"len": 512,
"dim": 4096,
},
schedule_config={
"noise_type": "linear",
"chunk_size": 5,
"steps": 10,
"extra_len": 4,
"random_epsilon": 0.00,
},
cfg_config={
"text_scale": 5.0,
"null_scale": -4.0,
},
input_keys={
"feature": "feature",
"feature_length": "feature_length",
"text": "text",
"text_end": "text_end",
},
):
super().__init__()
self.input_keys = input_keys
self.mean_path = mean_path
self.std_path = std_path
self.input_dim = input_dim
self.spatial_shape = tuple(spatial_shape)
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.time_embedding_scale = time_embedding_scale
self.causal = causal
self.rope_channel_split = rope_channel_split
self.prediction_type = prediction_type
self.cfg_config = cfg_config
self.schedule_config = schedule_config
self.time_scheduler = TriangularTimeScheduler(schedule_config)
# Cross-attention module (text)
self.text_module = T5TextCrossModule(**text_config)
if self.mean_path is not None:
self.register_buffer(
"mean", torch.from_numpy(np.load(self.mean_path)).float()
)
else:
self.register_buffer("mean", torch.zeros(input_dim))
if self.std_path is not None:
self.register_buffer(
"std", torch.from_numpy(np.load(self.std_path)).float()
)
else:
self.register_buffer("std", torch.ones(input_dim))
self.model = WanModel(
patch_size=(1, 1, 1),
text_len=self.text_module.len,
text_dim=self.text_module.dim,
cross_attn_norm=self.text_module.cross_attn_norm,
cross_rope=self.text_module.cross_rope,
in_dim=self.input_dim,
dim=self.hidden_dim,
ffn_dim=self.ffn_dim,
freq_dim=self.freq_dim,
out_dim=self.input_dim,
num_heads=self.num_heads,
num_layers=self.num_layers,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
causal=self.causal,
rope_channel_split=self.rope_channel_split,
)
self.param_dtype = torch.float32
def _extract_inputs(self, x):
"""Extract inputs from x using input_keys mapping."""
inputs = {}
for internal_key, external_key in self.input_keys.items():
if external_key in x:
inputs[internal_key] = x[external_key]
return inputs
def preprocess(self, x):
"""Convert last-channel format to channel-first, padding to 4D (C, T, H, W).
(T, C) -> (C, T, 1, 1)
(T, H, C) -> (C, T, H, 1)
(T, H, W, C) -> (C, T, H, W)
"""
for i in range(len(x)):
ndim = x[i].ndim
if ndim == 2: # (T, C)
x[i] = x[i].permute(1, 0)[:, :, None, None]
elif ndim == 3: # (T, H, C)
x[i] = x[i].permute(2, 0, 1)[:, :, :, None]
elif ndim == 4: # (T, H, W, C)
x[i] = x[i].permute(3, 0, 1, 2)
return x
def postprocess(self, x):
"""Reverse of preprocess: channel-first 4D back to last-channel, stripping padding dims.
(C, T, 1, 1) -> (T, C)
(C, T, H, 1) -> (T, H, C)
(C, T, H, W) -> (T, H, W, C)
"""
for i in range(len(x)):
shape = x[i].shape # (C, T, H, W)
if shape[2] == 1 and shape[3] == 1: # (C, T, 1, 1) -> (T, C)
x[i] = x[i][:, :, 0, 0].permute(1, 0)
elif shape[3] == 1: # (C, T, H, 1) -> (T, H, C)
x[i] = x[i][:, :, :, 0].permute(1, 2, 0)
else: # (C, T, H, W) -> (T, H, W, C)
x[i] = x[i].permute(1, 2, 3, 0)
return x
def forward(self, x):
x = self._extract_inputs(x)
feature_original = x["feature"] # (B, T, C)
feature_length = x["feature_length"] # (B,)
feature_original = (feature_original - self.mean) / self.std
batch_size = feature_original.shape[0]
seq_len = feature_original.shape[1]
device = feature_original.device
feature = []
valid_len = []
for i in range(batch_size):
length = min(feature_length[i].item(), seq_len)
valid_len.append(length)
feature.append(feature_original[i, :length, ...])
# Preprocess to (C, T, 1, 1) per sample
feature = self.preprocess(feature)
# Get context from text cross module
context, _ = self.text_module.get_context(
x,
valid_len,
device,
self.param_dtype,
training=True,
)
# Prepare noised data and schedule
s = self.time_scheduler.prepare(feature, device, valid_len, training=True)
time_schedules = s["time_schedules"]
input_start_index = s["input_start"]
input_end_index = s["input_end"]
output_start_index = s["output_start"]
output_end_index = s["output_end"]
dalpha = s["dalpha"]
dbeta = s["dbeta"]
x0, eps, xt = s["x0"], s["eps"], s["xt"]
# Slice per-frame context to match input window
if isinstance(context[0], (list, tuple)):
context = [
context[i][input_start_index[i] : input_end_index[i]]
for i in range(batch_size)
]
# time_schedules already sliced to input window by prepare()
time_schedules_input = [
time_schedules[i] * self.time_embedding_scale
for i in range(batch_size)
]
# Through WanModel
predicted_result = self.model(
xt,
time_schedules_input,
context,
seq_len,
y=None,
) # (B, C, T, 1, 1)
loss = 0.0
for b in range(batch_size):
pred_os = output_start_index[b] - input_start_index[b]
pred_oe = output_end_index[b] - input_start_index[b]
# dalpha, dbeta already sliced to output window by prepare()
dalpha_i = dalpha[b]
dbeta_i = dbeta[b]
if self.prediction_type == "vel":
vel = (
x0[b] * dalpha_i[None, :, None, None]
+ eps[b] * dbeta_i[None, :, None, None]
) # (C, output_length, 1, 1)
squared_error = (
predicted_result[b][:, pred_os:pred_oe, ...] - vel
) ** 2
elif self.prediction_type == "x0":
squared_error = (
predicted_result[b][:, pred_os:pred_oe, ...] - x0[b]
) ** 2
elif self.prediction_type == "eps":
squared_error = (
predicted_result[b][:, pred_os:pred_oe, ...] - eps[b]
) ** 2
sample_loss = squared_error.mean()
loss += sample_loss
loss = loss / batch_size
loss_dict = {"total": loss, "mse": loss}
return loss_dict
def generate(self, x):
"""
Generation - Diffusion Forcing inference
Uses triangular noise schedule, progressively generating from left to right
Generation process:
1. Start from t=0, gradually increase t
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
3. After each denoising step, t increases slightly and continues
"""
x = self._extract_inputs(x)
extra_len = self.schedule_config.get("extra_len", 0)
feature_length = x["feature_length"] # (B,)
batch_size = len(feature_length)
seq_len = max(feature_length).item() + extra_len
device = next(self.parameters()).device
valid_len = []
for i in range(batch_size):
length = min(feature_length[i].item(), seq_len)
valid_len.append(length)
generated_len = [seq_len for _ in range(batch_size)]
# Initialize entire sequence as pure noise
generated = torch.randn(
batch_size, seq_len, *self.spatial_shape, self.input_dim, device=device
)
generated = [generated[i] for i in range(batch_size)]
generated = self.preprocess(generated)
# Precompute text and null contexts for CFG
text_context, metadata = self.text_module.get_context(
x,
generated_len,
device,
self.param_dtype,
training=False,
)
null_context = self.text_module.get_null_context(
batch_size, device, self.param_dtype
)
full_text = metadata["full_text"]
total_steps = self.time_scheduler.get_total_steps(seq_len)
# Progressively advance from t=0 to t=max_t
for step in range(total_steps):
s = self.time_scheduler.prepare(
generated, device, generated_len, training=False, current_step=step
)
time_schedules = s["time_schedules"]
time_schedules_derivative = s["time_schedules_derivative"]
alpha = s["alpha"]
dalpha = s["dalpha"]
beta = s["beta"]
dbeta = s["dbeta"]
sigma = s["sigma"]
dlog_alpha = s["dlog_alpha"]
dlog_beta = s["dlog_beta"]
input_start_index = s["input_start"]
input_end_index = s["input_end"]
output_start_index = s["output_start"]
output_end_index = s["output_end"]
xt = s["xt"]
# time_schedules already sliced to input window by prepare()
time_schedules_input = [
time_schedules[i] * self.time_embedding_scale
for i in range(batch_size)
]
# Slice per-frame context to match input window
if isinstance(text_context[0], (list, tuple)):
window_text_context = [
text_context[i][input_start_index[i] : input_end_index[i]]
for i in range(batch_size)
]
else:
window_text_context = text_context
# CFG: text_scale * pred_text + null_scale * pred_null
pred_text = self.model(
xt,
time_schedules_input,
window_text_context,
seq_len,
y=None,
)
pred_null = self.model(
xt,
time_schedules_input,
null_context,
seq_len,
y=None,
)
predicted_result = [
self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn
for pt, pn in zip(pred_text, pred_null)
]
# All noise coefficients already sliced to output window by prepare()
for i in range(batch_size):
os, oe = output_start_index[i], output_end_index[i]
pred_os = os - input_start_index[i]
pred_oe = oe - input_start_index[i]
predicted_result_i = predicted_result[i][:, pred_os:pred_oe, ...]
generated_i = generated[i][:, os:oe, ...]
dt = time_schedules_derivative[i][None, :, None, None]
alpha_i = alpha[i][None, :, None, None]
dalpha_i = dalpha[i][None, :, None, None]
beta_i = beta[i][None, :, None, None]
dbeta_i = dbeta[i][None, :, None, None]
sigma_i = sigma[i][None, :, None, None]
dlog_alpha_i = dlog_alpha[i][None, :, None, None]
dlog_beta_i = dlog_beta[i][None, :, None, None]
if self.prediction_type == "vel":
vel = predicted_result_i
elif self.prediction_type == "x0":
vel = (
predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i)
+ generated_i * dlog_beta_i
)
elif self.prediction_type == "eps":
vel = (
predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i)
+ generated_i * dlog_alpha_i
)
st = (vel - generated_i * dlog_alpha_i) / (
(beta_i * dlog_alpha_i - dbeta_i) * beta_i
)
generated[i][:, os:oe, ...] += (
vel * dt
+ st * 0.5 * sigma_i**2 * dt
+ sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i)
)
generated = self.postprocess(generated) # list of (T, C)
y_hat_out = []
for i in range(batch_size):
single_generated = generated[i][: valid_len[i], :] * self.std + self.mean
y_hat_out.append(single_generated)
out = {}
out["generated"] = y_hat_out
out["text"] = full_text
return out
def init_generated(self, seq_len, batch_size=1, schedule_config={}):
"""Initialize streaming generation state.
Args:
seq_len: Model window size (how many frames WanModel processes per step).
schedule_config: Optional schedule config overrides.
Buffer is 2*seq_len. Model window is always buffer[0:seq_len].
When conditions overflow seq_len, shift buffer by seq_len and restart.
"""
self.schedule_config.update(schedule_config)
content_len = self.schedule_config.get("content_len", None)
if content_len is None:
self.schedule_config["content_len"] = seq_len
else:
self.schedule_config["content_len"] = min(seq_len, content_len)
self.time_scheduler = TriangularTimeScheduler(self.schedule_config)
self.batch_size = batch_size
self.seq_len = seq_len
self.buf_len = seq_len * 2
self.current_step = 0
self.current_commit = 0
self.condition_frames = 0
device = next(self.parameters()).device
# Initialize entire buffer as pure noise
generated = torch.randn(
batch_size, self.buf_len, *self.spatial_shape, self.input_dim, device=device
)
generated = [generated[i] for i in range(batch_size)]
self.generated = self.preprocess(generated)
# Initialize streaming state for cross module
self.text_module.init_stream(self.batch_size)
def _rollback(self):
"""Shift buffer by seq_len when conditions overflow the window."""
for i in range(self.batch_size):
self.generated[i][:, : self.seq_len, ...] = self.generated[i][
:, self.seq_len :, ...
].clone()
self.generated[i][:, self.seq_len :, ...] = torch.randn_like(
self.generated[i][:, self.seq_len :, ...]
)
self.current_step -= self.time_scheduler.get_step_rollback(self.seq_len)
self.condition_frames -= self.seq_len
self.current_commit -= self.seq_len
self.text_module.trim_stream(self.seq_len)
@torch.no_grad()
def stream_generate_step(self, x):
"""
Streaming generation step. Each call provides 1 frame of conditions.
The scheduler determines committable frames from accumulated conditions.
Returns:
dict with "generated": list of one (N, C) tensor, or [] if nothing to commit.
"""
x = self._extract_inputs(x)
device = next(self.parameters()).device
self.generated = [g.to(device) for g in self.generated]
# 1. Update conditions (1 frame per call)
self.text_module.update_stream(x, device, self.param_dtype)
self.condition_frames += 1
# 2. Rollback if conditions overflow the window
if self.condition_frames > self.buf_len:
self._rollback()
# 3. Determine how many frames can be committed
committable_length, committable_steps = self.time_scheduler.get_committable(
self.condition_frames
)
while self.current_step < committable_steps:
s = self.time_scheduler.prepare(
self.generated, device, [self.buf_len] * self.batch_size,
training=False, current_step=self.current_step
)
time_schedules = s["time_schedules"]
time_schedules_derivative = s["time_schedules_derivative"]
alpha = s["alpha"]
dalpha = s["dalpha"]
beta = s["beta"]
dbeta = s["dbeta"]
sigma = s["sigma"]
dlog_alpha = s["dlog_alpha"]
dlog_beta = s["dlog_beta"]
is_ = s["input_start"]
ie_ = s["input_end"]
os_ = s["output_start"]
oe_ = s["output_end"]
xt = s["xt"]
# time_schedules already sliced to input window by prepare()
time_schedules_input = [
time_schedules[0] * self.time_embedding_scale
] * self.batch_size
# CFG: batch text + null in one forward pass
text_context = self.text_module.get_stream_context(is_[0], ie_[0])
null_context = self.text_module.get_null_context(
self.batch_size, device, self.param_dtype
)
# Convert null to per-frame format to match text_context
window_len = ie_[0] - is_[0]
null_context_pf = [
[null_context[i]] * window_len for i in range(self.batch_size)
]
pred_all = self.model(
xt + xt,
time_schedules_input + time_schedules_input,
text_context + null_context_pf,
self.seq_len,
y=None,
)
pred_text = pred_all[: self.batch_size]
pred_null = pred_all[self.batch_size :]
predicted_result = [
self.cfg_config["text_scale"] * pt + self.cfg_config["null_scale"] * pn
for pt, pn in zip(pred_text, pred_null)
]
# All noise coefficients already sliced to output window by prepare()
os_idx, oe_idx = os_[0], oe_[0]
pred_os_idx = os_idx - is_[0]
pred_oe_idx = oe_idx - is_[0]
dt = time_schedules_derivative[0][None, :, None, None]
alpha_i = alpha[0][None, :, None, None]
dalpha_i = dalpha[0][None, :, None, None]
beta_i = beta[0][None, :, None, None]
dbeta_i = dbeta[0][None, :, None, None]
sigma_i = sigma[0][None, :, None, None]
dlog_alpha_i = dlog_alpha[0][None, :, None, None]
dlog_beta_i = dlog_beta[0][None, :, None, None]
for i in range(self.batch_size):
predicted_result_i = predicted_result[i][
:, pred_os_idx:pred_oe_idx, ...
]
generated_i = self.generated[i][:, os_idx:oe_idx, ...]
if self.prediction_type == "vel":
vel = predicted_result_i
elif self.prediction_type == "x0":
vel = (
predicted_result_i * (-dlog_beta_i * alpha_i + dalpha_i)
+ generated_i * dlog_beta_i
)
elif self.prediction_type == "eps":
vel = (
predicted_result_i * (-dlog_alpha_i * beta_i + dbeta_i)
+ generated_i * dlog_alpha_i
)
st = (vel - generated_i * dlog_alpha_i) / (
(beta_i * dlog_alpha_i - dbeta_i) * beta_i
)
self.generated[i][:, os_idx:oe_idx, ...] += (
vel * dt
+ st * 0.5 * sigma_i**2 * dt
+ sigma_i * torch.sqrt(dt) * torch.randn_like(generated_i)
)
self.current_step += 1
# 5. Extract newly committed frames
if self.current_commit < committable_length:
output = [
self.generated[i][:, self.current_commit : committable_length, ...]
for i in range(self.batch_size)
]
output = self.postprocess(output)
output = [o * self.std + self.mean for o in output]
self.current_commit = committable_length
return {"generated": output}
else:
empty = [
torch.zeros(self.input_dim, 0, *self.spatial_shape, device=device)
for _ in range(self.batch_size)
]
empty = self.postprocess(empty)
return {"generated": empty}