Spaces:
Running on Zero
Running on Zero
File size: 7,225 Bytes
e88b235 38c1a39 e88b235 c50961a e88b235 38c1a39 e88b235 a4fdd33 e88b235 c50961a e88b235 c50961a e88b235 7c7586e e88b235 c50961a e88b235 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | import os
import time
import base64
import tempfile
import queue as pyqueue
import multiprocessing as mp
from io import BytesIO
import spaces # before torch / CUDA imports
import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, hf_hub_download
from streamdiffusionv2 import StreamDiffusionV2Pipeline
# ----------------------------------------------------------------------------
# Config
# ----------------------------------------------------------------------------
WAN_REPO = "Wan-AI/Wan2.1-T2V-1.3B"
WAN_DIR = "wan_models/Wan2.1-T2V-1.3B"
SDV2_REPO = "jerryfeng/StreamDiffusionV2"
CKPT_DIR = "ckpts"
CKPT_FOLDER = os.path.join(CKPT_DIR, "wan_causal_dmd_v2v") # 1.3B v2v checkpoint
HEIGHT, WIDTH = 480, 832
SESSION_DURATION = 58
POLL_INTERVAL = 0.005
DEFAULT_PROMPT = "a psychedelic neon dream, vivid saturated colors, glowing"
NOISE_SCALE = 0.8
SESSION_DIR = tempfile.gettempdir()
INSTRUCTION_FILE = os.path.join(SESSION_DIR, "sdv2_prompt.txt")
READY_SENTINEL = "__READY__"
# Fork-safe live frame queue (created before any ZeroGPU fork).
FRAME_Q = mp.get_context("fork").Queue(maxsize=512)
# ----------------------------------------------------------------------------
# Weights + pipeline at module scope (ZeroGPU snapshot preload)
# ----------------------------------------------------------------------------
snapshot_download(
repo_id=WAN_REPO, local_dir=WAN_DIR,
allow_patterns=[
"config.json", "diffusion_pytorch_model.safetensors", "Wan2.1_VAE.pth",
"models_t5_umt5-xxl-enc-bf16.pth", "google/umt5-xxl/*",
],
)
snapshot_download(repo_id=SDV2_REPO, local_dir=CKPT_DIR,
allow_patterns=["wan_causal_dmd_v2v/*"])
# Pre-fetch the TAEHV tiny-VAE decoder weights (fast streaming decode).
_TAEHV_PATH = os.path.join(CKPT_DIR, "taew2_1.pth")
if not os.path.exists(_TAEHV_PATH):
import urllib.request
urllib.request.urlretrieve(
"https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", _TAEHV_PATH)
device = torch.device("cuda")
# StreamDiffusionV2 single-GPU streaming pipeline (rolling KV + sink tokens are
# built into the model -> continuous streaming without the window-shift burst).
stream = StreamDiffusionV2Pipeline(
checkpoint_folder=CKPT_FOLDER,
mode="single",
device=device,
height=HEIGHT,
width=WIDTH,
step=2, # 2 denoising steps (quality); TAEHV keeps it fast
noise_scale=NOISE_SCALE,
model_type="T2V-1.3B",
use_taehv=True, # tiny-VAE decode -> much faster per-chunk -> lower lag
)
PM = stream.pipeline_manager
CHUNK = PM.base_chunk_size * PM.pipeline.num_frame_per_block # 4 px frames / chunk
FIRST_BATCH = 1 + CHUNK # 5 px frames first
def _read_prompt():
try:
with open(INSTRUCTION_FILE, encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
return ""
def _decode_jpeg_to_tensor(jpeg_bytes):
"""JPEG bytes -> [C, H, W] in [-1, 1]."""
im = Image.open(BytesIO(jpeg_bytes)).convert("RGB").resize((WIDTH, HEIGHT), Image.BICUBIC)
arr = torch.from_numpy(np.asarray(im)).float().permute(2, 0, 1) / 255.0
return arr * 2.0 - 1.0
def _frames_to_video_tensor(frame_list):
"""list of [C,H,W] -> [B, C, T, H, W] bf16 on device."""
vid = torch.stack(frame_list, dim=1).unsqueeze(0) # [1, C, T, H, W]
return vid.to(device=device, dtype=torch.bfloat16)
def _to_data_uri(frame01):
im = Image.fromarray((np.clip(frame01, 0, 1) * 255.0).astype(np.uint8))
buf = BytesIO()
im.save(buf, format="JPEG", quality=70)
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()
# ----------------------------------------------------------------------------
# Gradio Server
# ----------------------------------------------------------------------------
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse
app = Server()
@app.api(name="run_session")
@spaces.GPU(duration=60, size="xlarge")
@torch.inference_mode()
def run_session() -> str:
# Drain stale frames, then signal the client to start streaming.
try:
while True:
FRAME_Q.get_nowait()
except pyqueue.Empty:
pass
yield READY_SENTINEL
cur_prompt = _read_prompt() or DEFAULT_PROMPT
buffer = []
session = None
deadline = time.time() + SESSION_DURATION
last = None
while time.time() < deadline:
# Live prompt change: re-encode text and reset the cross-attn cache,
# WITHOUT touching the rolling self-attn KV (keeps temporal continuity).
new_prompt = _read_prompt() or DEFAULT_PROMPT
if new_prompt != cur_prompt and session is not None:
cur_prompt = new_prompt
cond = PM.pipeline.text_encoder(text_prompts=[cur_prompt])
cond["prompt_embeds"] = cond["prompt_embeds"].repeat(PM.pipeline.batch_size, 1, 1)
PM.pipeline.conditional_dict = cond
for blk in PM.pipeline.crossattn_cache:
blk["is_init"] = False
drained = 0
while drained < 256:
try:
buffer.append(FRAME_Q.get_nowait())
except pyqueue.Empty:
break
drained += 1
need = FIRST_BATCH if session is None else CHUNK
if len(buffer) < need:
time.sleep(POLL_INTERVAL)
continue
# Low latency: edit the FRESHEST frames and drop any backlog that piled
# up during the previous chunk's compute, so the output tracks "now".
chunk_bytes = buffer[-need:]
buffer = []
frames = [_decode_jpeg_to_tensor(b) for b in chunk_bytes]
vid = _frames_to_video_tensor(frames)
t0 = time.time()
if session is None:
session, init_video = PM.start_stream_session(cur_prompt, vid, NOISE_SCALE)
outs = [init_video]
else:
outs = PM.run_stream_batch(session, vid)
dt = time.time() - t0
n = 0
for arr in outs: # each arr: [T, H, W, C] in [0,1]
for fr in arr:
last = _to_data_uri(fr)
yield last
n += 1
if n:
print(f"[sdv2] {n} frames in {dt:.2f}s ({n/max(1e-3,dt):.1f} fps)", flush=True)
if last is not None:
yield last
@app.post("/frame")
async def post_frame(request: Request):
body = await request.body()
if body:
try:
FRAME_Q.put_nowait(body)
except pyqueue.Full:
pass
return {"ok": True}
@app.post("/instruction")
async def post_instruction(request: Request):
data = await request.json()
text = (data.get("instruction", "") or "").strip()
tmp = INSTRUCTION_FILE + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
f.write(text)
os.replace(tmp, INSTRUCTION_FILE)
return {"ok": True}
@app.get("/", response_class=HTMLResponse)
async def homepage():
here = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(here, "index.html"), encoding="utf-8") as f:
return f.read()
app.launch(show_error=True)
|