"""DreamVerse LTX-2 — the *real* FastVideo frontend on `gradio.Server`. This serves DreamVerse's actual Next.js UI (static-exported to `frontend/`) and implements its `/ws` streaming protocol on `gradio.Server`'s FastAPI, backed by the real FastVideo LTX-2 pipeline with cross-segment continuation. Flow (matches the frontend's reducer exactly): client connects /ws ──▶ session_init_v2 {initial_rollout_prompt, ...} server ──▶ queue_status, gpu_assigned, ltx2_stream_start per segment: ltx2_segment_start ──▶ (generate LTX-2, conditioned on prev tail) ──▶ media_init {stream_id, mime} ──▶ [binary fMP4 chunks] ──▶ media_segment_complete ──▶ ltx2_segment_complete client ──▶ append_prompt {prompt} grows the rollout (continuation). GPU work is isolated in `@spaces.GPU` (ZeroGPU-shaped). fMP4 muxing mirrors DreamVerse's `av_streaming.py` (libx264 ultrafast/zerolatency, baseline, GOP 12, +frag_keyframe+empty_moov+default_base_moof) so the frontend's MSE SourceBuffer plays it. FA4 skipped (sm_120) → TORCH_SDPA. See ZEROGPU_NOTES.md. """ import os import asyncio import json import subprocess import tempfile import uuid import wave os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "TORCH_SDPA") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # Keep the AOTInductor/compile cache on the big overlay disk (not a small tmpfs). os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/home/user/.cache/dreamverse_inductor") HERE_EARLY = os.path.dirname(os.path.abspath(__file__)) NVFP4 = os.getenv("DREAMVERSE_NVFP4") == "1" # HF bucket holding the prebaked sm_120 NVFP4 flashinfer kernels (no runtime JIT). FP4_CACHE_REPO = os.getenv("DREAMVERSE_FP4_CACHE_REPO", "multimodalart/dreamverse-flashinfer-cache") def _setup_nvfp4_env(): """ZeroGPU-clean NVFP4: use flashinfer's AOT path with PREBAKED sm_120 FP4 kernels, so the runtime loads `.so` directly — no ninja, no nvcc, no JIT. flashinfer's `build_and_load` short-circuits to `load(aot_path)` when the kernel `.so` exists under `FLASHINFER_AOT_DIR//.so`. We pull the two prebaked `.so` (fp4_quantization_120f, fp4_gemm_cutlass_sm120) from the HF bucket and point flashinfer's AOT dir at them. The JIT-cache dir is NOT relocatable (ninja recompiles), but the AOT path IS — verified 0.2s load. flashinfer must be installed --no-deps (keeps torch 2.11+cu130). See NVFP4_ZEROGPU.md. """ import pathlib os.environ.setdefault("FLASHINFER_CUDA_ARCH_LIST", "12.0f") # arch w/o nvcc os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") try: from huggingface_hub import snapshot_download p = snapshot_download(repo_id=FP4_CACHE_REPO, repo_type="dataset", allow_patterns="aot/*") aot = os.path.join(p, "aot") if os.path.isdir(aot): import flashinfer.jit.env as _fienv _fienv.FLASHINFER_AOT_DIR = pathlib.Path(aot) print(f"[nvfp4] prebaked FP4 AOT kernels from {FP4_CACHE_REPO} (no JIT)", flush=True) else: print("[nvfp4] AOT dir missing in bucket; flashinfer may JIT", flush=True) except Exception as e: print(f"[nvfp4] AOT pull failed ({e}); flashinfer will JIT on first use", flush=True) import numpy as np import spaces if NVFP4: _setup_nvfp4_env() # AOTI DiT compiler — installs a hook on LTXModel._process_transformer_blocks # that loads the prebaked (weight-less, regional) AOTInductor .pt2 on first # forward. At module scope so the hook is live in FastVideo's spawned worker. # (No-op in NVFP4 mode: FP4 runs eager + prebaked kernels — see aoti_dit.) import aoti_dit aoti_dit.install() # Memoize Gemma text embeddings across same-prompt rollout segments. import text_cache text_cache.install() from gradio import Server from fastapi import WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles MODEL_ID = os.getenv("LTX2_MODEL_PATH", "FastVideo/LTX2-Distilled-Diffusers") HERE = os.path.dirname(os.path.abspath(__file__)) FRONTEND = os.path.join(HERE, "frontend") # Render settings (modest for single-GPU latency). (num_frames-1) % 8 == 0. # Lighter default profile for a better realtime ratio (lower res cuts the DiT # seq-len AND the VAE decode without shortening the clip). 384x640 ~= half the # pixels of 512x768. (H,W % 32 == 0; (frames-1) % 8 == 0.) GEN_HEIGHT = int(os.getenv("DREAMVERSE_HEIGHT", "384")) GEN_WIDTH = int(os.getenv("DREAMVERSE_WIDTH", "640")) GEN_FRAMES = int(os.getenv("DREAMVERSE_FRAMES", "65")) GEN_STEPS = int(os.getenv("DREAMVERSE_STEPS", "5")) GEN_FPS = 24 GPU_DURATION = int(os.getenv("DREAMVERSE_GPU_DURATION", "300")) # Continuation conditioning (mirrors dreamverse/video_generation.py). COND_FRAMES = 9 # tail frames carried forward (8k+1) MIME = 'video/mp4; codecs="avc1.42E01E,mp4a.40.2"' FFMPEG = os.getenv("FASTVIDEO_FFMPEG_BIN", "ffmpeg") app = Server() _generator = None _gen_lock_thread = None _gpu_lock = asyncio.Lock() # single-GPU: serialize generation across sockets # -------------------------------------------------------------------------- # Model (lazy; spawn-safe — never load at module scope, see ZEROGPU_NOTES.md) # -------------------------------------------------------------------------- def _get_generator(): global _generator, _gen_lock_thread if _generator is not None: return _generator import threading if _gen_lock_thread is None: _gen_lock_thread = threading.Lock() with _gen_lock_thread: if _generator is not None: return _generator from fastvideo import VideoGenerator from fastvideo.api.schema import (GeneratorConfig, EngineConfig, CompileConfig, QuantizationConfig) mode = "NVFP4 (FP4 DiT, prebaked kernels)" if NVFP4 else "bf16 + AOTI" print(f"[dreamverse] loading {MODEL_ID} [{mode}] (first segment — slow) ...", flush=True) engine = EngineConfig(num_gpus=1) # VAE/text-encoder compile is OFF by default: it's JIT torch.compile, # which costs ~180s of first-call compile here (VAE decode + the encode # path used for continuation) for only ~0.5s warm gain — and JIT # recompiles on every ZeroGPU cold start, breaking the AOT/clean # property. Opt in with DREAMVERSE_COMPILE_VAE=1 on a dedicated GPU. # (The ZeroGPU-clean way to speed the VAE is to AOTI it, like the DiT.) if os.getenv("DREAMVERSE_COMPILE_VAE") == "1": engine.compile = CompileConfig( enabled=False, text_encoder_enabled=True, vae_enabled=True, backend="inductor", fullgraph=False, mode=None, dynamic=True, ) print("[dreamverse] VAE+text-encoder torch.compile ON (first calls slow)", flush=True) if NVFP4: engine.quantization = QuantizationConfig(transformer_quant="NVFP4") _generator = VideoGenerator.from_pretrained( config=GeneratorConfig(model_path=MODEL_ID, engine=engine)) print("[dreamverse] generator ready", flush=True) return _generator def _mock_segment(seed: int): """Synthetic frames+audio for control-plane testing without the GPU model.""" n = GEN_FRAMES rng = np.random.default_rng(seed) base = rng.integers(0, 255, size=(3,), dtype=np.uint8) frames = [] for i in range(n): f = np.zeros((GEN_HEIGHT, GEN_WIDTH, 3), dtype=np.uint8) f[:] = (base + i * 3) % 255 frames.append(f) import torch audio = torch.zeros(int(n / GEN_FPS * 24000)) return frames, audio, 24000 @spaces.GPU(duration=GPU_DURATION) def generate_segment(prompt: str, cond_images, seed: int): """Run one LTX-2 segment. `cond_images`: list[PIL] tail frames or None. Returns (frames: list[np.uint8 HxWx3], audio, sample_rate). """ if os.getenv("DREAMVERSE_MOCK") == "1": return _mock_segment(seed) gen = _get_generator() kwargs = dict( prompt=prompt, negative_prompt="", height=GEN_HEIGHT, width=GEN_WIDTH, num_frames=GEN_FRAMES, num_inference_steps=GEN_STEPS, fps=GEN_FPS, guidance_scale=1.0, seed=seed, save_video=False, return_frames=True, ) if cond_images is not None: # Continuation: condition on the previous segment's tail frames. kwargs["ltx2_video_conditions"] = [(cond_images, 0, 1.0)] kwargs["ltx2_images"] = None kwargs["image_path"] = None res = gen.generate_video(**kwargs) if not isinstance(res, dict): raise RuntimeError("unexpected generation result") frames = res.get("frames") or [] return frames, res.get("audio"), res.get("audio_sample_rate") # -------------------------------------------------------------------------- # fMP4 muxing (ported from dreamverse/av_streaming.py) # -------------------------------------------------------------------------- def _audio_to_int16(audio, head_trim_samples=0, keep_samples=None): import torch if audio is None: return None, 1 a = audio.detach().cpu().float().numpy() if torch.is_tensor(audio) else np.asarray(audio, np.float32) if a.ndim == 1: a = a[:, None] elif a.ndim == 2 and a.shape[0] <= 8 and a.shape[1] > a.shape[0]: a = a.T a = np.clip(a, -1.0, 1.0) a = (a * 32767.0).astype(np.int16) if head_trim_samples > 0: a = a[head_trim_samples:] if keep_samples is not None: a = a[:keep_samples] return a, a.shape[1] def encode_fmp4(frames, audio, sample_rate, head_trim_frames=0): """Encode frames+audio to a fragmented MP4 (bytes), MSE-appendable.""" if head_trim_frames > 0: frames = frames[head_trim_frames:] frames = [np.ascontiguousarray(f, dtype=np.uint8) for f in frames] h, w = frames[0].shape[0], frames[0].shape[1] sr = int(sample_rate or 24000) trim_samp = int(round(head_trim_frames / GEN_FPS * sr)) if head_trim_frames > 0 else 0 keep_samp = int(round(len(frames) / GEN_FPS * sr)) audio_i16, nch = _audio_to_int16(audio, trim_samp, keep_samp) wav_path = None a_inputs = [] if audio_i16 is not None and audio_i16.shape[0] > 0: fd, wav_path = tempfile.mkstemp(suffix=".wav") os.close(fd) with wave.open(wav_path, "wb") as wf: wf.setnchannels(nch) wf.setsampwidth(2) wf.setframerate(sr) wf.writeframes(audio_i16.tobytes()) a_inputs = ["-i", wav_path] cmd = [ FFMPEG, "-hide_banner", "-loglevel", "error", "-y", "-f", "rawvideo", "-pix_fmt", "rgb24", "-s:v", f"{w}x{h}", "-r", str(GEN_FPS), "-i", "pipe:0", *a_inputs, "-c:v", "libx264", "-preset", "ultrafast", "-tune", "zerolatency", "-profile:v", "baseline", "-g", "12", "-keyint_min", "12", "-x264-params", "scenecut=0", "-pix_fmt", "yuv420p", ] if a_inputs: cmd += ["-c:a", "aac", "-shortest"] cmd += [ "-movflags", "+frag_keyframe+empty_moov+default_base_moof", "-frag_duration", "250000", "-muxdelay", "0", "-muxpreload", "0", "-f", "mp4", "pipe:1", ] raw = b"".join(f.tobytes() for f in frames) try: proc = subprocess.run(cmd, input=raw, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) finally: if wav_path: try: os.unlink(wav_path) except FileNotFoundError: pass if proc.returncode != 0 or not proc.stdout: raise RuntimeError(f"ffmpeg failed: {proc.stderr.decode('utf-8','ignore')[:300]}") return proc.stdout def tail_pil(frames, n=COND_FRAMES): from PIL import Image if not frames or len(frames) < n: return None return [Image.fromarray(np.ascontiguousarray(frames[len(frames) - n + i])) for i in range(n)] # -------------------------------------------------------------------------- # WebSocket: the DreamVerse streaming protocol # -------------------------------------------------------------------------- # DreamVerse's "streaming" product mode is a CONTINUOUS auto-rollout: the # backend keeps generating segments that flow into the player, and the user # *steers* the ongoing rollout by submitting prompts (which the frontend # sends as `rewrite_seed_prompts`, not `append_prompt`). So we run a # background rollout loop per connection and a concurrent receive loop. # Segments generated per "budget stretch": the rollout keeps generating the # current prompt (each segment continuing the last) until it reaches the # rolling target, then idles. Every edit (rewrite) extends the target, so the # rollout is effectively continuous while the user keeps steering. MAX_SEGMENTS = int(os.getenv("DREAMVERSE_MAX_SEGMENTS", "12")) class Session: def __init__(self, ws: WebSocket): self.ws = ws self.prompt = "" self.prev_frames = None self.seg_idx = 0 self.target = 0 self.stream_started = False self.paused = False self.running = False self.task = None self.send_lock = asyncio.Lock() # serialize sends (rollout + receive) async def send(self, **payload): async with self.send_lock: await self.ws.send_text(json.dumps(payload)) async def send_bytes(self, data: bytes): async with self.send_lock: await self.ws.send_bytes(data) async def _rollout(sess: Session): """Generate continuation segments up to the rolling target, then idle. `ltx2_stream_start` is emitted once by the caller (not here), so restarting the loop after an edit does NOT reset the player. """ sess.running = True try: while sess.running and sess.seg_idx < sess.target: if sess.paused or not sess.prompt: await asyncio.sleep(0.15) continue sess.seg_idx += 1 idx, prompt = sess.seg_idx, sess.prompt await sess.send(type="prompt_received", prompt_id=f"seg{idx}") await sess.send(type="prompt_ready", prompt_id=f"seg{idx}", prompt=prompt, source="user_raw") await sess.send(type="ltx2_segment_start", segment_idx=idx, total_segments=sess.target, prompt=prompt, source="user_raw") cond = tail_pil(sess.prev_frames) if sess.prev_frames else None head_trim = COND_FRAMES if cond is not None else 0 seed = abs(hash((prompt, idx))) % 100000 async with _gpu_lock: frames, audio, sr = await asyncio.to_thread(generate_segment, prompt, cond, seed) if not frames: await sess.send(type="error", message="Generation returned no frames") break sess.prev_frames = frames fmp4 = await asyncio.to_thread(encode_fmp4, frames, audio, sr, head_trim) sid = f"seg{idx:03d}-{uuid.uuid4().hex[:8]}" await sess.send(type="media_init", segment_idx=idx, stream_id=sid, mime=MIME) CHUNK = 1 << 20 for i in range(0, len(fmp4), CHUNK): await sess.send_bytes(fmp4[i:i + CHUNK]) await sess.send(type="media_segment_complete", segment_idx=idx, stream_id=sid) await sess.send(type="ltx2_segment_complete", segment_idx=idx, total_segments=sess.target) print(f"[dreamverse] segment {idx}/{sess.target} ({len(frames)}f, {len(fmp4)//1024}KB) prompt={prompt[:40]!r}", flush=True) except (WebSocketDisconnect, RuntimeError): pass except Exception as e: print(f"[dreamverse] rollout error: {e}", flush=True) try: await sess.send(type="error", message=str(e)) except Exception: pass finally: sess.running = False def _first_prompt(data): p = (data.get("initial_rollout_prompt") or "").strip() if not p: cur = data.get("curated_prompts") or [] p = cur[0].strip() if cur and isinstance(cur[0], str) else "" return p @app.websocket("/ws") async def ws_endpoint(ws: WebSocket): await ws.accept() sess = Session(ws) try: while True: data = json.loads(await ws.receive_text()) mtype = data.get("type") if mtype in ("session_init_v2", "project_init_v1"): await sess.send(type="queue_status", queue_position=0, status="assigned") await sess.send(type="gpu_assigned", gpu_id=0, session_timeout=GPU_DURATION) sess.prompt = _first_prompt(data) sess.prev_frames, sess.seg_idx = None, 0 sess.target = MAX_SEGMENTS if not sess.stream_started: await sess.send(type="ltx2_stream_start", total_segments=sess.target) sess.stream_started = True if sess.prompt and not sess.running: sess.task = asyncio.create_task(_rollout(sess)) elif mtype == "rewrite_seed_prompts": # The "edit" path in streaming mode: steer the rollout. Update # the prompt and EXTEND the budget so the rollout keeps going # (applies to the next segment); restart the loop only if idle. new = (data.get("rewrite_instruction") or "").strip() if not new: pw = data.get("prompt_window_prompts") or data.get("prompts") or [] if isinstance(pw, list) and pw and isinstance(pw[0], str): new = pw[0].strip() if new: sess.prompt = new print(f"[dreamverse] steer -> {new[:50]!r}", flush=True) sess.target = sess.seg_idx + MAX_SEGMENTS await sess.send(type="seed_prompts_updated", prompts=[sess.prompt], reason="rewrite") await sess.send(type="rewrite_seed_prompts_complete") if not sess.running and sess.prompt: sess.task = asyncio.create_task(_rollout(sess)) elif mtype == "append_prompt": # devtools/demo build path new = (data.get("prompt") or "").strip() if new: sess.prompt = new sess.target = sess.seg_idx + MAX_SEGMENTS if not sess.running and sess.prompt: sess.task = asyncio.create_task(_rollout(sess)) elif mtype == "set_generation_paused": sess.paused = bool(data.get("paused")) elif mtype == "restart_generation": sess.target = sess.seg_idx + MAX_SEGMENTS if not sess.running and sess.prompt: sess.task = asyncio.create_task(_rollout(sess)) elif mtype in ("end_project_keep_session", "reset_to_seed_prompts"): sess.running = False sess.prev_frames, sess.seg_idx = None, 0 # set_auto_extension / set_loop_generation acknowledged implicitly. except (WebSocketDisconnect, json.JSONDecodeError): pass except Exception as e: print(f"[dreamverse] ws error: {e}", flush=True) finally: sess.running = False # -------------------------------------------------------------------------- # Aux endpoints the frontend probes # -------------------------------------------------------------------------- @app.get("/healthz") async def healthz(): return {"status": "ok", "model": MODEL_ID, "loaded": _generator is not None} @app.get("/readyz") async def readyz(): return JSONResponse({"ready": True}) @app.get("/models") async def models(): return {"models": [{"id": "fast-ltx2", "name": "FastLTX2", "model_path": MODEL_ID}], "default": "fast-ltx2"} @app.get("/curated-presets") async def curated_presets(): p = os.path.join(FRONTEND, "prompts", "selected_ltx2_continuation_story_presets.json") if os.path.exists(p): return FileResponse(p, media_type="application/json") return JSONResponse([]) @app.get("/prompt-system-config") async def prompt_system_config(): return {"enhancement_enabled": False, "auto_extension_enabled": False} # Root: serve the real DreamVerse frontend. @app.get("/", response_class=HTMLResponse) async def homepage(): with open(os.path.join(FRONTEND, "index.html"), "r", encoding="utf-8") as f: return f.read() # Bundled JS/CSS/fonts under /_next (specific mount — does NOT shadow # gradio's own /gradio_api/* routes added during launch()). app.mount("/_next", StaticFiles(directory=os.path.join(FRONTEND, "_next")), name="next") # Root-level static assets (favicon, svgs, k2.png, index.txt) — registered as # explicit single-segment routes so they don't shadow gradio's API paths. def _make_file_route(full_path: str): async def _route(): return FileResponse(full_path) return _route for _f in os.listdir(FRONTEND): _full = os.path.join(FRONTEND, _f) if os.path.isfile(_full) and _f != "index.html": app.add_api_route(f"/{_f}", _make_file_route(_full), methods=["GET"]) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True)