multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
c50961a verified
Raw
History Blame Contribute Delete
7.23 kB
import os
import time
import base64
import tempfile
import queue as pyqueue
import multiprocessing as mp
from io import BytesIO
import spaces # before torch / CUDA imports
import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, hf_hub_download
from streamdiffusionv2 import StreamDiffusionV2Pipeline
# ----------------------------------------------------------------------------
# Config
# ----------------------------------------------------------------------------
WAN_REPO = "Wan-AI/Wan2.1-T2V-1.3B"
WAN_DIR = "wan_models/Wan2.1-T2V-1.3B"
SDV2_REPO = "jerryfeng/StreamDiffusionV2"
CKPT_DIR = "ckpts"
CKPT_FOLDER = os.path.join(CKPT_DIR, "wan_causal_dmd_v2v") # 1.3B v2v checkpoint
HEIGHT, WIDTH = 480, 832
SESSION_DURATION = 58
POLL_INTERVAL = 0.005
DEFAULT_PROMPT = "a psychedelic neon dream, vivid saturated colors, glowing"
NOISE_SCALE = 0.8
SESSION_DIR = tempfile.gettempdir()
INSTRUCTION_FILE = os.path.join(SESSION_DIR, "sdv2_prompt.txt")
READY_SENTINEL = "__READY__"
# Fork-safe live frame queue (created before any ZeroGPU fork).
FRAME_Q = mp.get_context("fork").Queue(maxsize=512)
# ----------------------------------------------------------------------------
# Weights + pipeline at module scope (ZeroGPU snapshot preload)
# ----------------------------------------------------------------------------
snapshot_download(
repo_id=WAN_REPO, local_dir=WAN_DIR,
allow_patterns=[
"config.json", "diffusion_pytorch_model.safetensors", "Wan2.1_VAE.pth",
"models_t5_umt5-xxl-enc-bf16.pth", "google/umt5-xxl/*",
],
)
snapshot_download(repo_id=SDV2_REPO, local_dir=CKPT_DIR,
allow_patterns=["wan_causal_dmd_v2v/*"])
# Pre-fetch the TAEHV tiny-VAE decoder weights (fast streaming decode).
_TAEHV_PATH = os.path.join(CKPT_DIR, "taew2_1.pth")
if not os.path.exists(_TAEHV_PATH):
import urllib.request
urllib.request.urlretrieve(
"https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", _TAEHV_PATH)
device = torch.device("cuda")
# StreamDiffusionV2 single-GPU streaming pipeline (rolling KV + sink tokens are
# built into the model -> continuous streaming without the window-shift burst).
stream = StreamDiffusionV2Pipeline(
checkpoint_folder=CKPT_FOLDER,
mode="single",
device=device,
height=HEIGHT,
width=WIDTH,
step=2, # 2 denoising steps (quality); TAEHV keeps it fast
noise_scale=NOISE_SCALE,
model_type="T2V-1.3B",
use_taehv=True, # tiny-VAE decode -> much faster per-chunk -> lower lag
)
PM = stream.pipeline_manager
CHUNK = PM.base_chunk_size * PM.pipeline.num_frame_per_block # 4 px frames / chunk
FIRST_BATCH = 1 + CHUNK # 5 px frames first
def _read_prompt():
try:
with open(INSTRUCTION_FILE, encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
return ""
def _decode_jpeg_to_tensor(jpeg_bytes):
"""JPEG bytes -> [C, H, W] in [-1, 1]."""
im = Image.open(BytesIO(jpeg_bytes)).convert("RGB").resize((WIDTH, HEIGHT), Image.BICUBIC)
arr = torch.from_numpy(np.asarray(im)).float().permute(2, 0, 1) / 255.0
return arr * 2.0 - 1.0
def _frames_to_video_tensor(frame_list):
"""list of [C,H,W] -> [B, C, T, H, W] bf16 on device."""
vid = torch.stack(frame_list, dim=1).unsqueeze(0) # [1, C, T, H, W]
return vid.to(device=device, dtype=torch.bfloat16)
def _to_data_uri(frame01):
im = Image.fromarray((np.clip(frame01, 0, 1) * 255.0).astype(np.uint8))
buf = BytesIO()
im.save(buf, format="JPEG", quality=70)
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()
# ----------------------------------------------------------------------------
# Gradio Server
# ----------------------------------------------------------------------------
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse
app = Server()
@app.api(name="run_session")
@spaces.GPU(duration=60, size="xlarge")
@torch.inference_mode()
def run_session() -> str:
# Drain stale frames, then signal the client to start streaming.
try:
while True:
FRAME_Q.get_nowait()
except pyqueue.Empty:
pass
yield READY_SENTINEL
cur_prompt = _read_prompt() or DEFAULT_PROMPT
buffer = []
session = None
deadline = time.time() + SESSION_DURATION
last = None
while time.time() < deadline:
# Live prompt change: re-encode text and reset the cross-attn cache,
# WITHOUT touching the rolling self-attn KV (keeps temporal continuity).
new_prompt = _read_prompt() or DEFAULT_PROMPT
if new_prompt != cur_prompt and session is not None:
cur_prompt = new_prompt
cond = PM.pipeline.text_encoder(text_prompts=[cur_prompt])
cond["prompt_embeds"] = cond["prompt_embeds"].repeat(PM.pipeline.batch_size, 1, 1)
PM.pipeline.conditional_dict = cond
for blk in PM.pipeline.crossattn_cache:
blk["is_init"] = False
drained = 0
while drained < 256:
try:
buffer.append(FRAME_Q.get_nowait())
except pyqueue.Empty:
break
drained += 1
need = FIRST_BATCH if session is None else CHUNK
if len(buffer) < need:
time.sleep(POLL_INTERVAL)
continue
# Low latency: edit the FRESHEST frames and drop any backlog that piled
# up during the previous chunk's compute, so the output tracks "now".
chunk_bytes = buffer[-need:]
buffer = []
frames = [_decode_jpeg_to_tensor(b) for b in chunk_bytes]
vid = _frames_to_video_tensor(frames)
t0 = time.time()
if session is None:
session, init_video = PM.start_stream_session(cur_prompt, vid, NOISE_SCALE)
outs = [init_video]
else:
outs = PM.run_stream_batch(session, vid)
dt = time.time() - t0
n = 0
for arr in outs: # each arr: [T, H, W, C] in [0,1]
for fr in arr:
last = _to_data_uri(fr)
yield last
n += 1
if n:
print(f"[sdv2] {n} frames in {dt:.2f}s ({n/max(1e-3,dt):.1f} fps)", flush=True)
if last is not None:
yield last
@app.post("/frame")
async def post_frame(request: Request):
body = await request.body()
if body:
try:
FRAME_Q.put_nowait(body)
except pyqueue.Full:
pass
return {"ok": True}
@app.post("/instruction")
async def post_instruction(request: Request):
data = await request.json()
text = (data.get("instruction", "") or "").strip()
tmp = INSTRUCTION_FILE + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
f.write(text)
os.replace(tmp, INSTRUCTION_FILE)
return {"ok": True}
@app.get("/", response_class=HTMLResponse)
async def homepage():
here = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(here, "index.html"), encoding="utf-8") as f:
return f.read()
app.launch(show_error=True)