import os import time import base64 import tempfile import queue as pyqueue import multiprocessing as mp from io import BytesIO import spaces # before torch / CUDA imports import torch import numpy as np from PIL import Image from huggingface_hub import snapshot_download, hf_hub_download from streamdiffusionv2 import StreamDiffusionV2Pipeline # ---------------------------------------------------------------------------- # Config # ---------------------------------------------------------------------------- WAN_REPO = "Wan-AI/Wan2.1-T2V-1.3B" WAN_DIR = "wan_models/Wan2.1-T2V-1.3B" SDV2_REPO = "jerryfeng/StreamDiffusionV2" CKPT_DIR = "ckpts" CKPT_FOLDER = os.path.join(CKPT_DIR, "wan_causal_dmd_v2v") # 1.3B v2v checkpoint HEIGHT, WIDTH = 480, 832 SESSION_DURATION = 58 POLL_INTERVAL = 0.005 DEFAULT_PROMPT = "a psychedelic neon dream, vivid saturated colors, glowing" NOISE_SCALE = 0.8 SESSION_DIR = tempfile.gettempdir() INSTRUCTION_FILE = os.path.join(SESSION_DIR, "sdv2_prompt.txt") READY_SENTINEL = "__READY__" # Fork-safe live frame queue (created before any ZeroGPU fork). FRAME_Q = mp.get_context("fork").Queue(maxsize=512) # ---------------------------------------------------------------------------- # Weights + pipeline at module scope (ZeroGPU snapshot preload) # ---------------------------------------------------------------------------- snapshot_download( repo_id=WAN_REPO, local_dir=WAN_DIR, allow_patterns=[ "config.json", "diffusion_pytorch_model.safetensors", "Wan2.1_VAE.pth", "models_t5_umt5-xxl-enc-bf16.pth", "google/umt5-xxl/*", ], ) snapshot_download(repo_id=SDV2_REPO, local_dir=CKPT_DIR, allow_patterns=["wan_causal_dmd_v2v/*"]) # Pre-fetch the TAEHV tiny-VAE decoder weights (fast streaming decode). _TAEHV_PATH = os.path.join(CKPT_DIR, "taew2_1.pth") if not os.path.exists(_TAEHV_PATH): import urllib.request urllib.request.urlretrieve( "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", _TAEHV_PATH) device = torch.device("cuda") # StreamDiffusionV2 single-GPU streaming pipeline (rolling KV + sink tokens are # built into the model -> continuous streaming without the window-shift burst). stream = StreamDiffusionV2Pipeline( checkpoint_folder=CKPT_FOLDER, mode="single", device=device, height=HEIGHT, width=WIDTH, step=2, # 2 denoising steps (quality); TAEHV keeps it fast noise_scale=NOISE_SCALE, model_type="T2V-1.3B", use_taehv=True, # tiny-VAE decode -> much faster per-chunk -> lower lag ) PM = stream.pipeline_manager CHUNK = PM.base_chunk_size * PM.pipeline.num_frame_per_block # 4 px frames / chunk FIRST_BATCH = 1 + CHUNK # 5 px frames first def _read_prompt(): try: with open(INSTRUCTION_FILE, encoding="utf-8") as f: return f.read().strip() except FileNotFoundError: return "" def _decode_jpeg_to_tensor(jpeg_bytes): """JPEG bytes -> [C, H, W] in [-1, 1].""" im = Image.open(BytesIO(jpeg_bytes)).convert("RGB").resize((WIDTH, HEIGHT), Image.BICUBIC) arr = torch.from_numpy(np.asarray(im)).float().permute(2, 0, 1) / 255.0 return arr * 2.0 - 1.0 def _frames_to_video_tensor(frame_list): """list of [C,H,W] -> [B, C, T, H, W] bf16 on device.""" vid = torch.stack(frame_list, dim=1).unsqueeze(0) # [1, C, T, H, W] return vid.to(device=device, dtype=torch.bfloat16) def _to_data_uri(frame01): im = Image.fromarray((np.clip(frame01, 0, 1) * 255.0).astype(np.uint8)) buf = BytesIO() im.save(buf, format="JPEG", quality=70) return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode() # ---------------------------------------------------------------------------- # Gradio Server # ---------------------------------------------------------------------------- from gradio import Server from fastapi import Request from fastapi.responses import HTMLResponse app = Server() @app.api(name="run_session") @spaces.GPU(duration=60, size="xlarge") @torch.inference_mode() def run_session() -> str: # Drain stale frames, then signal the client to start streaming. try: while True: FRAME_Q.get_nowait() except pyqueue.Empty: pass yield READY_SENTINEL cur_prompt = _read_prompt() or DEFAULT_PROMPT buffer = [] session = None deadline = time.time() + SESSION_DURATION last = None while time.time() < deadline: # Live prompt change: re-encode text and reset the cross-attn cache, # WITHOUT touching the rolling self-attn KV (keeps temporal continuity). new_prompt = _read_prompt() or DEFAULT_PROMPT if new_prompt != cur_prompt and session is not None: cur_prompt = new_prompt cond = PM.pipeline.text_encoder(text_prompts=[cur_prompt]) cond["prompt_embeds"] = cond["prompt_embeds"].repeat(PM.pipeline.batch_size, 1, 1) PM.pipeline.conditional_dict = cond for blk in PM.pipeline.crossattn_cache: blk["is_init"] = False drained = 0 while drained < 256: try: buffer.append(FRAME_Q.get_nowait()) except pyqueue.Empty: break drained += 1 need = FIRST_BATCH if session is None else CHUNK if len(buffer) < need: time.sleep(POLL_INTERVAL) continue # Low latency: edit the FRESHEST frames and drop any backlog that piled # up during the previous chunk's compute, so the output tracks "now". chunk_bytes = buffer[-need:] buffer = [] frames = [_decode_jpeg_to_tensor(b) for b in chunk_bytes] vid = _frames_to_video_tensor(frames) t0 = time.time() if session is None: session, init_video = PM.start_stream_session(cur_prompt, vid, NOISE_SCALE) outs = [init_video] else: outs = PM.run_stream_batch(session, vid) dt = time.time() - t0 n = 0 for arr in outs: # each arr: [T, H, W, C] in [0,1] for fr in arr: last = _to_data_uri(fr) yield last n += 1 if n: print(f"[sdv2] {n} frames in {dt:.2f}s ({n/max(1e-3,dt):.1f} fps)", flush=True) if last is not None: yield last @app.post("/frame") async def post_frame(request: Request): body = await request.body() if body: try: FRAME_Q.put_nowait(body) except pyqueue.Full: pass return {"ok": True} @app.post("/instruction") async def post_instruction(request: Request): data = await request.json() text = (data.get("instruction", "") or "").strip() tmp = INSTRUCTION_FILE + ".tmp" with open(tmp, "w", encoding="utf-8") as f: f.write(text) os.replace(tmp, INSTRUCTION_FILE) return {"ok": True} @app.get("/", response_class=HTMLResponse) async def homepage(): here = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(here, "index.html"), encoding="utf-8") as f: return f.read() app.launch(show_error=True)