File size: 7,225 Bytes
e88b235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38c1a39
 
 
 
 
 
 
e88b235
 
 
 
 
 
 
 
 
 
c50961a
e88b235
 
38c1a39
e88b235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4fdd33
e88b235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c50961a
e88b235
 
 
 
 
 
c50961a
 
 
 
 
 
 
 
 
 
 
e88b235
 
 
 
 
 
 
 
 
 
 
 
 
7c7586e
 
 
 
 
e88b235
 
 
 
c50961a
e88b235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import time
import base64
import tempfile
import queue as pyqueue
import multiprocessing as mp
from io import BytesIO

import spaces  # before torch / CUDA imports

import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, hf_hub_download

from streamdiffusionv2 import StreamDiffusionV2Pipeline

# ----------------------------------------------------------------------------
# Config
# ----------------------------------------------------------------------------
WAN_REPO = "Wan-AI/Wan2.1-T2V-1.3B"
WAN_DIR = "wan_models/Wan2.1-T2V-1.3B"
SDV2_REPO = "jerryfeng/StreamDiffusionV2"
CKPT_DIR = "ckpts"
CKPT_FOLDER = os.path.join(CKPT_DIR, "wan_causal_dmd_v2v")  # 1.3B v2v checkpoint

HEIGHT, WIDTH = 480, 832
SESSION_DURATION = 58
POLL_INTERVAL = 0.005
DEFAULT_PROMPT = "a psychedelic neon dream, vivid saturated colors, glowing"
NOISE_SCALE = 0.8

SESSION_DIR = tempfile.gettempdir()
INSTRUCTION_FILE = os.path.join(SESSION_DIR, "sdv2_prompt.txt")
READY_SENTINEL = "__READY__"

# Fork-safe live frame queue (created before any ZeroGPU fork).
FRAME_Q = mp.get_context("fork").Queue(maxsize=512)

# ----------------------------------------------------------------------------
# Weights + pipeline at module scope (ZeroGPU snapshot preload)
# ----------------------------------------------------------------------------
snapshot_download(
    repo_id=WAN_REPO, local_dir=WAN_DIR,
    allow_patterns=[
        "config.json", "diffusion_pytorch_model.safetensors", "Wan2.1_VAE.pth",
        "models_t5_umt5-xxl-enc-bf16.pth", "google/umt5-xxl/*",
    ],
)
snapshot_download(repo_id=SDV2_REPO, local_dir=CKPT_DIR,
                  allow_patterns=["wan_causal_dmd_v2v/*"])

# Pre-fetch the TAEHV tiny-VAE decoder weights (fast streaming decode).
_TAEHV_PATH = os.path.join(CKPT_DIR, "taew2_1.pth")
if not os.path.exists(_TAEHV_PATH):
    import urllib.request
    urllib.request.urlretrieve(
        "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", _TAEHV_PATH)

device = torch.device("cuda")

# StreamDiffusionV2 single-GPU streaming pipeline (rolling KV + sink tokens are
# built into the model -> continuous streaming without the window-shift burst).
stream = StreamDiffusionV2Pipeline(
    checkpoint_folder=CKPT_FOLDER,
    mode="single",
    device=device,
    height=HEIGHT,
    width=WIDTH,
    step=2,            # 2 denoising steps (quality); TAEHV keeps it fast
    noise_scale=NOISE_SCALE,
    model_type="T2V-1.3B",
    use_taehv=True,    # tiny-VAE decode -> much faster per-chunk -> lower lag
)
PM = stream.pipeline_manager
CHUNK = PM.base_chunk_size * PM.pipeline.num_frame_per_block   # 4 px frames / chunk
FIRST_BATCH = 1 + CHUNK                                        # 5 px frames first


def _read_prompt():
    try:
        with open(INSTRUCTION_FILE, encoding="utf-8") as f:
            return f.read().strip()
    except FileNotFoundError:
        return ""


