rtx6000test / app.py
multimodalart's picture
multimodalart HF Staff
Speed: text-embed cache + lighter 384x640 profile + PyAV (VAE compile opt-in)
2eeb78f
Raw
History Blame Contribute Delete
21.3 kB
"""DreamVerse LTX-2 — the *real* FastVideo frontend on `gradio.Server`.
This serves DreamVerse's actual Next.js UI (static-exported to `frontend/`)
and implements its `/ws` streaming protocol on `gradio.Server`'s FastAPI,
backed by the real FastVideo LTX-2 pipeline with cross-segment continuation.
Flow (matches the frontend's reducer exactly):
client connects /ws ──▶ session_init_v2 {initial_rollout_prompt, ...}
server ──▶ queue_status, gpu_assigned, ltx2_stream_start
per segment:
ltx2_segment_start ──▶ (generate LTX-2, conditioned on prev tail) ──▶
media_init {stream_id, mime} ──▶ [binary fMP4 chunks] ──▶
media_segment_complete ──▶ ltx2_segment_complete
client ──▶ append_prompt {prompt} grows the rollout (continuation).
GPU work is isolated in `@spaces.GPU` (ZeroGPU-shaped). fMP4 muxing mirrors
DreamVerse's `av_streaming.py` (libx264 ultrafast/zerolatency, baseline, GOP
12, +frag_keyframe+empty_moov+default_base_moof) so the frontend's MSE
SourceBuffer plays it. FA4 skipped (sm_120) → TORCH_SDPA. See ZEROGPU_NOTES.md.
"""
import os
import asyncio
import json
import subprocess
import tempfile
import uuid
import wave
os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "TORCH_SDPA")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# Keep the AOTInductor/compile cache on the big overlay disk (not a small tmpfs).
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/home/user/.cache/dreamverse_inductor")
HERE_EARLY = os.path.dirname(os.path.abspath(__file__))
NVFP4 = os.getenv("DREAMVERSE_NVFP4") == "1"
# HF bucket holding the prebaked sm_120 NVFP4 flashinfer kernels (no runtime JIT).
FP4_CACHE_REPO = os.getenv("DREAMVERSE_FP4_CACHE_REPO", "multimodalart/dreamverse-flashinfer-cache")
def _setup_nvfp4_env():
"""ZeroGPU-clean NVFP4: use flashinfer's AOT path with PREBAKED sm_120 FP4
kernels, so the runtime loads `.so` directly — no ninja, no nvcc, no JIT.
flashinfer's `build_and_load` short-circuits to `load(aot_path)` when the
kernel `.so` exists under `FLASHINFER_AOT_DIR/<op>/<op>.so`. We pull the two
prebaked `.so` (fp4_quantization_120f, fp4_gemm_cutlass_sm120) from the HF
bucket and point flashinfer's AOT dir at them. The JIT-cache dir is NOT
relocatable (ninja recompiles), but the AOT path IS — verified 0.2s load.
flashinfer must be installed --no-deps (keeps torch 2.11+cu130). See
NVFP4_ZEROGPU.md.
"""
import pathlib
os.environ.setdefault("FLASHINFER_CUDA_ARCH_LIST", "12.0f") # arch w/o nvcc
os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
try:
from huggingface_hub import snapshot_download
p = snapshot_download(repo_id=FP4_CACHE_REPO, repo_type="dataset", allow_patterns="aot/*")
aot = os.path.join(p, "aot")
if os.path.isdir(aot):
import flashinfer.jit.env as _fienv
_fienv.FLASHINFER_AOT_DIR = pathlib.Path(aot)
print(f"[nvfp4] prebaked FP4 AOT kernels from {FP4_CACHE_REPO} (no JIT)", flush=True)
else:
print("[nvfp4] AOT dir missing in bucket; flashinfer may JIT", flush=True)
except Exception as e:
print(f"[nvfp4] AOT pull failed ({e}); flashinfer will JIT on first use", flush=True)
import numpy as np
import spaces
if NVFP4:
_setup_nvfp4_env()
# AOTI DiT compiler — installs a hook on LTXModel._process_transformer_blocks
# that loads the prebaked (weight-less, regional) AOTInductor .pt2 on first
# forward. At module scope so the hook is live in FastVideo's spawned worker.
# (No-op in NVFP4 mode: FP4 runs eager + prebaked kernels — see aoti_dit.)
import aoti_dit
aoti_dit.install()
# Memoize Gemma text embeddings across same-prompt rollout segments.
import text_cache
text_cache.install()
from gradio import Server
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
MODEL_ID = os.getenv("LTX2_MODEL_PATH", "FastVideo/LTX2-Distilled-Diffusers")
HERE = os.path.dirname(os.path.abspath(__file__))
FRONTEND = os.path.join(HERE, "frontend")
# Render settings (modest for single-GPU latency). (num_frames-1) % 8 == 0.
# Lighter default profile for a better realtime ratio (lower res cuts the DiT
# seq-len AND the VAE decode without shortening the clip). 384x640 ~= half the
# pixels of 512x768. (H,W % 32 == 0; (frames-1) % 8 == 0.)
GEN_HEIGHT = int(os.getenv("DREAMVERSE_HEIGHT", "384"))
GEN_WIDTH = int(os.getenv("DREAMVERSE_WIDTH", "640"))
GEN_FRAMES = int(os.getenv("DREAMVERSE_FRAMES", "65"))
GEN_STEPS = int(os.getenv("DREAMVERSE_STEPS", "5"))
GEN_FPS = 24
GPU_DURATION = int(os.getenv("DREAMVERSE_GPU_DURATION", "300"))
# Continuation conditioning (mirrors dreamverse/video_generation.py).
COND_FRAMES = 9 # tail frames carried forward (8k+1)
MIME = 'video/mp4; codecs="avc1.42E01E,mp4a.40.2"'
FFMPEG = os.getenv("FASTVIDEO_FFMPEG_BIN", "ffmpeg")
app = Server()
_generator = None
_gen_lock_thread = None
_gpu_lock = asyncio.Lock() # single-GPU: serialize generation across sockets
# --------------------------------------------------------------------------
# Model (lazy; spawn-safe — never load at module scope, see ZEROGPU_NOTES.md)
# --------------------------------------------------------------------------
def _get_generator():
global _generator, _gen_lock_thread
if _generator is not None:
return _generator
import threading
if _gen_lock_thread is None:
_gen_lock_thread = threading.Lock()
with _gen_lock_thread:
if _generator is not None:
return _generator
from fastvideo import VideoGenerator
from fastvideo.api.schema import (GeneratorConfig, EngineConfig,
CompileConfig, QuantizationConfig)
mode = "NVFP4 (FP4 DiT, prebaked kernels)" if NVFP4 else "bf16 + AOTI"
print(f"[dreamverse] loading {MODEL_ID} [{mode}] (first segment — slow) ...", flush=True)
engine = EngineConfig(num_gpus=1)
# VAE/text-encoder compile is OFF by default: it's JIT torch.compile,
# which costs ~180s of first-call compile here (VAE decode + the encode
# path used for continuation) for only ~0.5s warm gain — and JIT
# recompiles on every ZeroGPU cold start, breaking the AOT/clean
# property. Opt in with DREAMVERSE_COMPILE_VAE=1 on a dedicated GPU.
# (The ZeroGPU-clean way to speed the VAE is to AOTI it, like the DiT.)
if os.getenv("DREAMVERSE_COMPILE_VAE") == "1":
engine.compile = CompileConfig(
enabled=False, text_encoder_enabled=True, vae_enabled=True,
backend="inductor", fullgraph=False, mode=None, dynamic=True,
)
print("[dreamverse] VAE+text-encoder torch.compile ON (first calls slow)", flush=True)
if NVFP4:
engine.quantization = QuantizationConfig(transformer_quant="NVFP4")
_generator = VideoGenerator.from_pretrained(
config=GeneratorConfig(model_path=MODEL_ID, engine=engine))
print("[dreamverse] generator ready", flush=True)
return _generator
def _mock_segment(seed: int):
"""Synthetic frames+audio for control-plane testing without the GPU model."""
n = GEN_FRAMES
rng = np.random.default_rng(seed)
base = rng.integers(0, 255, size=(3,), dtype=np.uint8)
frames = []
for i in range(n):
f = np.zeros((GEN_HEIGHT, GEN_WIDTH, 3), dtype=np.uint8)
f[:] = (base + i * 3) % 255
frames.append(f)
import torch
audio = torch.zeros(int(n / GEN_FPS * 24000))
return frames, audio, 24000
@spaces.GPU(duration=GPU_DURATION)
def generate_segment(prompt: str, cond_images, seed: int):
"""Run one LTX-2 segment. `cond_images`: list[PIL] tail frames or None.
Returns (frames: list[np.uint8 HxWx3], audio, sample_rate).
"""
if os.getenv("DREAMVERSE_MOCK") == "1":
return _mock_segment(seed)
gen = _get_generator()
kwargs = dict(
prompt=prompt, negative_prompt="",
height=GEN_HEIGHT, width=GEN_WIDTH, num_frames=GEN_FRAMES,
num_inference_steps=GEN_STEPS, fps=GEN_FPS, guidance_scale=1.0,
seed=seed, save_video=False, return_frames=True,
)
if cond_images is not None:
# Continuation: condition on the previous segment's tail frames.
kwargs["ltx2_video_conditions"] = [(cond_images, 0, 1.0)]
kwargs["ltx2_images"] = None
kwargs["image_path"] = None
res = gen.generate_video(**kwargs)
if not isinstance(res, dict):
raise RuntimeError("unexpected generation result")
frames = res.get("frames") or []
return frames, res.get("audio"), res.get("audio_sample_rate")
# --------------------------------------------------------------------------
# fMP4 muxing (ported from dreamverse/av_streaming.py)
# --------------------------------------------------------------------------
def _audio_to_int16(audio, head_trim_samples=0, keep_samples=None):
import torch
if audio is None:
return None, 1
a = audio.detach().cpu().float().numpy() if torch.is_tensor(audio) else np.asarray(audio, np.float32)
if a.ndim == 1:
a = a[:, None]
elif a.ndim == 2 and a.shape[0] <= 8 and a.shape[1] > a.shape[0]:
a = a.T
a = np.clip(a, -1.0, 1.0)
a = (a * 32767.0).astype(np.int16)
if head_trim_samples > 0:
a = a[head_trim_samples:]
if keep_samples is not None:
a = a[:keep_samples]
return a, a.shape[1]
def encode_fmp4(frames, audio, sample_rate, head_trim_frames=0):
"""Encode frames+audio to a fragmented MP4 (bytes), MSE-appendable."""
if head_trim_frames > 0:
frames = frames[head_trim_frames:]
frames = [np.ascontiguousarray(f, dtype=np.uint8) for f in frames]
h, w = frames[0].shape[0], frames[0].shape[1]
sr = int(sample_rate or 24000)
trim_samp = int(round(head_trim_frames / GEN_FPS * sr)) if head_trim_frames > 0 else 0
keep_samp = int(round(len(frames) / GEN_FPS * sr))
audio_i16, nch = _audio_to_int16(audio, trim_samp, keep_samp)
wav_path = None
a_inputs = []
if audio_i16 is not None and audio_i16.shape[0] > 0:
fd, wav_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(nch)
wf.setsampwidth(2)
wf.setframerate(sr)
wf.writeframes(audio_i16.tobytes())
a_inputs = ["-i", wav_path]
cmd = [
FFMPEG, "-hide_banner", "-loglevel", "error", "-y",
"-f", "rawvideo", "-pix_fmt", "rgb24", "-s:v", f"{w}x{h}",
"-r", str(GEN_FPS), "-i", "pipe:0", *a_inputs,
"-c:v", "libx264", "-preset", "ultrafast", "-tune", "zerolatency",
"-profile:v", "baseline", "-g", "12", "-keyint_min", "12",
"-x264-params", "scenecut=0", "-pix_fmt", "yuv420p",
]
if a_inputs:
cmd += ["-c:a", "aac", "-shortest"]
cmd += [
"-movflags", "+frag_keyframe+empty_moov+default_base_moof",
"-frag_duration", "250000", "-muxdelay", "0", "-muxpreload", "0",
"-f", "mp4", "pipe:1",
]
raw = b"".join(f.tobytes() for f in frames)
try:
proc = subprocess.run(cmd, input=raw, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, check=False)
finally:
if wav_path:
try: os.unlink(wav_path)
except FileNotFoundError: pass
if proc.returncode != 0 or not proc.stdout:
raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode('utf-8','ignore')[:300]}")
return proc.stdout
def tail_pil(frames, n=COND_FRAMES):
from PIL import Image
if not frames or len(frames) < n:
return None
return [Image.fromarray(np.ascontiguousarray(frames[len(frames) - n + i])) for i in range(n)]
# --------------------------------------------------------------------------
# WebSocket: the DreamVerse streaming protocol
# --------------------------------------------------------------------------
# DreamVerse's "streaming" product mode is a CONTINUOUS auto-rollout: the
# backend keeps generating segments that flow into the player, and the user
# *steers* the ongoing rollout by submitting prompts (which the frontend
# sends as `rewrite_seed_prompts`, not `append_prompt`). So we run a
# background rollout loop per connection and a concurrent receive loop.
# Segments generated per "budget stretch": the rollout keeps generating the
# current prompt (each segment continuing the last) until it reaches the
# rolling target, then idles. Every edit (rewrite) extends the target, so the
# rollout is effectively continuous while the user keeps steering.
MAX_SEGMENTS = int(os.getenv("DREAMVERSE_MAX_SEGMENTS", "12"))
class Session:
def __init__(self, ws: WebSocket):
self.ws = ws
self.prompt = ""
self.prev_frames = None
self.seg_idx = 0
self.target = 0
self.stream_started = False
self.paused = False
self.running = False
self.task = None
self.send_lock = asyncio.Lock() # serialize sends (rollout + receive)
async def send(self, **payload):
async with self.send_lock:
await self.ws.send_text(json.dumps(payload))
async def send_bytes(self, data: bytes):
async with self.send_lock:
await self.ws.send_bytes(data)
async def _rollout(sess: Session):
"""Generate continuation segments up to the rolling target, then idle.
`ltx2_stream_start` is emitted once by the caller (not here), so restarting
the loop after an edit does NOT reset the player.
"""
sess.running = True
try:
while sess.running and sess.seg_idx < sess.target:
if sess.paused or not sess.prompt:
await asyncio.sleep(0.15)
continue
sess.seg_idx += 1
idx, prompt = sess.seg_idx, sess.prompt
await sess.send(type="prompt_received", prompt_id=f"seg{idx}")
await sess.send(type="prompt_ready", prompt_id=f"seg{idx}", prompt=prompt, source="user_raw")
await sess.send(type="ltx2_segment_start", segment_idx=idx,
total_segments=sess.target, prompt=prompt, source="user_raw")
cond = tail_pil(sess.prev_frames) if sess.prev_frames else None
head_trim = COND_FRAMES if cond is not None else 0
seed = abs(hash((prompt, idx))) % 100000
async with _gpu_lock:
frames, audio, sr = await asyncio.to_thread(generate_segment, prompt, cond, seed)
if not frames:
await sess.send(type="error", message="Generation returned no frames")
break
sess.prev_frames = frames
fmp4 = await asyncio.to_thread(encode_fmp4, frames, audio, sr, head_trim)
sid = f"seg{idx:03d}-{uuid.uuid4().hex[:8]}"
await sess.send(type="media_init", segment_idx=idx, stream_id=sid, mime=MIME)
CHUNK = 1 << 20
for i in range(0, len(fmp4), CHUNK):
await sess.send_bytes(fmp4[i:i + CHUNK])
await sess.send(type="media_segment_complete", segment_idx=idx, stream_id=sid)
await sess.send(type="ltx2_segment_complete", segment_idx=idx, total_segments=sess.target)
print(f"[dreamverse] segment {idx}/{sess.target} ({len(frames)}f, {len(fmp4)//1024}KB) prompt={prompt[:40]!r}", flush=True)
except (WebSocketDisconnect, RuntimeError):
pass
except Exception as e:
print(f"[dreamverse] rollout error: {e}", flush=True)
try:
await sess.send(type="error", message=str(e))
except Exception:
pass
finally:
sess.running = False
def _first_prompt(data):
p = (data.get("initial_rollout_prompt") or "").strip()
if not p:
cur = data.get("curated_prompts") or []
p = cur[0].strip() if cur and isinstance(cur[0], str) else ""
return p
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
sess = Session(ws)
try:
while True:
data = json.loads(await ws.receive_text())
mtype = data.get("type")
if mtype in ("session_init_v2", "project_init_v1"):
await sess.send(type="queue_status", queue_position=0, status="assigned")
await sess.send(type="gpu_assigned", gpu_id=0, session_timeout=GPU_DURATION)
sess.prompt = _first_prompt(data)
sess.prev_frames, sess.seg_idx = None, 0
sess.target = MAX_SEGMENTS
if not sess.stream_started:
await sess.send(type="ltx2_stream_start", total_segments=sess.target)
sess.stream_started = True
if sess.prompt and not sess.running:
sess.task = asyncio.create_task(_rollout(sess))
elif mtype == "rewrite_seed_prompts":
# The "edit" path in streaming mode: steer the rollout. Update
# the prompt and EXTEND the budget so the rollout keeps going
# (applies to the next segment); restart the loop only if idle.
new = (data.get("rewrite_instruction") or "").strip()
if not new:
pw = data.get("prompt_window_prompts") or data.get("prompts") or []
if isinstance(pw, list) and pw and isinstance(pw[0], str):
new = pw[0].strip()
if new:
sess.prompt = new
print(f"[dreamverse] steer -> {new[:50]!r}", flush=True)
sess.target = sess.seg_idx + MAX_SEGMENTS
await sess.send(type="seed_prompts_updated", prompts=[sess.prompt], reason="rewrite")
await sess.send(type="rewrite_seed_prompts_complete")
if not sess.running and sess.prompt:
sess.task = asyncio.create_task(_rollout(sess))
elif mtype == "append_prompt": # devtools/demo build path
new = (data.get("prompt") or "").strip()
if new:
sess.prompt = new
sess.target = sess.seg_idx + MAX_SEGMENTS
if not sess.running and sess.prompt:
sess.task = asyncio.create_task(_rollout(sess))
elif mtype == "set_generation_paused":
sess.paused = bool(data.get("paused"))
elif mtype == "restart_generation":
sess.target = sess.seg_idx + MAX_SEGMENTS
if not sess.running and sess.prompt:
sess.task = asyncio.create_task(_rollout(sess))
elif mtype in ("end_project_keep_session", "reset_to_seed_prompts"):
sess.running = False
sess.prev_frames, sess.seg_idx = None, 0
# set_auto_extension / set_loop_generation acknowledged implicitly.
except (WebSocketDisconnect, json.JSONDecodeError):
pass
except Exception as e:
print(f"[dreamverse] ws error: {e}", flush=True)
finally:
sess.running = False
# --------------------------------------------------------------------------
# Aux endpoints the frontend probes
# --------------------------------------------------------------------------
@app.get("/healthz")
async def healthz():
return {"status": "ok", "model": MODEL_ID, "loaded": _generator is not None}
@app.get("/readyz")
async def readyz():
return JSONResponse({"ready": True})
@app.get("/models")
async def models():
return {"models": [{"id": "fast-ltx2", "name": "FastLTX2", "model_path": MODEL_ID}],
"default": "fast-ltx2"}
@app.get("/curated-presets")
async def curated_presets():
p = os.path.join(FRONTEND, "prompts", "selected_ltx2_continuation_story_presets.json")
if os.path.exists(p):
return FileResponse(p, media_type="application/json")
return JSONResponse([])
@app.get("/prompt-system-config")
async def prompt_system_config():
return {"enhancement_enabled": False, "auto_extension_enabled": False}
# Root: serve the real DreamVerse frontend.
@app.get("/", response_class=HTMLResponse)
async def homepage():
with open(os.path.join(FRONTEND, "index.html"), "r", encoding="utf-8") as f:
return f.read()
# Bundled JS/CSS/fonts under /_next (specific mount — does NOT shadow
# gradio's own /gradio_api/* routes added during launch()).
app.mount("/_next", StaticFiles(directory=os.path.join(FRONTEND, "_next")), name="next")
# Root-level static assets (favicon, svgs, k2.png, index.txt) — registered as
# explicit single-segment routes so they don't shadow gradio's API paths.
def _make_file_route(full_path: str):
async def _route():
return FileResponse(full_path)
return _route
for _f in os.listdir(FRONTEND):
_full = os.path.join(FRONTEND, _f)
if os.path.isfile(_full) and _f != "index.html":
app.add_api_route(f"/{_f}", _make_file_route(_full), methods=["GET"])
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True)