Spaces:
Running on RTX PRO 6000
Running on RTX PRO 6000
multimodalart HF Staff
Speed: text-embed cache + lighter 384x640 profile + PyAV (VAE compile opt-in)
2eeb78f | """DreamVerse LTX-2 — the *real* FastVideo frontend on `gradio.Server`. | |
| This serves DreamVerse's actual Next.js UI (static-exported to `frontend/`) | |
| and implements its `/ws` streaming protocol on `gradio.Server`'s FastAPI, | |
| backed by the real FastVideo LTX-2 pipeline with cross-segment continuation. | |
| Flow (matches the frontend's reducer exactly): | |
| client connects /ws ──▶ session_init_v2 {initial_rollout_prompt, ...} | |
| server ──▶ queue_status, gpu_assigned, ltx2_stream_start | |
| per segment: | |
| ltx2_segment_start ──▶ (generate LTX-2, conditioned on prev tail) ──▶ | |
| media_init {stream_id, mime} ──▶ [binary fMP4 chunks] ──▶ | |
| media_segment_complete ──▶ ltx2_segment_complete | |
| client ──▶ append_prompt {prompt} grows the rollout (continuation). | |
| GPU work is isolated in `@spaces.GPU` (ZeroGPU-shaped). fMP4 muxing mirrors | |
| DreamVerse's `av_streaming.py` (libx264 ultrafast/zerolatency, baseline, GOP | |
| 12, +frag_keyframe+empty_moov+default_base_moof) so the frontend's MSE | |
| SourceBuffer plays it. FA4 skipped (sm_120) → TORCH_SDPA. See ZEROGPU_NOTES.md. | |
| """ | |
| import os | |
| import asyncio | |
| import json | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| import wave | |
| os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "TORCH_SDPA") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| # Keep the AOTInductor/compile cache on the big overlay disk (not a small tmpfs). | |
| os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/home/user/.cache/dreamverse_inductor") | |
| HERE_EARLY = os.path.dirname(os.path.abspath(__file__)) | |
| NVFP4 = os.getenv("DREAMVERSE_NVFP4") == "1" | |
| # HF bucket holding the prebaked sm_120 NVFP4 flashinfer kernels (no runtime JIT). | |
| FP4_CACHE_REPO = os.getenv("DREAMVERSE_FP4_CACHE_REPO", "multimodalart/dreamverse-flashinfer-cache") | |
| def _setup_nvfp4_env(): | |
| """ZeroGPU-clean NVFP4: use flashinfer's AOT path with PREBAKED sm_120 FP4 | |
| kernels, so the runtime loads `.so` directly — no ninja, no nvcc, no JIT. | |
| flashinfer's `build_and_load` short-circuits to `load(aot_path)` when the | |
| kernel `.so` exists under `FLASHINFER_AOT_DIR/<op>/<op>.so`. We pull the two | |
| prebaked `.so` (fp4_quantization_120f, fp4_gemm_cutlass_sm120) from the HF | |
| bucket and point flashinfer's AOT dir at them. The JIT-cache dir is NOT | |
| relocatable (ninja recompiles), but the AOT path IS — verified 0.2s load. | |
| flashinfer must be installed --no-deps (keeps torch 2.11+cu130). See | |
| NVFP4_ZEROGPU.md. | |
| """ | |
| import pathlib | |
| os.environ.setdefault("FLASHINFER_CUDA_ARCH_LIST", "12.0f") # arch w/o nvcc | |
| os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| p = snapshot_download(repo_id=FP4_CACHE_REPO, repo_type="dataset", allow_patterns="aot/*") | |
| aot = os.path.join(p, "aot") | |
| if os.path.isdir(aot): | |
| import flashinfer.jit.env as _fienv | |
| _fienv.FLASHINFER_AOT_DIR = pathlib.Path(aot) | |
| print(f"[nvfp4] prebaked FP4 AOT kernels from {FP4_CACHE_REPO} (no JIT)", flush=True) | |
| else: | |
| print("[nvfp4] AOT dir missing in bucket; flashinfer may JIT", flush=True) | |
| except Exception as e: | |
| print(f"[nvfp4] AOT pull failed ({e}); flashinfer will JIT on first use", flush=True) | |
| import numpy as np | |
| import spaces | |
| if NVFP4: | |
| _setup_nvfp4_env() | |
| # AOTI DiT compiler — installs a hook on LTXModel._process_transformer_blocks | |
| # that loads the prebaked (weight-less, regional) AOTInductor .pt2 on first | |
| # forward. At module scope so the hook is live in FastVideo's spawned worker. | |
| # (No-op in NVFP4 mode: FP4 runs eager + prebaked kernels — see aoti_dit.) | |
| import aoti_dit | |
| aoti_dit.install() | |
| # Memoize Gemma text embeddings across same-prompt rollout segments. | |
| import text_cache | |
| text_cache.install() | |
| from gradio import Server | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| MODEL_ID = os.getenv("LTX2_MODEL_PATH", "FastVideo/LTX2-Distilled-Diffusers") | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| FRONTEND = os.path.join(HERE, "frontend") | |
| # Render settings (modest for single-GPU latency). (num_frames-1) % 8 == 0. | |
| # Lighter default profile for a better realtime ratio (lower res cuts the DiT | |
| # seq-len AND the VAE decode without shortening the clip). 384x640 ~= half the | |
| # pixels of 512x768. (H,W % 32 == 0; (frames-1) % 8 == 0.) | |
| GEN_HEIGHT = int(os.getenv("DREAMVERSE_HEIGHT", "384")) | |
| GEN_WIDTH = int(os.getenv("DREAMVERSE_WIDTH", "640")) | |
| GEN_FRAMES = int(os.getenv("DREAMVERSE_FRAMES", "65")) | |
| GEN_STEPS = int(os.getenv("DREAMVERSE_STEPS", "5")) | |
| GEN_FPS = 24 | |
| GPU_DURATION = int(os.getenv("DREAMVERSE_GPU_DURATION", "300")) | |
| # Continuation conditioning (mirrors dreamverse/video_generation.py). | |
| COND_FRAMES = 9 # tail frames carried forward (8k+1) | |
| MIME = 'video/mp4; codecs="avc1.42E01E,mp4a.40.2"' | |
| FFMPEG = os.getenv("FASTVIDEO_FFMPEG_BIN", "ffmpeg") | |
| app = Server() | |
| _generator = None | |
| _gen_lock_thread = None | |
| _gpu_lock = asyncio.Lock() # single-GPU: serialize generation across sockets | |
| # -------------------------------------------------------------------------- | |
| # Model (lazy; spawn-safe — never load at module scope, see ZEROGPU_NOTES.md) | |
| # -------------------------------------------------------------------------- | |
| def _get_generator(): | |
| global _generator, _gen_lock_thread | |
| if _generator is not None: | |
| return _generator | |
| import threading | |
| if _gen_lock_thread is None: | |
| _gen_lock_thread = threading.Lock() | |
| with _gen_lock_thread: | |
| if _generator is not None: | |
| return _generator | |
| from fastvideo import VideoGenerator | |
| from fastvideo.api.schema import (GeneratorConfig, EngineConfig, | |
| CompileConfig, QuantizationConfig) | |
| mode = "NVFP4 (FP4 DiT, prebaked kernels)" if NVFP4 else "bf16 + AOTI" | |
| print(f"[dreamverse] loading {MODEL_ID} [{mode}] (first segment — slow) ...", flush=True) | |
| engine = EngineConfig(num_gpus=1) | |
| # VAE/text-encoder compile is OFF by default: it's JIT torch.compile, | |
| # which costs ~180s of first-call compile here (VAE decode + the encode | |
| # path used for continuation) for only ~0.5s warm gain — and JIT | |
| # recompiles on every ZeroGPU cold start, breaking the AOT/clean | |
| # property. Opt in with DREAMVERSE_COMPILE_VAE=1 on a dedicated GPU. | |
| # (The ZeroGPU-clean way to speed the VAE is to AOTI it, like the DiT.) | |
| if os.getenv("DREAMVERSE_COMPILE_VAE") == "1": | |
| engine.compile = CompileConfig( | |
| enabled=False, text_encoder_enabled=True, vae_enabled=True, | |
| backend="inductor", fullgraph=False, mode=None, dynamic=True, | |
| ) | |
| print("[dreamverse] VAE+text-encoder torch.compile ON (first calls slow)", flush=True) | |
| if NVFP4: | |
| engine.quantization = QuantizationConfig(transformer_quant="NVFP4") | |
| _generator = VideoGenerator.from_pretrained( | |
| config=GeneratorConfig(model_path=MODEL_ID, engine=engine)) | |
| print("[dreamverse] generator ready", flush=True) | |
| return _generator | |
| def _mock_segment(seed: int): | |
| """Synthetic frames+audio for control-plane testing without the GPU model.""" | |
| n = GEN_FRAMES | |
| rng = np.random.default_rng(seed) | |
| base = rng.integers(0, 255, size=(3,), dtype=np.uint8) | |
| frames = [] | |
| for i in range(n): | |
| f = np.zeros((GEN_HEIGHT, GEN_WIDTH, 3), dtype=np.uint8) | |
| f[:] = (base + i * 3) % 255 | |
| frames.append(f) | |
| import torch | |
| audio = torch.zeros(int(n / GEN_FPS * 24000)) | |
| return frames, audio, 24000 | |
| def generate_segment(prompt: str, cond_images, seed: int): | |
| """Run one LTX-2 segment. `cond_images`: list[PIL] tail frames or None. | |
| Returns (frames: list[np.uint8 HxWx3], audio, sample_rate). | |
| """ | |
| if os.getenv("DREAMVERSE_MOCK") == "1": | |
| return _mock_segment(seed) | |
| gen = _get_generator() | |
| kwargs = dict( | |
| prompt=prompt, negative_prompt="", | |
| height=GEN_HEIGHT, width=GEN_WIDTH, num_frames=GEN_FRAMES, | |
| num_inference_steps=GEN_STEPS, fps=GEN_FPS, guidance_scale=1.0, | |
| seed=seed, save_video=False, return_frames=True, | |
| ) | |
| if cond_images is not None: | |
| # Continuation: condition on the previous segment's tail frames. | |
| kwargs["ltx2_video_conditions"] = [(cond_images, 0, 1.0)] | |
| kwargs["ltx2_images"] = None | |
| kwargs["image_path"] = None | |
| res = gen.generate_video(**kwargs) | |
| if not isinstance(res, dict): | |
| raise RuntimeError("unexpected generation result") | |
| frames = res.get("frames") or [] | |
| return frames, res.get("audio"), res.get("audio_sample_rate") | |
| # -------------------------------------------------------------------------- | |
| # fMP4 muxing (ported from dreamverse/av_streaming.py) | |
| # -------------------------------------------------------------------------- | |
| def _audio_to_int16(audio, head_trim_samples=0, keep_samples=None): | |
| import torch | |
| if audio is None: | |
| return None, 1 | |
| a = audio.detach().cpu().float().numpy() if torch.is_tensor(audio) else np.asarray(audio, np.float32) | |
| if a.ndim == 1: | |
| a = a[:, None] | |
| elif a.ndim == 2 and a.shape[0] <= 8 and a.shape[1] > a.shape[0]: | |
| a = a.T | |
| a = np.clip(a, -1.0, 1.0) | |
| a = (a * 32767.0).astype(np.int16) | |
| if head_trim_samples > 0: | |
| a = a[head_trim_samples:] | |
| if keep_samples is not None: | |
| a = a[:keep_samples] | |
| return a, a.shape[1] | |
| def encode_fmp4(frames, audio, sample_rate, head_trim_frames=0): | |
| """Encode frames+audio to a fragmented MP4 (bytes), MSE-appendable.""" | |
| if head_trim_frames > 0: | |
| frames = frames[head_trim_frames:] | |
| frames = [np.ascontiguousarray(f, dtype=np.uint8) for f in frames] | |
| h, w = frames[0].shape[0], frames[0].shape[1] | |
| sr = int(sample_rate or 24000) | |
| trim_samp = int(round(head_trim_frames / GEN_FPS * sr)) if head_trim_frames > 0 else 0 | |
| keep_samp = int(round(len(frames) / GEN_FPS * sr)) | |
| audio_i16, nch = _audio_to_int16(audio, trim_samp, keep_samp) | |
| wav_path = None | |
| a_inputs = [] | |
| if audio_i16 is not None and audio_i16.shape[0] > 0: | |
| fd, wav_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| with wave.open(wav_path, "wb") as wf: | |
| wf.setnchannels(nch) | |
| wf.setsampwidth(2) | |
| wf.setframerate(sr) | |
| wf.writeframes(audio_i16.tobytes()) | |
| a_inputs = ["-i", wav_path] | |
| cmd = [ | |
| FFMPEG, "-hide_banner", "-loglevel", "error", "-y", | |
| "-f", "rawvideo", "-pix_fmt", "rgb24", "-s:v", f"{w}x{h}", | |
| "-r", str(GEN_FPS), "-i", "pipe:0", *a_inputs, | |
| "-c:v", "libx264", "-preset", "ultrafast", "-tune", "zerolatency", | |
| "-profile:v", "baseline", "-g", "12", "-keyint_min", "12", | |
| "-x264-params", "scenecut=0", "-pix_fmt", "yuv420p", | |
| ] | |
| if a_inputs: | |
| cmd += ["-c:a", "aac", "-shortest"] | |
| cmd += [ | |
| "-movflags", "+frag_keyframe+empty_moov+default_base_moof", | |
| "-frag_duration", "250000", "-muxdelay", "0", "-muxpreload", "0", | |
| "-f", "mp4", "pipe:1", | |
| ] | |
| raw = b"".join(f.tobytes() for f in frames) | |
| try: | |
| proc = subprocess.run(cmd, input=raw, stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, check=False) | |
| finally: | |
| if wav_path: | |
| try: os.unlink(wav_path) | |
| except FileNotFoundError: pass | |
| if proc.returncode != 0 or not proc.stdout: | |
| raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode('utf-8','ignore')[:300]}") | |
| return proc.stdout | |
| def tail_pil(frames, n=COND_FRAMES): | |
| from PIL import Image | |
| if not frames or len(frames) < n: | |
| return None | |
| return [Image.fromarray(np.ascontiguousarray(frames[len(frames) - n + i])) for i in range(n)] | |
| # -------------------------------------------------------------------------- | |
| # WebSocket: the DreamVerse streaming protocol | |
| # -------------------------------------------------------------------------- | |
| # DreamVerse's "streaming" product mode is a CONTINUOUS auto-rollout: the | |
| # backend keeps generating segments that flow into the player, and the user | |
| # *steers* the ongoing rollout by submitting prompts (which the frontend | |
| # sends as `rewrite_seed_prompts`, not `append_prompt`). So we run a | |
| # background rollout loop per connection and a concurrent receive loop. | |
| # Segments generated per "budget stretch": the rollout keeps generating the | |
| # current prompt (each segment continuing the last) until it reaches the | |
| # rolling target, then idles. Every edit (rewrite) extends the target, so the | |
| # rollout is effectively continuous while the user keeps steering. | |
| MAX_SEGMENTS = int(os.getenv("DREAMVERSE_MAX_SEGMENTS", "12")) | |
| class Session: | |
| def __init__(self, ws: WebSocket): | |
| self.ws = ws | |
| self.prompt = "" | |
| self.prev_frames = None | |
| self.seg_idx = 0 | |
| self.target = 0 | |
| self.stream_started = False | |
| self.paused = False | |
| self.running = False | |
| self.task = None | |
| self.send_lock = asyncio.Lock() # serialize sends (rollout + receive) | |
| async def send(self, **payload): | |
| async with self.send_lock: | |
| await self.ws.send_text(json.dumps(payload)) | |
| async def send_bytes(self, data: bytes): | |
| async with self.send_lock: | |
| await self.ws.send_bytes(data) | |
| async def _rollout(sess: Session): | |
| """Generate continuation segments up to the rolling target, then idle. | |
| `ltx2_stream_start` is emitted once by the caller (not here), so restarting | |
| the loop after an edit does NOT reset the player. | |
| """ | |
| sess.running = True | |
| try: | |
| while sess.running and sess.seg_idx < sess.target: | |
| if sess.paused or not sess.prompt: | |
| await asyncio.sleep(0.15) | |
| continue | |
| sess.seg_idx += 1 | |
| idx, prompt = sess.seg_idx, sess.prompt | |
| await sess.send(type="prompt_received", prompt_id=f"seg{idx}") | |
| await sess.send(type="prompt_ready", prompt_id=f"seg{idx}", prompt=prompt, source="user_raw") | |
| await sess.send(type="ltx2_segment_start", segment_idx=idx, | |
| total_segments=sess.target, prompt=prompt, source="user_raw") | |
| cond = tail_pil(sess.prev_frames) if sess.prev_frames else None | |
| head_trim = COND_FRAMES if cond is not None else 0 | |
| seed = abs(hash((prompt, idx))) % 100000 | |
| async with _gpu_lock: | |
| frames, audio, sr = await asyncio.to_thread(generate_segment, prompt, cond, seed) | |
| if not frames: | |
| await sess.send(type="error", message="Generation returned no frames") | |
| break | |
| sess.prev_frames = frames | |
| fmp4 = await asyncio.to_thread(encode_fmp4, frames, audio, sr, head_trim) | |
| sid = f"seg{idx:03d}-{uuid.uuid4().hex[:8]}" | |
| await sess.send(type="media_init", segment_idx=idx, stream_id=sid, mime=MIME) | |
| CHUNK = 1 << 20 | |
| for i in range(0, len(fmp4), CHUNK): | |
| await sess.send_bytes(fmp4[i:i + CHUNK]) | |
| await sess.send(type="media_segment_complete", segment_idx=idx, stream_id=sid) | |
| await sess.send(type="ltx2_segment_complete", segment_idx=idx, total_segments=sess.target) | |
| print(f"[dreamverse] segment {idx}/{sess.target} ({len(frames)}f, {len(fmp4)//1024}KB) prompt={prompt[:40]!r}", flush=True) | |
| except (WebSocketDisconnect, RuntimeError): | |
| pass | |
| except Exception as e: | |
| print(f"[dreamverse] rollout error: {e}", flush=True) | |
| try: | |
| await sess.send(type="error", message=str(e)) | |
| except Exception: | |
| pass | |
| finally: | |
| sess.running = False | |
| def _first_prompt(data): | |
| p = (data.get("initial_rollout_prompt") or "").strip() | |
| if not p: | |
| cur = data.get("curated_prompts") or [] | |
| p = cur[0].strip() if cur and isinstance(cur[0], str) else "" | |
| return p | |
| async def ws_endpoint(ws: WebSocket): | |
| await ws.accept() | |
| sess = Session(ws) | |
| try: | |
| while True: | |
| data = json.loads(await ws.receive_text()) | |
| mtype = data.get("type") | |
| if mtype in ("session_init_v2", "project_init_v1"): | |
| await sess.send(type="queue_status", queue_position=0, status="assigned") | |
| await sess.send(type="gpu_assigned", gpu_id=0, session_timeout=GPU_DURATION) | |
| sess.prompt = _first_prompt(data) | |
| sess.prev_frames, sess.seg_idx = None, 0 | |
| sess.target = MAX_SEGMENTS | |
| if not sess.stream_started: | |
| await sess.send(type="ltx2_stream_start", total_segments=sess.target) | |
| sess.stream_started = True | |
| if sess.prompt and not sess.running: | |
| sess.task = asyncio.create_task(_rollout(sess)) | |
| elif mtype == "rewrite_seed_prompts": | |
| # The "edit" path in streaming mode: steer the rollout. Update | |
| # the prompt and EXTEND the budget so the rollout keeps going | |
| # (applies to the next segment); restart the loop only if idle. | |
| new = (data.get("rewrite_instruction") or "").strip() | |
| if not new: | |
| pw = data.get("prompt_window_prompts") or data.get("prompts") or [] | |
| if isinstance(pw, list) and pw and isinstance(pw[0], str): | |
| new = pw[0].strip() | |
| if new: | |
| sess.prompt = new | |
| print(f"[dreamverse] steer -> {new[:50]!r}", flush=True) | |
| sess.target = sess.seg_idx + MAX_SEGMENTS | |
| await sess.send(type="seed_prompts_updated", prompts=[sess.prompt], reason="rewrite") | |
| await sess.send(type="rewrite_seed_prompts_complete") | |
| if not sess.running and sess.prompt: | |
| sess.task = asyncio.create_task(_rollout(sess)) | |
| elif mtype == "append_prompt": # devtools/demo build path | |
| new = (data.get("prompt") or "").strip() | |
| if new: | |
| sess.prompt = new | |
| sess.target = sess.seg_idx + MAX_SEGMENTS | |
| if not sess.running and sess.prompt: | |
| sess.task = asyncio.create_task(_rollout(sess)) | |
| elif mtype == "set_generation_paused": | |
| sess.paused = bool(data.get("paused")) | |
| elif mtype == "restart_generation": | |
| sess.target = sess.seg_idx + MAX_SEGMENTS | |
| if not sess.running and sess.prompt: | |
| sess.task = asyncio.create_task(_rollout(sess)) | |
| elif mtype in ("end_project_keep_session", "reset_to_seed_prompts"): | |
| sess.running = False | |
| sess.prev_frames, sess.seg_idx = None, 0 | |
| # set_auto_extension / set_loop_generation acknowledged implicitly. | |
| except (WebSocketDisconnect, json.JSONDecodeError): | |
| pass | |
| except Exception as e: | |
| print(f"[dreamverse] ws error: {e}", flush=True) | |
| finally: | |
| sess.running = False | |
| # -------------------------------------------------------------------------- | |
| # Aux endpoints the frontend probes | |
| # -------------------------------------------------------------------------- | |
| async def healthz(): | |
| return {"status": "ok", "model": MODEL_ID, "loaded": _generator is not None} | |
| async def readyz(): | |
| return JSONResponse({"ready": True}) | |
| async def models(): | |
| return {"models": [{"id": "fast-ltx2", "name": "FastLTX2", "model_path": MODEL_ID}], | |
| "default": "fast-ltx2"} | |
| async def curated_presets(): | |
| p = os.path.join(FRONTEND, "prompts", "selected_ltx2_continuation_story_presets.json") | |
| if os.path.exists(p): | |
| return FileResponse(p, media_type="application/json") | |
| return JSONResponse([]) | |
| async def prompt_system_config(): | |
| return {"enhancement_enabled": False, "auto_extension_enabled": False} | |
| # Root: serve the real DreamVerse frontend. | |
| async def homepage(): | |
| with open(os.path.join(FRONTEND, "index.html"), "r", encoding="utf-8") as f: | |
| return f.read() | |
| # Bundled JS/CSS/fonts under /_next (specific mount — does NOT shadow | |
| # gradio's own /gradio_api/* routes added during launch()). | |
| app.mount("/_next", StaticFiles(directory=os.path.join(FRONTEND, "_next")), name="next") | |
| # Root-level static assets (favicon, svgs, k2.png, index.txt) — registered as | |
| # explicit single-segment routes so they don't shadow gradio's API paths. | |
| def _make_file_route(full_path: str): | |
| async def _route(): | |
| return FileResponse(full_path) | |
| return _route | |
| for _f in os.listdir(FRONTEND): | |
| _full = os.path.join(FRONTEND, _f) | |
| if os.path.isfile(_full) and _f != "index.html": | |
| app.add_api_route(f"/{_f}", _make_file_route(_full), methods=["GET"]) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True) | |