Spaces:
Running on Zero
Running on Zero
| """Readable staged video-to-video API for StreamDiffusionV2.""" | |
| from __future__ import annotations | |
| from contextlib import ExitStack | |
| from dataclasses import dataclass | |
| from importlib.resources import as_file, files | |
| from pathlib import Path | |
| from typing import Literal | |
| from diffusers.utils import export_to_video as diffusers_export_to_video | |
| import numpy as np | |
| import torch | |
| from models.util import set_seed | |
| from streamv2v.inference import ( | |
| SingleGPUInferencePipeline as StreamBatchInferencePipeline, | |
| compute_noise_scale_and_step, | |
| ) | |
| from streamv2v.inference_common import load_mp4_as_tensor, merge_cli_config, normalize_acceleration_flags | |
| from streamv2v.inference_wo_batch import SingleGPUInferencePipeline as StreamNoBatchInferencePipeline | |
| SingleMode = Literal["single", "single-wo"] | |
| class VideoChunk: | |
| """One video chunk prepared for the encode -> denoise -> decode loop.""" | |
| frames: torch.Tensor | |
| start_idx: int | |
| end_idx: int | |
| current_start: int | |
| current_end: int | |
| class EncodedChunk: | |
| """Encoded latent chunk plus the schedule metadata needed for denoising.""" | |
| noisy_latents: torch.Tensor | |
| current_start: int | |
| current_end: int | |
| noise_scale: float | |
| current_step: int | None = None | |
| class DenoisedChunk: | |
| """Denoised latent chunk ready for VAE decoding.""" | |
| denoised_pred: torch.Tensor | |
| last_frame_only: bool | |
| def _resolve_default_config_path(resource_stack: ExitStack) -> str: | |
| resource = files("streamv2v.configs").joinpath("wan_causal_dmd_v2v.yaml") | |
| return str(resource_stack.enter_context(as_file(resource))) | |
| def _resolve_device(device: str | torch.device | None) -> torch.device: | |
| cuda_available = torch.cuda.is_available() | |
| if device is None: | |
| return torch.device("cuda" if cuda_available else "cpu") | |
| resolved = torch.device(device) | |
| if resolved.type == "cuda" and not cuda_available: | |
| raise RuntimeError("CUDA is not available in the current Python environment") | |
| if resolved.type == "cuda" and resolved.index is not None: | |
| torch.cuda.set_device(resolved.index) | |
| return resolved | |
| def _normalize_video_tensor( | |
| video: str | Path | torch.Tensor, | |
| *, | |
| height: int, | |
| width: int, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| if isinstance(video, (str, Path)): | |
| tensor = load_mp4_as_tensor(str(video), resize_hw=(height, width)).unsqueeze(0) | |
| else: | |
| tensor = video | |
| if tensor.ndim == 4: | |
| tensor = tensor.unsqueeze(0) | |
| if tensor.ndim != 5: | |
| raise ValueError("video tensor must have shape [B, C, T, H, W] or [C, T, H, W]") | |
| if tensor.dtype != torch.bfloat16: | |
| tensor = tensor.to(dtype=torch.bfloat16) | |
| return tensor.to(device) | |
| def load_video(video_path: str, *, height: int = 480, width: int = 832) -> torch.Tensor: | |
| """Load a video file as a normalized tensor with shape [C, T, H, W].""" | |
| return load_mp4_as_tensor(video_path, resize_hw=(height, width)) | |
| def export_video(video: np.ndarray, output_path: str, *, fps: int = 16) -> str: | |
| """Write a `[T, H, W, C]` float video array to an mp4 file.""" | |
| output_file = Path(output_path) | |
| output_file.parent.mkdir(parents=True, exist_ok=True) | |
| diffusers_export_to_video(video, str(output_file), fps=fps) | |
| return str(output_file) | |
| class StreamDiffusionV2Pipeline: | |
| """Readable staged single-GPU API that mirrors the offline inference flow.""" | |
| def __init__( | |
| self, | |
| checkpoint_folder: str, | |
| *, | |
| mode: SingleMode = "single", | |
| config_path: str | None = None, | |
| device: str | torch.device | None = None, | |
| noise_scale: float = 0.8, | |
| height: int = 480, | |
| width: int = 832, | |
| fps: int = 16, | |
| step: int = 2, | |
| seed: int = 0, | |
| model_type: str = "T2V-1.3B", | |
| use_taehv: bool = False, | |
| use_tensorrt: bool = False, | |
| fast: bool = False, | |
| profile: bool = False, | |
| ) -> None: | |
| if mode not in {"single", "single-wo"}: | |
| raise ValueError("StreamDiffusionV2Pipeline only supports 'single' and 'single-wo'") | |
| self._resource_stack = ExitStack() | |
| self.mode = mode | |
| self.device = _resolve_device(device) | |
| self.checkpoint_folder = checkpoint_folder | |
| self.noise_scale = float(noise_scale) | |
| self.height = int(height) | |
| self.width = int(width) | |
| self.fps = int(fps) | |
| self.seed = int(seed) | |
| self.step = int(step) | |
| self.profile = bool(profile) | |
| self.model_type = model_type | |
| self.prompt: str | None = None | |
| resolved_config_path = config_path or _resolve_default_config_path(self._resource_stack) | |
| self.config_path = resolved_config_path | |
| flags = normalize_acceleration_flags( | |
| { | |
| "use_taehv": use_taehv, | |
| "use_tensorrt": use_tensorrt, | |
| "fast": fast, | |
| } | |
| ) | |
| self.use_taehv = bool(flags["use_taehv"]) | |
| self.use_tensorrt = bool(flags["use_tensorrt"]) | |
| self.fast = bool(flags["fast"]) | |
| config_args = { | |
| "config_path": resolved_config_path, | |
| "checkpoint_folder": checkpoint_folder, | |
| "noise_scale": noise_scale, | |
| "height": height, | |
| "width": width, | |
| "fps": fps, | |
| "step": step, | |
| "seed": seed, | |
| "model_type": model_type, | |
| "profile": profile, | |
| "use_taehv": self.use_taehv, | |
| "use_tensorrt": self.use_tensorrt, | |
| "fast": self.fast, | |
| "t2v": False, | |
| "target_fps": None, | |
| "fixed_noise_scale": False, | |
| "num_frames": 81, | |
| } | |
| self.config = merge_cli_config(resolved_config_path, config_args) | |
| manager_cls = ( | |
| StreamBatchInferencePipeline if mode == "single" else StreamNoBatchInferencePipeline | |
| ) | |
| torch.set_grad_enabled(False) | |
| set_seed(self.seed) | |
| self.pipeline_manager = manager_cls(self.config, self.device) | |
| self.pipeline_manager.load_model(checkpoint_folder) | |
| self.chunk_size = 4 * self.config.num_frame_per_block | |
| self.num_steps = len(self.pipeline_manager.pipeline.denoising_step_list) | |
| self._next_chunk_index = 0 | |
| def close(self) -> None: | |
| self._resource_stack.close() | |
| def __enter__(self) -> "StreamDiffusionV2Pipeline": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.close() | |
| def enable_acceleration( | |
| self, | |
| *, | |
| use_taehv: bool = False, | |
| use_tensorrt: bool = False, | |
| fast: bool = False, | |
| ) -> "StreamDiffusionV2Pipeline": | |
| """Rebuild the pipeline with the requested acceleration flags.""" | |
| replacement = StreamDiffusionV2Pipeline( | |
| checkpoint_folder=self.checkpoint_folder, | |
| mode=self.mode, | |
| config_path=self.config_path, | |
| device=self.device, | |
| noise_scale=self.noise_scale, | |
| height=self.height, | |
| width=self.width, | |
| fps=self.fps, | |
| step=self.step, | |
| seed=self.seed, | |
| model_type=self.model_type, | |
| use_taehv=use_taehv, | |
| use_tensorrt=use_tensorrt, | |
| fast=fast, | |
| profile=self.profile, | |
| ) | |
| self.close() | |
| self.__dict__.update(replacement.__dict__) | |
| return self | |
| def prepare(self, prompt: str) -> None: | |
| """Reset the stream state and store the prompt for the next denoising pass.""" | |
| self.prompt = prompt | |
| self.pipeline_manager.reset_stream_state(reset_vae_flags=True) | |
| self.pipeline_manager.processed = 0 | |
| self._next_chunk_index = 0 | |
| def chunk_video(self, video: str | Path | torch.Tensor) -> list[VideoChunk]: | |
| """Split a full input video into the same chunks used by the offline inference loop.""" | |
| input_video = _normalize_video_tensor( | |
| video, | |
| height=self.height, | |
| width=self.width, | |
| device=self.device, | |
| ) | |
| _, _, total_frames, _, _ = input_video.shape | |
| if total_frames < 1 + self.chunk_size: | |
| raise ValueError(f"video must contain at least {1 + self.chunk_size} frames") | |
| chunks: list[VideoChunk] = [] | |
| start_idx = 0 | |
| end_idx = 1 + self.chunk_size | |
| current_start = 0 | |
| current_end = self.pipeline_manager.pipeline.frame_seq_length * (1 + self.chunk_size // 4) | |
| chunks.append( | |
| VideoChunk( | |
| frames=input_video[:, :, start_idx:end_idx], | |
| start_idx=start_idx, | |
| end_idx=end_idx, | |
| current_start=current_start, | |
| current_end=current_end, | |
| ) | |
| ) | |
| while True: | |
| start_idx = end_idx | |
| end_idx = end_idx + self.chunk_size | |
| if end_idx > total_frames: | |
| break | |
| current_start = current_end | |
| current_end = current_end + (self.chunk_size // 4) * self.pipeline_manager.pipeline.frame_seq_length | |
| chunks.append( | |
| VideoChunk( | |
| frames=input_video[:, :, start_idx:end_idx], | |
| start_idx=start_idx, | |
| end_idx=end_idx, | |
| current_start=current_start, | |
| current_end=current_end, | |
| ) | |
| ) | |
| return chunks | |
| def encode_chunk( | |
| self, | |
| input_video: str | Path | torch.Tensor, | |
| chunk: VideoChunk, | |
| *, | |
| previous_noise_scale: float | None = None, | |
| initial_noise_scale: float | None = None, | |
| ) -> EncodedChunk: | |
| """Encode one chunk in the same style as the offline inference loop.""" | |
| full_video = _normalize_video_tensor( | |
| input_video, | |
| height=self.height, | |
| width=self.width, | |
| device=self.device, | |
| ) | |
| noise_scale = self.noise_scale if previous_noise_scale is None else float(previous_noise_scale) | |
| init_noise_scale = self.noise_scale if initial_noise_scale is None else float(initial_noise_scale) | |
| current_step = None | |
| if chunk.start_idx != 0: | |
| noise_scale, current_step = compute_noise_scale_and_step( | |
| full_video, | |
| chunk.end_idx, | |
| self.chunk_size, | |
| noise_scale, | |
| init_noise_scale, | |
| ) | |
| latents = self.pipeline_manager._timed_stream_encode(chunk.frames) | |
| latents = latents.transpose(2, 1).contiguous().to(dtype=torch.bfloat16) | |
| noise = torch.randn_like(latents) | |
| return EncodedChunk( | |
| noisy_latents=noise * noise_scale + latents * (1 - noise_scale), | |
| current_start=chunk.current_start, | |
| current_end=chunk.current_end, | |
| noise_scale=float(noise_scale), | |
| current_step=current_step, | |
| ) | |
| def encode_video(self, video: str | Path | torch.Tensor) -> list[EncodedChunk]: | |
| """Encode a full input video into noisy latent chunks.""" | |
| chunks: list[EncodedChunk] = [] | |
| noise_scale = float(self.noise_scale) | |
| init_noise_scale = noise_scale | |
| video_chunks = self.chunk_video(video) | |
| full_video = _normalize_video_tensor( | |
| video, | |
| height=self.height, | |
| width=self.width, | |
| device=self.device, | |
| ) | |
| for chunk in video_chunks: | |
| encoded_chunk = self.encode_chunk( | |
| full_video, | |
| chunk, | |
| previous_noise_scale=noise_scale, | |
| initial_noise_scale=init_noise_scale, | |
| ) | |
| noise_scale = encoded_chunk.noise_scale | |
| chunks.append(encoded_chunk) | |
| return chunks | |
| def denoise_chunks(self, chunks: list[EncodedChunk]) -> list[DenoisedChunk]: | |
| """Run DiT denoising over the encoded chunks.""" | |
| if not chunks: | |
| raise ValueError("chunks must not be empty") | |
| if self.prompt is None: | |
| raise RuntimeError("Call prepare(prompt) before denoise_chunks(...)") | |
| self.prepare(self.prompt) | |
| outputs: list[DenoisedChunk] = [] | |
| for chunk in chunks: | |
| denoised_chunk = self.denoise_chunk(chunk) | |
| if denoised_chunk is not None: | |
| outputs.append(denoised_chunk) | |
| return outputs | |
| def denoise_chunk(self, chunk: EncodedChunk) -> DenoisedChunk | None: | |
| """Run DiT on one encoded chunk and return a decodable latent when available.""" | |
| if self.prompt is None: | |
| raise RuntimeError("Call prepare(prompt) before denoise_chunk(...)") | |
| if self._next_chunk_index == 0: | |
| if self.mode == "single": | |
| denoised_pred = self.pipeline_manager.prepare_pipeline( | |
| text_prompts=[self.prompt], | |
| noise=chunk.noisy_latents, | |
| current_start=chunk.current_start, | |
| current_end=chunk.current_end, | |
| ) | |
| else: | |
| denoised_pred = self.pipeline_manager.prepare_pipeline( | |
| text_prompts=[self.prompt], | |
| noise=chunk.noisy_latents, | |
| current_start=chunk.current_start, | |
| current_end=chunk.current_end, | |
| batch_denoise=False, | |
| ) | |
| self._next_chunk_index += 1 | |
| return DenoisedChunk(denoised_pred=denoised_pred, last_frame_only=False) | |
| current_start = chunk.current_start | |
| current_end = chunk.current_end | |
| if current_start // self.pipeline_manager.pipeline.frame_seq_length >= self.pipeline_manager.t_refresh: | |
| current_start = self.pipeline_manager.pipeline.kv_cache_length - self.pipeline_manager.pipeline.frame_seq_length | |
| current_end = current_start + (self.chunk_size // 4) * self.pipeline_manager.pipeline.frame_seq_length | |
| if self.mode == "single": | |
| denoised_pred = self.pipeline_manager.pipeline.inference_stream( | |
| noise=chunk.noisy_latents, | |
| current_start=current_start, | |
| current_end=current_end, | |
| current_step=chunk.current_step, | |
| ) | |
| self.pipeline_manager.processed += 1 | |
| self._next_chunk_index += 1 | |
| if self.pipeline_manager.processed < self.num_steps: | |
| return None | |
| return DenoisedChunk(denoised_pred=denoised_pred, last_frame_only=True) | |
| denoised_pred = self.pipeline_manager.pipeline.inference_wo_batch( | |
| noise=chunk.noisy_latents, | |
| current_start=current_start, | |
| current_end=current_end, | |
| current_step=chunk.current_step, | |
| ) | |
| self.pipeline_manager.processed += 1 | |
| self._next_chunk_index += 1 | |
| return DenoisedChunk(denoised_pred=denoised_pred, last_frame_only=True) | |
| def decode_chunks(self, chunks: list[DenoisedChunk]) -> np.ndarray: | |
| """Decode denoised latent chunks into a `[T, H, W, C]` video array.""" | |
| if not chunks: | |
| raise ValueError("chunks must not be empty") | |
| decoded = [self.decode_chunk(chunk) for chunk in chunks] | |
| return np.concatenate(decoded, axis=0) | |
| def decode_chunk(self, chunk: DenoisedChunk) -> np.ndarray: | |
| """Decode one denoised latent chunk into `[T, H, W, C]` frames.""" | |
| return self.pipeline_manager._decode_video_array( | |
| chunk.denoised_pred, | |
| last_frame_only=chunk.last_frame_only, | |
| ) | |
| def __call__(self, video: str | Path | torch.Tensor) -> np.ndarray: | |
| """Run the full staged pipeline after `prepare(prompt)` has been called.""" | |
| encoded = self.encode_video(video) | |
| denoised = self.denoise_chunks(encoded) | |
| return self.decode_chunks(denoised) | |