RealWonder / video_generator.py
Wei Liu
CPU-first startup: load all models/scenes to CPU at module level, GPU transfer at generation time
0cdce4a
"""Streaming block-by-block video generation.
Decomposes CausalInferencePipelineSDEdit.inference() into per-block
streaming calls, enabling real-time generation with interactive control.
"""
from typing import List, Optional
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from vidgen import (
CausalInferencePipelineSDEdit,
WanImageEncoder,
WanVideoVAE,
WanVideoUnit_ImageEmbedderCLIP,
WanVideoUnit_ImageEmbedderVAE,
DynamicSwapInstaller,
gpu,
get_cuda_free_memory_gb,
load_first_frame,
set_seed,
)
from vidgen.utils import extract_subdim
from gpu_profiler import log_gpu
from config import (
DEFAULT_HEIGHT, DEFAULT_WIDTH,
FRAMES_PER_BLOCK, LATENT_C, LATENT_H, LATENT_W,
DEFAULT_LOCAL_ATTN_SIZE, DEFAULT_TIMESTEP_SHIFT, CONTEXT_NOISE,
)
class StreamingVideoGenerator:
"""Block-by-block video generation with SDEdit support."""
def __init__(self, checkpoint_path: str, num_pixel_frames: int,
denoising_steps: list, device: str = "cuda",
use_ema: bool = False, seed: int = 42,
mask_dropin_step: int = -1, franka_step: int = -1,
enable_taehv: bool = False):
self.checkpoint_path = checkpoint_path
self.num_pixel_frames = num_pixel_frames
self.device = torch.device(device)
self.use_ema = use_ema
self.seed = seed
self.denoising_steps = denoising_steps
self.mask_dropin_step = mask_dropin_step
self.franka_step = franka_step
self.enable_taehv = enable_taehv
self.pipeline = None
self.taehv_decoder = None
self.taehv_cache = None
self.is_setup = False
self.case_data = {}
def setup(self):
"""Load models and initialize the pipeline."""
set_seed(self.seed)
torch.set_grad_enabled(False)
try:
low_memory = get_cuda_free_memory_gb(gpu) < 40
except Exception:
low_memory = False # No GPU at module level (ZeroGPU)
config = OmegaConf.create({
"independent_first_frame": False,
"warp_denoising_step": True,
"context_noise": CONTEXT_NOISE,
"causal": True,
"i2v": True,
"i2v_flow": True,
"height": DEFAULT_HEIGHT,
"width": DEFAULT_WIDTH,
"num_frames": self.num_pixel_frames,
"num_frame_per_block": FRAMES_PER_BLOCK,
"denoising_step_list": self.denoising_steps,
"mask_dropin_step": self.mask_dropin_step,
"franka_step": self.franka_step,
"model_kwargs": {
"sink_size": 1,
"local_attn_size": DEFAULT_LOCAL_ATTN_SIZE,
"timestep_shift": DEFAULT_TIMESTEP_SHIFT,
},
})
log_gpu("before pipeline init")
self.pipeline = CausalInferencePipelineSDEdit(config, device=self.device,
use_separate_encode_vae=True)
log_gpu("after pipeline init (on CPU)")
state_dict = torch.load(self.checkpoint_path, map_location="cpu")
key = "generator_ema" if self.use_ema else "generator"
gen_state_dict = state_dict[key]
try:
self.pipeline.generator.load_state_dict(gen_state_dict)
except RuntimeError:
gen_state_dict = {
k.replace("._fsdp_wrapped_module", ""): v
for k, v in gen_state_dict.items()
}
self.pipeline.generator.load_state_dict(gen_state_dict)
self.pipeline = self.pipeline.to(dtype=torch.bfloat16)
log_gpu("after checkpoint load (bf16, CPU)")
if low_memory:
DynamicSwapInstaller.install_model(self.pipeline.text_encoder, device=self.device)
else:
self.pipeline.text_encoder.to(device=self.device)
self.pipeline.generator.to(device=self.device)
self.pipeline.vae.to(device=self.device)
self.pipeline.encode_vae.to(device=self.device, dtype=torch.bfloat16)
if self.enable_taehv:
import os
import urllib.request
from taehv import TAEHV
taehv_path = os.path.join(os.path.dirname(__file__), "checkpoints", "taew2_1.pth")
taehv_path = os.path.abspath(taehv_path)
if not os.path.exists(taehv_path):
os.makedirs(os.path.dirname(taehv_path), exist_ok=True)
url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
print(f"Downloading TAEHV weights from {url} ...")
urllib.request.urlretrieve(url, taehv_path)
print("Loading TAEHV decoder...")
self.taehv_decoder = TAEHV(checkpoint_path=taehv_path).to(
device=gpu, dtype=torch.float16,
)
self.taehv_decoder.eval()
self.taehv_decoder.requires_grad_(False)
self.pipeline.processor_dtype = torch.float32
self.pipeline.processor_device = self.device
self.pipeline.processor_vae = WanVideoVAE().to(device=self.device, dtype=torch.float32)
self.pipeline.processor_ienc = WanImageEncoder().to(device=self.device, dtype=torch.float32)
self.pipeline.processor_vae.requires_grad_(False)
self.pipeline.processor_ienc.requires_grad_(False)
for p in self.pipeline.processor_vae.parameters():
p.data = p.data.to(dtype=torch.float32)
for b in self.pipeline.processor_vae.buffers():
b.data = b.data.to(dtype=torch.float32)
self.pipeline.processors = [
WanVideoUnit_ImageEmbedderVAE(),
WanVideoUnit_ImageEmbedderCLIP(),
]
self.is_setup = True
log_gpu("setup complete")
print("StreamingVideoGenerator setup complete")
def precompute_case(self, case_name: str, first_frame_path: str,
default_prompt: str, sdedit_cfg: dict):
"""Pre-compute per-case data: VAE+CLIP embeddings, text features, denoising steps.
Processors (processor_vae / processor_ienc) remain alive across all cases;
call finish_precompute() once all cases are done to free them.
KV cache + crossattn cache are allocated on the first call only.
"""
assert self.is_setup, "Call setup() first"
device = self.device
pipeline = self.pipeline
log_gpu(f"precompute_case {case_name}: start")
input_image = load_first_frame(first_frame_path, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH)
batch = {
"input_image": input_image.unsqueeze(0),
"end_image": None,
"height": DEFAULT_HEIGHT,
"width": DEFAULT_WIDTH,
"num_frames": self.num_pixel_frames,
}
i2v_conditional = {}
for unit in pipeline.processors:
input_data = {"device": device, "torch_dtype": pipeline.processor_dtype}
for key in unit.input_params:
input_data[key] = batch.get(key)
for key in unit.onload_model_names:
if key == "image_encoder":
input_data["image_encoder"] = pipeline.processor_ienc
if key == "vae":
input_data["vae"] = pipeline.processor_vae
unit_output = unit.process(**input_data)
for k, v in unit_output.items():
i2v_conditional[k] = (
v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v
)
# Initialize KV cache + crossattn cache on first case only
if pipeline.kv_cache1 is None:
dtype = torch.bfloat16
pipeline._initialize_kv_cache(batch_size=1, dtype=dtype, device=device)
pipeline._initialize_crossattn_cache(batch_size=1, dtype=dtype, device=device)
full_y = (
i2v_conditional["y"].permute(0, 2, 1, 3, 4)
if "y" in i2v_conditional else None
)
# Clear VAE caches between cases to prevent cache contamination
pipeline.vae.model.clear_cache()
pipeline.encode_vae.model.clear_cache()
# Pre-encode the default text prompt for this case
if default_prompt:
default_text_features = pipeline.text_encoder(text_prompts=[default_prompt])
else:
default_text_features = None
# Compute the warp-transformed denoising_step_list for this case
# (mirrors the transform done in CausalInferencePipelineSDEdit.__init__)
# Use explicit device='cpu' for all tensors: torch.set_default_device('cuda')
# may be active inside the WAN model, which would put torch.tensor() on GPU.
raw_steps = torch.tensor(sdedit_cfg["denoising_step_list"], dtype=torch.long, device='cpu')
timesteps_all = torch.cat([
pipeline.scheduler.timesteps.detach().cpu().float(),
torch.tensor([0], dtype=torch.float32, device='cpu'),
])
denoising_step_list = timesteps_all[1000 - raw_steps]
self.case_data[case_name] = {
"i2v_conditional": i2v_conditional,
"full_y": full_y,
"default_prompt": default_prompt,
"default_text_features": default_text_features,
"denoising_step_list": denoising_step_list,
"mask_dropin_step": sdedit_cfg["mask_dropin_step"],
"franka_step": sdedit_cfg["franka_step"],
}
log_gpu(f"precompute_case {case_name}: done")
print(f"Case '{case_name}' pre-computation complete")
def finish_precompute(self):
"""Free processor models after all cases have been pre-computed."""
pipeline = self.pipeline
if hasattr(pipeline, 'processor_vae') and pipeline.processor_vae is not None:
pipeline.processor_vae.cpu()
del pipeline.processor_vae
pipeline.processor_vae = None
if hasattr(pipeline, 'processor_ienc') and pipeline.processor_ienc is not None:
pipeline.processor_ienc.cpu()
del pipeline.processor_ienc
pipeline.processor_ienc = None
torch.cuda.empty_cache()
print("Processor models freed after pre-computation")
def prepare_generation(self, prompt: str, case_name: str):
"""Lightweight per-request preparation: swap case data + text encode + reset caches."""
assert self.is_setup, "Call setup() first"
pipeline = self.pipeline
case_data = self.case_data[case_name]
# ZeroGPU frees GPU memory between calls; null caches to force re-allocation on GPU.
pipeline.kv_cache1 = None
pipeline.kv_cache2 = None
if hasattr(pipeline, 'crossattn_cache'):
pipeline.crossattn_cache = None
# Swap pipeline settings for this case
pipeline.denoising_step_list = case_data["denoising_step_list"]
pipeline.mask_dropin_step = case_data["mask_dropin_step"]
pipeline.franka_step = case_data["franka_step"]
# Set case-specific temporal conditioning
self.full_y = case_data["full_y"]
# Text encoding (use cached features if prompt matches this case's default)
if (prompt == case_data["default_prompt"]
and case_data["default_text_features"] is not None):
text_features = case_data["default_text_features"]
else:
text_features = pipeline.text_encoder(text_prompts=[prompt])
self.conditional_dict = dict(text_features)
for k, v in case_data["i2v_conditional"].items():
self.conditional_dict[k] = v
# Re-allocate KV + crossattn caches (nulled above for ZeroGPU compatibility)
if pipeline.kv_cache1 is None:
dtype = torch.bfloat16
pipeline._initialize_kv_cache(batch_size=1, dtype=dtype, device=self.device)
pipeline._initialize_crossattn_cache(batch_size=1, dtype=dtype, device=self.device)
# Reset crossattn cache if case or prompt changed since last call
case_or_prompt_changed = (
case_name != getattr(self, "_last_prepared_case", None)
or prompt != getattr(self, "_last_prepared_prompt", None)
)
if case_or_prompt_changed:
for block_index in range(pipeline.num_transformer_blocks):
pipeline.crossattn_cache[block_index]["is_init"] = False
self._last_prepared_case = case_name
self._last_prepared_prompt = prompt
# Reset KV cache indices
for block_index in range(len(pipeline.kv_cache1)):
pipeline.kv_cache1[block_index]["global_end_index"].fill_(0)
pipeline.kv_cache1[block_index]["local_end_index"].fill_(0)
self.current_start_frame = 0
self.taehv_cache = None
pipeline.vae.model.clear_cache()
pipeline.encode_vae.model.clear_cache()
print("Generation prepared")
def generate_block(self, block_idx: int, structured_noise: torch.Tensor,
sim_latent: torch.Tensor,
sde_noise: Optional[torch.Tensor] = None,
sim_mask: Optional[torch.Tensor] = None,
sim_franka_mask: Optional[torch.Tensor] = None) -> List[np.ndarray]:
"""Generate and decode one block of video."""
pipeline = self.pipeline
device = self.device
num_frames = FRAMES_PER_BLOCK
structured_noise = structured_noise.to(device=device, dtype=torch.bfloat16)
sim_latent = sim_latent.to(device=device, dtype=torch.bfloat16)
if sde_noise is not None:
sde_noise = sde_noise.to(device=device, dtype=torch.bfloat16)
if sim_mask is not None:
sim_mask = sim_mask.to(device=device)
if sim_franka_mask is not None:
sim_franka_mask = sim_franka_mask.to(device=device)
sdedit_step = pipeline.denoising_step_list[0]
noisy_input = pipeline.scheduler.add_noise(
sim_latent.flatten(0, 1),
structured_noise.flatten(0, 1),
sdedit_step * torch.ones([num_frames], device=device, dtype=torch.long),
).unflatten(0, (1, num_frames))
bg_noisy_input = None
if pipeline.mask_dropin_step > 0 and sim_mask is not None:
mask_step = pipeline.denoising_step_list[pipeline.mask_dropin_step]
bg_noisy_input = pipeline.scheduler.add_noise(
sim_latent.flatten(0, 1),
noisy_input.flatten(0, 1),
mask_step * torch.ones([num_frames], device=device, dtype=torch.long),
).unflatten(0, (1, num_frames))
bg_noisy_franka = None
use_franka = (
pipeline.franka_step >= 0
and sim_franka_mask is not None
and sim_franka_mask.any()
)
if use_franka:
franka_step = pipeline.denoising_step_list[pipeline.franka_step]
bg_noisy_franka = pipeline.scheduler.add_noise(
sim_latent.flatten(0, 1),
noisy_input.flatten(0, 1),
franka_step * torch.ones([num_frames], device=device, dtype=torch.long),
).unflatten(0, (1, num_frames))
curr_y = None
if self.full_y is not None:
start = self.current_start_frame
curr_y = self.full_y[:, start:start + num_frames]
for index, current_timestep in enumerate(pipeline.denoising_step_list):
timestep = torch.ones(
[1, num_frames], device=device, dtype=torch.int64
) * current_timestep
if (pipeline.mask_dropin_step > 0
and pipeline.mask_dropin_step == index
and sim_mask is not None
and bg_noisy_input is not None):
noisy_input = torch.where(
sim_mask.unsqueeze(2),
noisy_input, bg_noisy_input,
)
if (use_franka and pipeline.franka_step == index and bg_noisy_franka is not None):
noisy_input = torch.where(
sim_franka_mask.unsqueeze(2),
bg_noisy_franka, noisy_input,
)
_, denoised_pred = pipeline.generator(
noisy_image_or_video=noisy_input,
conditional_dict=self.conditional_dict,
curr_y=curr_y,
timestep=timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=self.current_start_frame * pipeline.frame_seq_length,
)
if index < len(pipeline.denoising_step_list) - 1:
next_step = pipeline.denoising_step_list[index + 1]
if sde_noise is not None:
sde_n = extract_subdim(sde_noise, LATENT_C, return_complement=False, channel_dim=2)
else:
sde_n = torch.randn_like(noisy_input)
noisy_input = pipeline.scheduler.add_noise(
denoised_pred.flatten(0, 1),
sde_n.flatten(0, 1),
next_step * torch.ones([num_frames], device=device, dtype=torch.long),
).unflatten(0, denoised_pred.shape[:2])
context_timestep = torch.ones_like(timestep) * pipeline.args.context_noise
pipeline.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=self.conditional_dict,
curr_y=curr_y,
timestep=context_timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=self.current_start_frame * pipeline.frame_seq_length,
)
self.current_start_frame += num_frames
if self.enable_taehv:
if self.taehv_cache is None:
decode_input = denoised_pred
self.taehv_cache = denoised_pred
else:
decode_input = torch.cat([self.taehv_cache, denoised_pred], dim=1)
self.taehv_cache = decode_input[:, -3:, :, :, :]
video = self.taehv_decoder.decode_video(
decode_input.to(dtype=torch.float16), parallel=True,
)
if block_idx == 0:
video = video[:, 3:]
else:
video = video[:, 12:]
video = video.clamp(0, 1)
else:
video = pipeline.vae.decode_to_pixel(denoised_pred, use_cache=True)
video = (video * 0.5 + 0.5).clamp(0, 1)
video = rearrange(video, "b t c h w -> b t h w c").cpu()
frames = (255.0 * video[0]).to(torch.uint8).numpy()
return [frames[i] for i in range(frames.shape[0])]
def move_case_data_to_device(self, device: str):
"""Move all precomputed case tensors to device.
Call with 'cuda' before generation starts inside @spaces.GPU;
call with 'cpu' in the finally block so ZeroGPU can release GPU memory.
"""
for case_data in self.case_data.values():
for k, v in list(case_data.items()):
if isinstance(v, torch.Tensor):
case_data[k] = v.to(device)
elif isinstance(v, dict):
for kk, vv in list(v.items()):
if isinstance(vv, torch.Tensor):
v[kk] = vv.to(device)
def reset(self):
"""Reset generation state, preserving KV cache allocations."""
if self.pipeline is not None:
pipeline = self.pipeline
if pipeline.kv_cache1 is not None:
for block_index in range(len(pipeline.kv_cache1)):
pipeline.kv_cache1[block_index]["global_end_index"].fill_(0)
pipeline.kv_cache1[block_index]["local_end_index"].fill_(0)
pipeline.vae.model.clear_cache()
pipeline.encode_vae.model.clear_cache()
self.current_start_frame = 0
self.conditional_dict = None
self.taehv_cache = None
def move_pipeline_to_device(self, device: str):
"""Move all pipeline models to target device (CPU→GPU at generation start, GPU→CPU at end)."""
dev = torch.device(device)
self.device = dev
pipeline = self.pipeline
if hasattr(pipeline, 'generator') and pipeline.generator is not None:
pipeline.generator.to(device=dev)
if hasattr(pipeline, 'vae') and pipeline.vae is not None:
pipeline.vae.to(device=dev)
if hasattr(pipeline, 'encode_vae') and pipeline.encode_vae is not None:
pipeline.encode_vae.to(device=dev)
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
pipeline.text_encoder.to(device=dev)