Spaces:
Running on Zero
Running on Zero
| import os | |
| import time | |
| import base64 | |
| import tempfile | |
| import queue as pyqueue | |
| import multiprocessing as mp | |
| from io import BytesIO | |
| import spaces # before torch / CUDA imports | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| from streamdiffusionv2 import StreamDiffusionV2Pipeline | |
| # ---------------------------------------------------------------------------- | |
| # Config | |
| # ---------------------------------------------------------------------------- | |
| WAN_REPO = "Wan-AI/Wan2.1-T2V-1.3B" | |
| WAN_DIR = "wan_models/Wan2.1-T2V-1.3B" | |
| SDV2_REPO = "jerryfeng/StreamDiffusionV2" | |
| CKPT_DIR = "ckpts" | |
| CKPT_FOLDER = os.path.join(CKPT_DIR, "wan_causal_dmd_v2v") # 1.3B v2v checkpoint | |
| HEIGHT, WIDTH = 480, 832 | |
| SESSION_DURATION = 58 | |
| POLL_INTERVAL = 0.005 | |
| DEFAULT_PROMPT = "a psychedelic neon dream, vivid saturated colors, glowing" | |
| NOISE_SCALE = 0.8 | |
| SESSION_DIR = tempfile.gettempdir() | |
| INSTRUCTION_FILE = os.path.join(SESSION_DIR, "sdv2_prompt.txt") | |
| READY_SENTINEL = "__READY__" | |
| # Fork-safe live frame queue (created before any ZeroGPU fork). | |
| FRAME_Q = mp.get_context("fork").Queue(maxsize=512) | |
| # ---------------------------------------------------------------------------- | |
| # Weights + pipeline at module scope (ZeroGPU snapshot preload) | |
| # ---------------------------------------------------------------------------- | |
| snapshot_download( | |
| repo_id=WAN_REPO, local_dir=WAN_DIR, | |
| allow_patterns=[ | |
| "config.json", "diffusion_pytorch_model.safetensors", "Wan2.1_VAE.pth", | |
| "models_t5_umt5-xxl-enc-bf16.pth", "google/umt5-xxl/*", | |
| ], | |
| ) | |
| snapshot_download(repo_id=SDV2_REPO, local_dir=CKPT_DIR, | |
| allow_patterns=["wan_causal_dmd_v2v/*"]) | |
| # Pre-fetch the TAEHV tiny-VAE decoder weights (fast streaming decode). | |
| _TAEHV_PATH = os.path.join(CKPT_DIR, "taew2_1.pth") | |
| if not os.path.exists(_TAEHV_PATH): | |
| import urllib.request | |
| urllib.request.urlretrieve( | |
| "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", _TAEHV_PATH) | |
| device = torch.device("cuda") | |
| # StreamDiffusionV2 single-GPU streaming pipeline (rolling KV + sink tokens are | |
| # built into the model -> continuous streaming without the window-shift burst). | |
| stream = StreamDiffusionV2Pipeline( | |
| checkpoint_folder=CKPT_FOLDER, | |
| mode="single", | |
| device=device, | |
| height=HEIGHT, | |
| width=WIDTH, | |
| step=2, # 2 denoising steps (quality); TAEHV keeps it fast | |
| noise_scale=NOISE_SCALE, | |
| model_type="T2V-1.3B", | |
| use_taehv=True, # tiny-VAE decode -> much faster per-chunk -> lower lag | |
| ) | |
| PM = stream.pipeline_manager | |
| CHUNK = PM.base_chunk_size * PM.pipeline.num_frame_per_block # 4 px frames / chunk | |
| FIRST_BATCH = 1 + CHUNK # 5 px frames first | |
| def _read_prompt(): | |
| try: | |
| with open(INSTRUCTION_FILE, encoding="utf-8") as f: | |
| return f.read().strip() | |
| except FileNotFoundError: | |
| return "" | |
| def _decode_jpeg_to_tensor(jpeg_bytes): | |
| """JPEG bytes -> [C, H, W] in [-1, 1].""" | |
| im = Image.open(BytesIO(jpeg_bytes)).convert("RGB").resize((WIDTH, HEIGHT), Image.BICUBIC) | |
| arr = torch.from_numpy(np.asarray(im)).float().permute(2, 0, 1) / 255.0 | |
| return arr * 2.0 - 1.0 | |
| def _frames_to_video_tensor(frame_list): | |
| """list of [C,H,W] -> [B, C, T, H, W] bf16 on device.""" | |
| vid = torch.stack(frame_list, dim=1).unsqueeze(0) # [1, C, T, H, W] | |
| return vid.to(device=device, dtype=torch.bfloat16) | |
| def _to_data_uri(frame01): | |
| im = Image.fromarray((np.clip(frame01, 0, 1) * 255.0).astype(np.uint8)) | |
| buf = BytesIO() | |
| im.save(buf, format="JPEG", quality=70) | |
| return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode() | |
| # ---------------------------------------------------------------------------- | |
| # Gradio Server | |
| # ---------------------------------------------------------------------------- | |
| from gradio import Server | |
| from fastapi import Request | |
| from fastapi.responses import HTMLResponse | |
| app = Server() | |
| def run_session() -> str: | |
| # Drain stale frames, then signal the client to start streaming. | |
| try: | |
| while True: | |
| FRAME_Q.get_nowait() | |
| except pyqueue.Empty: | |
| pass | |
| yield READY_SENTINEL | |
| cur_prompt = _read_prompt() or DEFAULT_PROMPT | |
| buffer = [] | |
| session = None | |
| deadline = time.time() + SESSION_DURATION | |
| last = None | |
| while time.time() < deadline: | |
| # Live prompt change: re-encode text and reset the cross-attn cache, | |
| # WITHOUT touching the rolling self-attn KV (keeps temporal continuity). | |
| new_prompt = _read_prompt() or DEFAULT_PROMPT | |
| if new_prompt != cur_prompt and session is not None: | |
| cur_prompt = new_prompt | |
| cond = PM.pipeline.text_encoder(text_prompts=[cur_prompt]) | |
| cond["prompt_embeds"] = cond["prompt_embeds"].repeat(PM.pipeline.batch_size, 1, 1) | |
| PM.pipeline.conditional_dict = cond | |
| for blk in PM.pipeline.crossattn_cache: | |
| blk["is_init"] = False | |
| drained = 0 | |
| while drained < 256: | |
| try: | |
| buffer.append(FRAME_Q.get_nowait()) | |
| except pyqueue.Empty: | |
| break | |
| drained += 1 | |
| need = FIRST_BATCH if session is None else CHUNK | |
| if len(buffer) < need: | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| # Low latency: edit the FRESHEST frames and drop any backlog that piled | |
| # up during the previous chunk's compute, so the output tracks "now". | |
| chunk_bytes = buffer[-need:] | |
| buffer = [] | |
| frames = [_decode_jpeg_to_tensor(b) for b in chunk_bytes] | |
| vid = _frames_to_video_tensor(frames) | |
| t0 = time.time() | |
| if session is None: | |
| session, init_video = PM.start_stream_session(cur_prompt, vid, NOISE_SCALE) | |
| outs = [init_video] | |
| else: | |
| outs = PM.run_stream_batch(session, vid) | |
| dt = time.time() - t0 | |
| n = 0 | |
| for arr in outs: # each arr: [T, H, W, C] in [0,1] | |
| for fr in arr: | |
| last = _to_data_uri(fr) | |
| yield last | |
| n += 1 | |
| if n: | |
| print(f"[sdv2] {n} frames in {dt:.2f}s ({n/max(1e-3,dt):.1f} fps)", flush=True) | |
| if last is not None: | |
| yield last | |
| async def post_frame(request: Request): | |
| body = await request.body() | |
| if body: | |
| try: | |
| FRAME_Q.put_nowait(body) | |
| except pyqueue.Full: | |
| pass | |
| return {"ok": True} | |
| async def post_instruction(request: Request): | |
| data = await request.json() | |
| text = (data.get("instruction", "") or "").strip() | |
| tmp = INSTRUCTION_FILE + ".tmp" | |
| with open(tmp, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| os.replace(tmp, INSTRUCTION_FILE) | |
| return {"ok": True} | |
| async def homepage(): | |
| here = os.path.dirname(os.path.abspath(__file__)) | |
| with open(os.path.join(here, "index.html"), encoding="utf-8") as f: | |
| return f.read() | |
| app.launch(show_error=True) | |