def _decode_jpeg_to_tensor(jpeg_bytes):
    """JPEG bytes -> [C, H, W] in [-1, 1]."""
    im = Image.open(BytesIO(jpeg_bytes)).convert("RGB").resize((WIDTH, HEIGHT), Image.BICUBIC)
    arr = torch.from_numpy(np.asarray(im)).float().permute(2, 0, 1) / 255.0
    return arr * 2.0 - 1.0


def _frames_to_video_tensor(frame_list):
    """list of [C,H,W] -> [B, C, T, H, W] bf16 on device."""
    vid = torch.stack(frame_list, dim=1).unsqueeze(0)  # [1, C, T, H, W]
    return vid.to(device=device, dtype=torch.bfloat16)


def _to_data_uri(frame01):
    im = Image.fromarray((np.clip(frame01, 0, 1) * 255.0).astype(np.uint8))
    buf = BytesIO()
    im.save(buf, format="JPEG", quality=70)
    return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode()


# ----------------------------------------------------------------------------
# Gradio Server
# ----------------------------------------------------------------------------
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse

app = Server()


@app.api(name="run_session")
@spaces.GPU(duration=60, size="xlarge")
@torch.inference_mode()
def run_session() -> str:
    # Drain stale frames, then signal the client to start streaming.
    try:
        while True:
            FRAME_Q.get_nowait()
    except pyqueue.Empty:
        pass
    yield READY_SENTINEL

    cur_prompt = _read_prompt() or DEFAULT_PROMPT
    buffer = []
    session = None
    deadline = time.time() + SESSION_DURATION
    last = None

    while time.time() < deadline:
        # Live prompt change: re-encode text and reset the cross-attn cache,
        # WITHOUT touching the rolling self-attn KV (keeps temporal continuity).
        new_prompt = _read_prompt() or DEFAULT_PROMPT
        if new_prompt != cur_prompt and session is not None:
            cur_prompt = new_prompt
            cond = PM.pipeline.text_encoder(text_prompts=[cur_prompt])
            cond["prompt_embeds"] = cond["prompt_embeds"].repeat(PM.pipeline.batch_size, 1, 1)
            PM.pipeline.conditional_dict = cond
            for blk in PM.pipeline.crossattn_cache:
                blk["is_init"] = False

        drained = 0
        while drained < 256:
            try:
                buffer.append(FRAME_Q.get_nowait())
            except pyqueue.Empty:
                break
            drained += 1

        need = FIRST_BATCH if session is None else CHUNK
        if len(buffer) < need:
            time.sleep(POLL_INTERVAL)
            continue

        # Low latency: edit the FRESHEST frames and drop any backlog that piled
        # up during the previous chunk's compute, so the output tracks "now".
        chunk_bytes = buffer[-need:]
        buffer = []
        frames = [_decode_jpeg_to_tensor(b) for b in chunk_bytes]
        vid = _frames_to_video_tensor(frames)

        t0 = time.time()
        if session is None:
            session, init_video = PM.start_stream_session(cur_prompt, vid, NOISE_SCALE)
            outs = [init_video]
        else:
            outs = PM.run_stream_batch(session, vid)
        dt = time.time() - t0

        n = 0
        for arr in outs:                       # each arr: [T, H, W, C] in [0,1]
            for fr in arr:
                last = _to_data_uri(fr)
                yield last
                n += 1
        if n:
            print(f"[sdv2] {n} frames in {dt:.2f}s ({n/max(1e-3,dt):.1f} fps)", flush=True)

    if last is not None:
        yield last


@app.post("/frame")
async def post_frame(request: Request):
    body = await request.body()
    if body:
        try:
            FRAME_Q.put_nowait(body)
        except pyqueue.Full:
            pass
    return {"ok": True}


@app.post("/instruction")
async def post_instruction(request: Request):
    data = await request.json()
    text = (data.get("instruction", "") or "").strip()
    tmp = INSTRUCTION_FILE + ".tmp"
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, INSTRUCTION_FILE)
    return {"ok": True}


@app.get("/", response_class=HTMLResponse)
async def homepage():
    here = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(here, "index.html"), encoding="utf-8") as f:
        return f.read()


app.launch(show_error=True)