Spaces:
Runtime error
Runtime error
| """Gradio app for RealWonder interactive demo (HuggingFace Space). | |
| Replaces Flask + SocketIO with a Gradio Blocks interface that streams | |
| generated frames in real-time via Gradio's generator support. | |
| ZeroGPU-compatible: GPU is held for the duration of each generation call. | |
| Download checkpoint before running: | |
| huggingface-cli download ziyc/realwonder \ | |
| --include "Realwonder-Distilled-AR-I2V-Flow/*" \ | |
| --local-dir ckpts/ | |
| """ | |
| import os | |
| os.environ['SETUPTOOLS_USE_DISTUTILS'] = 'stdlib' | |
| os.environ.setdefault('PYOPENGL_PLATFORM', 'egl') # headless EGL for Genesis/pyrender | |
| # Patch gradio_client bug: get_type() does `"const" in schema` without checking | |
| # whether schema is a bool first (valid JSON Schema: additionalProperties: false). | |
| # This crashes the /info API endpoint. Fix: intercept boolean schemas early. | |
| try: | |
| import gradio_client.utils as _gc_utils | |
| _orig_j2p = _gc_utils._json_schema_to_python_type | |
| def _patched_j2p(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "bool" | |
| return _orig_j2p(schema, defs) | |
| _gc_utils._json_schema_to_python_type = _patched_j2p | |
| except Exception: | |
| pass | |
| # Patch Genesis from_torch: in PyTorch 2.5+, Tensor(existing_plain_tensor) raises | |
| # "raw Tensor object is already associated to a python object of type Tensor | |
| # which is not a subclass of the requested type" | |
| # because torch.Tensor.__new__(SubClass, existing_tensor) checks that the existing | |
| # TensorImpl's Python wrapper is a subclass of SubClass. torch.Tensor is the parent, | |
| # not a subclass of genesis.grad.Tensor, so the check fails. | |
| # Fix: use torch.Tensor._make_subclass(cls, t) which is the proper PyTorch API for | |
| # creating a subclass view of an existing tensor regardless of the wrapper type. | |
| def _patch_genesis_from_torch(): | |
| try: | |
| import genesis | |
| import genesis.grad.creation_ops as _gc_ops | |
| import genesis.grad.tensor as _gt_mod | |
| _Tensor = _gt_mod.Tensor | |
| _gs = genesis | |
| def _patched_from_torch(torch_tensor, dtype=None, requires_grad=False, detach=True, scene=None): | |
| if dtype is None: | |
| dtype = torch_tensor.dtype | |
| if dtype in (float, torch.float32, torch.float64): | |
| dtype = _gs.tc_float | |
| elif dtype in (int, torch.int32, torch.int64): | |
| dtype = _gs.tc_int | |
| elif dtype in (bool, torch.bool): | |
| dtype = torch.bool | |
| else: | |
| _gs.raise_exception(f"Unsupported dtype: {dtype}") | |
| if torch_tensor.requires_grad and (not detach) and (not requires_grad): | |
| requires_grad = True | |
| # Perform ALL tensor operations on plain torch.Tensor objects BEFORE | |
| # wrapping as genesis.grad.Tensor. This avoids __torch_function__ | |
| # interference from ZeroGPU (spaces/zero/torch/patching.py), which | |
| # intercepts operations on tensor subclasses and then fails when | |
| # PyTorch tries to restore the subclass type via as_subclass(). | |
| t = torch_tensor.to(device=_gs.device, dtype=dtype).clone() | |
| if detach: | |
| t = t.detach() | |
| # _make_subclass uses MAYBE_UNINITIALIZED status, bypassing the | |
| # "already associated" check that Tensor(existing_tensor) triggers. | |
| gs_tensor = torch.Tensor._make_subclass(_Tensor, t, requires_grad) | |
| gs_tensor.scene = scene | |
| gs_tensor.uid = _gs.UID() | |
| gs_tensor.parents = [] | |
| return gs_tensor | |
| _gc_ops.from_torch = _patched_from_torch | |
| print("[patch] Genesis from_torch patched (_make_subclass fix for PyTorch 2.5+)") | |
| except Exception as e: | |
| print(f"[patch] Genesis from_torch patch skipped: {e}") | |
| import base64 | |
| import io | |
| import threading | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from queue import Queue, Full as QueueFull, Empty as QueueEmpty | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| # ZeroGPU (HuggingFace Spaces): import spaces with a no-op fallback for | |
| # local development where the spaces package is not installed. | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: # noqa: N801 | |
| """Stub so the decorators are harmless outside HF Spaces.""" | |
| def GPU(fn=None, *, duration=None): | |
| if fn is not None: | |
| return fn | |
| def decorator(f): | |
| return f | |
| return decorator | |
| from config import ( | |
| FRAMES_PER_BLOCK, FRAMES_PER_BLOCK_PIXEL, FRAMES_FIRST_BLOCK_PIXEL, | |
| FPS, LATENT_H, LATENT_W, LATENT_C, | |
| DEFAULT_HEIGHT, DEFAULT_WIDTH, TEMPORAL_FACTOR, | |
| load_case_sdedit_config, | |
| ) | |
| from simulation_engine import InteractiveSimulator | |
| from noise_warper_stream import StreamingNoiseWarper | |
| from video_generator import StreamingVideoGenerator | |
| from case_handlers.base import get_demo_case_handler | |
| import case_handlers # trigger registration | |
| from gpu_profiler import log_gpu, set_gpu_logging | |
| from simulation.utils import resize_and_crop_pil | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace Space configuration | |
| # --------------------------------------------------------------------------- | |
| DEMO_DATA_DIR = Path("./demo_data") | |
| CHECKPOINT_DIR = Path("ckpts/Realwonder-Distilled-AR-I2V-Flow") | |
| WAN_MODEL_DIR = Path("wan_models/Wan2.1-Fun-V1.1-1.3B-InP") | |
| SEED = 42 | |
| USE_EMA = False | |
| ENABLE_TAEHV = False | |
| MAX_OBJECTS = 3 # maximum objects across all cases | |
| CASE_DISPLAY_NAMES = { | |
| "lamp": "Lamp on River", | |
| "persimmon": "Falling Persimmons", | |
| "tree": "Breaking Tree", | |
| "santa_cloth": "Blowing Clothes", | |
| } | |
| class CaseBundle: | |
| simulator: InteractiveSimulator | |
| noise_warper: StreamingNoiseWarper | |
| demo_case_handler: object | |
| preview_pil: Image.Image | |
| default_prompt: str | |
| num_blocks: int | |
| first_frame_path: str | |
| # --------------------------------------------------------------------------- | |
| # Global state — initialized at module load before Gradio starts | |
| # --------------------------------------------------------------------------- | |
| video_generator: StreamingVideoGenerator = None | |
| cases: dict = {} # case_name → CaseBundle | |
| _stop_event = threading.Event() | |
| _gen_lock = threading.Lock() | |
| _startup_lock = threading.Lock() | |
| _is_generating = False | |
| # --------------------------------------------------------------------------- | |
| # Model download helpers | |
| # --------------------------------------------------------------------------- | |
| def _ensure_models_downloaded(): | |
| from huggingface_hub import snapshot_download | |
| CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) | |
| WAN_MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| if not any(CHECKPOINT_DIR.glob("*.pt")) and not any(CHECKPOINT_DIR.glob("*.pth")): | |
| print("Downloading RealWonder checkpoint from ziyc/realwonder ...") | |
| snapshot_download( | |
| repo_id="ziyc/realwonder", | |
| allow_patterns="Realwonder-Distilled-AR-I2V-Flow/*", | |
| local_dir="ckpts/", | |
| ) | |
| print("RealWonder checkpoint downloaded.") | |
| vae_path = WAN_MODEL_DIR / "Wan2.1_VAE.pth" | |
| if not vae_path.exists(): | |
| print("Downloading Wan2.1 base models from alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP ...") | |
| snapshot_download( | |
| repo_id="alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP", | |
| local_dir=str(WAN_MODEL_DIR), | |
| ) | |
| print("Wan2.1 base models downloaded.") | |
| def _find_checkpoint(): | |
| for pattern in ("*.pt", "*.pth"): | |
| matches = sorted(CHECKPOINT_DIR.rglob(pattern)) | |
| if matches: | |
| return str(matches[0]) | |
| raise FileNotFoundError( | |
| f"No .pt/.pth checkpoint found in {CHECKPOINT_DIR}. " | |
| "Run: huggingface-cli download ziyc/realwonder " | |
| "--include 'Realwonder-Distilled-AR-I2V-Flow/*' --local-dir ckpts/" | |
| ) | |
| def _find_first_frame(case_dir, case_config): | |
| case_path = Path(case_dir) | |
| candidate = case_path / "first_frame.png" | |
| if candidate.exists(): | |
| return str(candidate) | |
| input_path = Path(case_config.get("data_path", "")) / "input.png" | |
| if input_path.exists(): | |
| return str(input_path) | |
| return str(candidate) | |
| # --------------------------------------------------------------------------- | |
| # Pipeline warmup | |
| # --------------------------------------------------------------------------- | |
| def _warmup_pipeline(warmup_case_name): | |
| """Run dummy passes to compile CUDA kernels before first user request.""" | |
| import time | |
| bundle = cases[warmup_case_name] | |
| default_prompt = bundle.default_prompt | |
| print(f"[5/6] Warming up CUDA kernels for '{warmup_case_name}' (one-time cost)...") | |
| torch.set_grad_enabled(False) | |
| t0 = time.perf_counter() | |
| # Sim render warmup | |
| for _ in range(2): | |
| for _ in range(bundle.simulator.frame_steps): | |
| updated_points = bundle.simulator.step() | |
| bundle.simulator.render_and_flow(updated_points) | |
| bundle.simulator.scene.reset() | |
| bundle.simulator.case_handler.fix_particles() | |
| bundle.simulator.step_count = 0 | |
| bundle.simulator.svr.previous_frame_data = None | |
| bundle.simulator.svr.optical_flow = np.array([]) | |
| bundle.simulator.svr._last_optical_flow = None | |
| bundle.simulator.svr._prev_fg_frags_idx = None | |
| bundle.simulator.svr._prev_fg_frags_dists = None | |
| # Noise warp warmup | |
| dummy_flow = np.zeros((2, 512, 512), dtype=np.float32) | |
| bundle.noise_warper.warp_step(dummy_flow) | |
| bundle.noise_warper.reset() | |
| t1 = time.perf_counter() | |
| print(f" Sim + warp warmup: {t1 - t0:.1f}s") | |
| # VAE + diffusion warmup | |
| video_generator.prepare_generation(default_prompt, warmup_case_name) | |
| dummy_pixel = torch.zeros( | |
| 1, 3, FRAMES_FIRST_BLOCK_PIXEL, DEFAULT_HEIGHT, DEFAULT_WIDTH, | |
| device=video_generator.device, dtype=torch.bfloat16, | |
| ) | |
| sim_latent = video_generator.pipeline.encode_vae.cached_encode_to_latent( | |
| dummy_pixel, is_first=True, | |
| ) | |
| if sim_latent.shape[1] > FRAMES_PER_BLOCK: | |
| sim_latent = sim_latent[:, :FRAMES_PER_BLOCK] | |
| elif sim_latent.shape[1] < FRAMES_PER_BLOCK: | |
| pad = FRAMES_PER_BLOCK - sim_latent.shape[1] | |
| sim_latent = torch.cat([sim_latent, sim_latent[:, -1:].repeat(1, pad, 1, 1, 1)], dim=1) | |
| dummy_noise = torch.randn( | |
| 1, FRAMES_PER_BLOCK, LATENT_C, LATENT_H, LATENT_W, | |
| device=video_generator.device, dtype=torch.bfloat16, | |
| ) | |
| video_generator.generate_block(block_idx=0, structured_noise=dummy_noise, sim_latent=sim_latent) | |
| for blk in range(1, 3): | |
| _d = torch.zeros(1, FRAMES_PER_BLOCK, LATENT_C, LATENT_H, LATENT_W, | |
| device=video_generator.device, dtype=torch.bfloat16) | |
| _n = torch.randn_like(_d) | |
| video_generator.generate_block(block_idx=blk, structured_noise=_n, sim_latent=_d) | |
| video_generator.reset() | |
| video_generator.pipeline.encode_vae.model.clear_cache() | |
| t2 = time.perf_counter() | |
| print(f" VAE + diffusion warmup: {t2 - t1:.1f}s") | |
| print(f" Total warmup: {t2 - t0:.1f}s — first generation will be fast.") | |
| log_gpu("after pipeline warmup") | |
| # --------------------------------------------------------------------------- | |
| # Startup — decorated with @spaces.GPU so CUDA is available for model loading, | |
| # PyTorch3D renderer init, precompute (VAE/CLIP), and kernel warmup. | |
| # duration=900 gives 15 min — enough for downloading + loading a 14 B model | |
| # and warming up 4 cases on first launch. | |
| # --------------------------------------------------------------------------- | |
| def startup(): | |
| global video_generator, cases | |
| set_gpu_logging(False) | |
| _ensure_models_downloaded() | |
| checkpoint_path = _find_checkpoint() | |
| if not DEMO_DATA_DIR.exists(): | |
| raise RuntimeError(f"demo_data directory not found: {DEMO_DATA_DIR}") | |
| import yaml | |
| case_dirs = sorted([ | |
| d for d in DEMO_DATA_DIR.iterdir() | |
| if d.is_dir() and (d / "config.yaml").exists() | |
| ]) | |
| if not case_dirs: | |
| raise RuntimeError(f"No case subdirs with config.yaml found in {DEMO_DATA_DIR}") | |
| print(f"Found {len(case_dirs)} case(s): {[d.name for d in case_dirs]}") | |
| all_case_configs, all_sdedit_cfgs = {}, {} | |
| for case_dir in case_dirs: | |
| with open(case_dir / "config.yaml") as f: | |
| case_config = yaml.safe_load(f) | |
| sdedit_cfg = load_case_sdedit_config(case_config) | |
| all_case_configs[case_dir.name] = case_config | |
| all_sdedit_cfgs[case_dir.name] = sdedit_cfg | |
| print(f" Case '{case_dir.name}': SDEdit config = {sdedit_cfg}") | |
| max_num_pixel_frames = max(cfg["num_pixel_frames"] for cfg in all_sdedit_cfgs.values()) | |
| first_case_name = case_dirs[0].name | |
| first_sdedit_cfg = all_sdedit_cfgs[first_case_name] | |
| # ---- Step 1: Video generator ---- | |
| print(f"[1/6] Initializing video generator from {checkpoint_path} ...") | |
| log_gpu("before video generator init") | |
| video_generator = StreamingVideoGenerator( | |
| checkpoint_path=checkpoint_path, | |
| num_pixel_frames=max_num_pixel_frames, | |
| denoising_steps=first_sdedit_cfg["denoising_step_list"], | |
| mask_dropin_step=first_sdedit_cfg["mask_dropin_step"], | |
| franka_step=first_sdedit_cfg["franka_step"], | |
| use_ema=USE_EMA, | |
| seed=SEED, | |
| enable_taehv=ENABLE_TAEHV, | |
| device="cpu", | |
| ) | |
| video_generator.setup() | |
| log_gpu("after video generator setup") | |
| # ---- Step 2: Genesis scenes + noise warpers ---- | |
| for case_dir in case_dirs: | |
| case_name = case_dir.name | |
| case_config = all_case_configs[case_name] | |
| sdedit_cfg = all_sdedit_cfgs[case_name] | |
| print(f"[2/6] Loading case '{case_name}' and building Genesis scene ...") | |
| log_gpu(f"before simulator init ({case_name})") | |
| config_overrides = {} | |
| if case_name == "santa_cloth": | |
| config_overrides["skip_force_fields"] = True | |
| simulator = InteractiveSimulator(str(case_dir), device="cpu", config_overrides=config_overrides) | |
| simulator.config["debug"] = False | |
| log_gpu(f"after simulator init ({case_name})") | |
| demo_case_handler = get_demo_case_handler(case_name, simulator.config) | |
| demo_case_handler.set_object_masks(simulator.object_masks_b64) | |
| simulator.set_demo_case_handler(demo_case_handler) | |
| noise_warper = StreamingNoiseWarper(crop_start=simulator.crop_start) | |
| log_gpu(f"after noise warper init ({case_name})") | |
| first_frame_path = _find_first_frame(case_dir, case_config) | |
| preview_pil = Image.open(first_frame_path).convert("RGB") | |
| default_prompt = simulator.config.get("vgen_prompt", "A video of physical simulation") | |
| num_blocks = sdedit_cfg["num_blocks"] | |
| cases[case_name] = CaseBundle( | |
| simulator=simulator, | |
| noise_warper=noise_warper, | |
| demo_case_handler=demo_case_handler, | |
| preview_pil=preview_pil, | |
| default_prompt=default_prompt, | |
| num_blocks=num_blocks, | |
| first_frame_path=first_frame_path, | |
| ) | |
| print(f" Case '{case_name}' ready.") | |
| # ---- Step 3: Pre-compute per-case embeddings ---- | |
| print("[3/6] Pre-computing first frame encoding for all cases ...") | |
| for case_dir in case_dirs: | |
| case_name = case_dir.name | |
| sdedit_cfg = all_sdedit_cfgs[case_name] | |
| bundle = cases[case_name] | |
| video_generator.precompute_case( | |
| case_name=case_name, | |
| first_frame_path=bundle.first_frame_path, | |
| default_prompt=bundle.default_prompt, | |
| sdedit_cfg=sdedit_cfg, | |
| ) | |
| log_gpu(f"after precompute_case ({case_name})") | |
| # ---- Step 4: Free processor models ---- | |
| print("[4/6] Freeing processor models ...") | |
| video_generator.finish_precompute() | |
| log_gpu("after finish_precompute") | |
| # ---- Step 5: Warmup ---- | |
| # Warmup (CUDA kernel compilation) is deferred to first generation call. | |
| print("[5/6] Skipping warmup at CPU-only startup — CUDA kernels compile on first generation.") | |
| torch.cuda.empty_cache() | |
| print("[6/6] CPU-only startup complete — models and scenes ready. GPU transfer at generation time.") | |
| # --------------------------------------------------------------------------- | |
| # Tensor helpers (identical logic to original app.py) | |
| # --------------------------------------------------------------------------- | |
| def _frames_to_tensor(frames_pil): | |
| """Convert list of PIL frames to tensor [1, C, T, H, W] in [-1, 1].""" | |
| arrays = [] | |
| for f in frames_pil: | |
| arr = np.array(f.convert("RGB")).astype(np.float32) / 127.5 - 1.0 | |
| arrays.append(torch.from_numpy(arr)) | |
| tensor = torch.stack(arrays, dim=0).permute(3, 0, 1, 2).contiguous() | |
| return tensor.unsqueeze(0) | |
| def _downsample_masks(masks, target_frames, crop_start=176, device="cuda"): | |
| """Downsample list of mask tensors to latent-space target_frames.""" | |
| if not masks or all(m is None for m in masks): | |
| return None | |
| processed = [] | |
| for m in masks: | |
| if m is None: | |
| processed.append(torch.zeros(1, 1, LATENT_H, LATENT_W, device=device)) | |
| continue | |
| if isinstance(m, torch.Tensor): | |
| m = m.to(device=device) | |
| if m.dim() == 3: | |
| m = m.squeeze(-1) | |
| m_832 = F.interpolate( | |
| m.float().unsqueeze(0).unsqueeze(0), | |
| size=(832, 832), mode="bilinear", align_corners=False, | |
| ) | |
| m_cropped = m_832[:, :, crop_start:crop_start + DEFAULT_HEIGHT, :] | |
| m_latent = F.interpolate( | |
| m_cropped, size=(LATENT_H, LATENT_W), | |
| mode="bilinear", align_corners=False, | |
| ) | |
| processed.append(m_latent) | |
| else: | |
| processed.append(torch.zeros(1, 1, LATENT_H, LATENT_W, device=device)) | |
| stacked = torch.cat(processed, dim=0) | |
| time_averaged = [] | |
| for i in range(0, stacked.shape[0], TEMPORAL_FACTOR): | |
| group = stacked[i:i + TEMPORAL_FACTOR] | |
| time_averaged.append(group.mean(dim=0, keepdim=True)) | |
| stacked = torch.cat(time_averaged, dim=0) | |
| if stacked.shape[0] > target_frames: | |
| stacked = stacked[:target_frames] | |
| elif stacked.shape[0] < target_frames: | |
| pad = target_frames - stacked.shape[0] | |
| stacked = torch.cat([stacked, stacked[-1:].repeat(pad, 1, 1, 1)], dim=0) | |
| result = stacked.squeeze(1).unsqueeze(0) | |
| return (result > 0.5).bool() | |
| # --------------------------------------------------------------------------- | |
| # Gradio event handlers | |
| # --------------------------------------------------------------------------- | |
| def on_case_change(case_name): | |
| """Return updated preview, prompt, and per-object control state.""" | |
| if not cases or case_name not in cases: | |
| no_vis = [gr.update(visible=False)] * MAX_OBJECTS | |
| no_radio = [gr.update(visible=False, value="none")] * MAX_OBJECTS | |
| no_slider = [gr.update(visible=False, value=0.0)] * MAX_OBJECTS | |
| return [None, ""] + no_vis + no_radio + no_slider | |
| bundle = cases[case_name] | |
| ui_cfg = bundle.demo_case_handler.get_ui_config() | |
| objects = ui_cfg["objects"] | |
| n_obj = len(objects) | |
| group_updates, radio_updates, slider_updates = [], [], [] | |
| for i in range(MAX_OBJECTS): | |
| if i < n_obj: | |
| obj = objects[i] | |
| group_updates.append(gr.update(visible=True)) | |
| radio_updates.append(gr.update( | |
| visible=True, | |
| value=obj.get("default_direction", "none"), | |
| label=f"Direction — {obj['label']}", | |
| )) | |
| slider_updates.append(gr.update( | |
| visible=True, | |
| value=obj.get("default_strength", 1.0), | |
| maximum=obj.get("max_strength", 2.0), | |
| label=f"Strength — {obj['label']}", | |
| )) | |
| else: | |
| group_updates.append(gr.update(visible=False)) | |
| radio_updates.append(gr.update(visible=False, value="none")) | |
| slider_updates.append(gr.update(visible=False, value=0.0)) | |
| return [bundle.preview_pil, bundle.default_prompt] + group_updates + radio_updates + slider_updates | |
| def do_generate(case_name, prompt, d0, s0, d1, s1, d2, s2): | |
| """Gradio generator: runs the 4-stage pipeline and yields frames. | |
| Decorated with @spaces.GPU so ZeroGPU holds the GPU for the entire | |
| generator lifetime. Precomputed case tensors are moved to CUDA at the | |
| start and back to CPU in the finally block so VRAM is released for other | |
| users when generation is not active. | |
| Stage 1a [thread]: Genesis physics steps → physics_queue | |
| Stage 1b [thread]: SVR render + optical flow → sim_queue | |
| Stage 2 [thread]: Noise warping → ready_queue | |
| Stage 3 [this generator]: VAE encode + SDEdit diffusion → yield frames | |
| """ | |
| global _is_generating, _stop_event | |
| with _gen_lock: | |
| if _is_generating: | |
| yield None, "Generation already in progress. Stop or reset first." | |
| return | |
| if not cases or case_name not in cases: | |
| yield None, "Error: no cases loaded." | |
| return | |
| _is_generating = True | |
| _stop_event.clear() | |
| if video_generator is None: | |
| _is_generating = False | |
| yield None, "Error: models not initialized. Please reload the Space." | |
| return | |
| # Transfer all CPU-resident state to GPU for this generation session. | |
| # NOTE: simulators are NOT moved to GPU — Genesis uses backend=gs.cpu and | |
| # simulation tensors must remain on CPU alongside Genesis internal state. | |
| video_generator.move_pipeline_to_device("cuda") | |
| video_generator.move_case_data_to_device("cuda") | |
| bundle = cases[case_name] | |
| # Build force configs from UI inputs | |
| ui_cfg = bundle.demo_case_handler.get_ui_config() | |
| n_obj = ui_cfg["num_objects"] | |
| dirs = [d0, d1, d2] | |
| strs = [s0, s1, s2] | |
| ui_forces = [ | |
| {"obj_idx": i, "direction": dirs[i], "strength": strs[i]} | |
| for i in range(n_obj) | |
| ] | |
| force_configs = bundle.demo_case_handler.get_force_config_from_ui(ui_forces) | |
| bundle.demo_case_handler.set_forces(force_configs) | |
| bundle.demo_case_handler.configure_simulation(bundle.simulator) | |
| yield None, "Forces configured. Starting generation..." | |
| physics_thread = render_thread = warp_thread = None | |
| try: | |
| bundle.noise_warper.reset() | |
| video_generator.prepare_generation(prompt, case_name) | |
| frame_steps = bundle.simulator.frame_steps | |
| num_blocks = bundle.num_blocks | |
| physics_queue = Queue(maxsize=2) | |
| sim_queue = Queue(maxsize=2) | |
| ready_queue = Queue(maxsize=3) | |
| # ---- Stage 1a: Physics ---- | |
| def physics_producer(): | |
| try: | |
| for block_idx in range(num_blocks): | |
| if _stop_event.is_set(): | |
| break | |
| n_pixel = FRAMES_FIRST_BLOCK_PIXEL if block_idx == 0 else FRAMES_PER_BLOCK_PIXEL | |
| for pf_idx in range(n_pixel): | |
| if _stop_event.is_set(): | |
| break | |
| last_i = frame_steps - 1 | |
| for i in range(frame_steps): | |
| updated_points = bundle.simulator.step(extract_points=(i == last_i)) | |
| frame_id = bundle.simulator.step_count | |
| item = (block_idx, n_pixel, pf_idx, updated_points, frame_id) | |
| while not _stop_event.is_set(): | |
| try: | |
| physics_queue.put(item, timeout=0.5) | |
| break | |
| except QueueFull: | |
| pass | |
| except Exception: | |
| import traceback; traceback.print_exc() | |
| finally: | |
| for _ in range(20): | |
| try: | |
| physics_queue.put(None, timeout=0.5) | |
| break | |
| except QueueFull: | |
| pass | |
| # ---- Stage 1b: Render + optical flow ---- | |
| def render_flow_producer(): | |
| try: | |
| current_block = -1 | |
| flows, sim_frames, fg_masks, mesh_masks = [], [], [], [] | |
| while not _stop_event.is_set(): | |
| try: | |
| item = physics_queue.get(timeout=0.5) | |
| except QueueEmpty: | |
| continue | |
| if item is None: | |
| break | |
| block_idx, n_pixel, pf_idx, updated_points, frame_id = item | |
| if block_idx != current_block: | |
| current_block = block_idx | |
| flows, sim_frames, fg_masks, mesh_masks = [], [], [], [] | |
| frame_pil, flow_2hw, fg_mask, mesh_mask = bundle.simulator.render_and_flow( | |
| updated_points, frame_id=frame_id, | |
| ) | |
| frame_pil = resize_and_crop_pil(frame_pil, start_y=bundle.simulator.crop_start) | |
| sim_frames.append(frame_pil) | |
| flows.append(flow_2hw) | |
| fg_masks.append(fg_mask) | |
| mesh_masks.append(mesh_mask) | |
| if len(sim_frames) == n_pixel: | |
| sim_queue.put((block_idx, flows, sim_frames, fg_masks, mesh_masks)) | |
| except Exception: | |
| import traceback; traceback.print_exc() | |
| finally: | |
| sim_queue.put(None) | |
| # ---- Stage 2: Noise warping ---- | |
| def noise_warp_stage(): | |
| try: | |
| while not _stop_event.is_set(): | |
| item = sim_queue.get() | |
| if item is None: | |
| break | |
| block_idx, flows, sim_frames, fg_masks, mesh_masks = item | |
| for flow in flows: | |
| bundle.noise_warper.warp_step(flow) | |
| structured_noise, sde_noise = bundle.noise_warper.get_block_noise(block_idx) | |
| ready_queue.put((block_idx, structured_noise, sde_noise, | |
| sim_frames, fg_masks, mesh_masks)) | |
| except Exception: | |
| import traceback; traceback.print_exc() | |
| finally: | |
| ready_queue.put(None) | |
| physics_thread = threading.Thread(target=physics_producer, daemon=True) | |
| render_thread = threading.Thread(target=render_flow_producer, daemon=True) | |
| warp_thread = threading.Thread(target=noise_warp_stage, daemon=True) | |
| physics_thread.start() | |
| render_thread.start() | |
| warp_thread.start() | |
| # ---- Stage 3: VAE encode + diffusion (main generator thread) ---- | |
| import time | |
| while not _stop_event.is_set(): | |
| try: | |
| item = ready_queue.get(timeout=120) | |
| except QueueEmpty: | |
| break | |
| if item is None: | |
| break | |
| block_idx, structured_noise, sde_noise, sim_frames, fg_masks, mesh_masks = item | |
| yield None, f"Block {block_idx + 1}/{num_blocks} — Generating..." | |
| # VAE encode simulation frames | |
| sim_frames_tensor = _frames_to_tensor(sim_frames) | |
| sim_latent = video_generator.pipeline.encode_vae.cached_encode_to_latent( | |
| sim_frames_tensor.to(device=video_generator.device, dtype=torch.bfloat16), | |
| is_first=(block_idx == 0), | |
| ) | |
| if sim_latent.shape[1] > FRAMES_PER_BLOCK: | |
| sim_latent = sim_latent[:, :FRAMES_PER_BLOCK] | |
| elif sim_latent.shape[1] < FRAMES_PER_BLOCK: | |
| pad = FRAMES_PER_BLOCK - sim_latent.shape[1] | |
| sim_latent = torch.cat( | |
| [sim_latent, sim_latent[:, -1:].repeat(1, pad, 1, 1, 1)], dim=1, | |
| ) | |
| # Build masks | |
| sim_mask = _downsample_masks(fg_masks, FRAMES_PER_BLOCK, | |
| crop_start=bundle.simulator.crop_start, | |
| device=video_generator.device) | |
| sim_franka_mask = _downsample_masks(mesh_masks, FRAMES_PER_BLOCK, | |
| crop_start=bundle.simulator.crop_start, | |
| device=video_generator.device) | |
| # Diffusion denoising | |
| pixel_frames = video_generator.generate_block( | |
| block_idx=block_idx, | |
| structured_noise=structured_noise, | |
| sim_latent=sim_latent, | |
| sde_noise=sde_noise, | |
| sim_mask=sim_mask, | |
| sim_franka_mask=sim_franka_mask, | |
| ) | |
| # Yield each decoded pixel frame | |
| for frame_np in pixel_frames: | |
| if _stop_event.is_set(): | |
| break | |
| yield Image.fromarray(frame_np), f"Block {block_idx + 1}/{num_blocks} — Streaming..." | |
| time.sleep(1.0 / FPS) | |
| if not _stop_event.is_set(): | |
| yield None, "Generation complete!" | |
| except GeneratorExit: | |
| # Gradio cancelled the generator (Stop button or new request) | |
| _stop_event.set() | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| yield None, f"Error: {e}" | |
| finally: | |
| _stop_event.set() | |
| if physics_thread is not None: | |
| physics_thread.join(timeout=10) | |
| if render_thread is not None: | |
| render_thread.join(timeout=10) | |
| if warp_thread is not None: | |
| warp_thread.join(timeout=10) | |
| if video_generator is not None: | |
| video_generator.move_pipeline_to_device("cpu") | |
| video_generator.move_case_data_to_device("cpu") | |
| torch.cuda.empty_cache() | |
| _is_generating = False | |
| def do_stop(): | |
| """Signal the generation loop to stop.""" | |
| _stop_event.set() | |
| return "Stopping..." | |
| def do_reset(case_name): | |
| """Reset simulation and generator state, return preview image.""" | |
| global _is_generating | |
| _stop_event.set() | |
| if cases and case_name in cases: | |
| bundle = cases[case_name] | |
| if bundle.simulator is not None: | |
| bundle.simulator.reset() | |
| if bundle.noise_warper is not None: | |
| bundle.noise_warper.reset() | |
| if video_generator is not None: | |
| video_generator.reset() | |
| _is_generating = False | |
| if cases and case_name in cases: | |
| return cases[case_name].preview_pil, "Reset complete. Ready to generate." | |
| return None, "Reset complete." | |
| # --------------------------------------------------------------------------- | |
| # Page-load initializer — CPU only, no GPU needed. | |
| # Reads configs and preview images from disk to populate the UI. | |
| # Heavy GPU work (model loading, scene init, precompute) is deferred to | |
| # the first do_generate call. | |
| # --------------------------------------------------------------------------- | |
| def _on_load(): | |
| """Lightweight CPU-only init: populate UI from configs on page load.""" | |
| import yaml | |
| if not DEMO_DATA_DIR.exists(): | |
| no_vis = [gr.update(visible=False)] * MAX_OBJECTS | |
| return ([gr.update(choices=[], value=None), None, "Error: demo_data not found"] | |
| + no_vis | |
| + [gr.update(visible=False, value="none")] * MAX_OBJECTS | |
| + [gr.update(visible=False, value=0.0)] * MAX_OBJECTS) | |
| case_dirs = sorted([d for d in DEMO_DATA_DIR.iterdir() | |
| if d.is_dir() and (d / "config.yaml").exists()]) | |
| for case_dir in case_dirs: | |
| case_name = case_dir.name | |
| if case_name in cases: | |
| continue # already populated (e.g. concurrent request) | |
| with open(case_dir / "config.yaml") as f: | |
| case_config = yaml.safe_load(f) | |
| sdedit_cfg = load_case_sdedit_config(case_config) | |
| demo_case_handler = get_demo_case_handler(case_name, case_config) | |
| # Object masks come from the simulator; set lazily when startup() runs. | |
| first_frame_path = _find_first_frame(case_dir, case_config) | |
| preview_pil = (Image.open(first_frame_path).convert("RGB") | |
| if Path(first_frame_path).exists() else None) | |
| default_prompt = case_config.get("vgen_prompt", "A video of physical simulation") | |
| cases[case_name] = CaseBundle( | |
| simulator=None, | |
| noise_warper=None, | |
| demo_case_handler=demo_case_handler, | |
| preview_pil=preview_pil, | |
| default_prompt=default_prompt, | |
| num_blocks=sdedit_cfg["num_blocks"], | |
| first_frame_path=first_frame_path, | |
| ) | |
| _case_names = list(cases.keys()) | |
| _case_choices = [(CASE_DISPLAY_NAMES.get(n, n), n) for n in _case_names] | |
| _first_case = _case_names[0] if _case_names else None | |
| case_update = gr.update(choices=_case_choices, value=_first_case, interactive=bool(_case_names)) | |
| if _first_case: | |
| on_change_result = on_case_change(_first_case) | |
| return [case_update] + on_change_result | |
| no_vis = [gr.update(visible=False)] * MAX_OBJECTS | |
| return ([case_update, None, ""] | |
| + no_vis | |
| + [gr.update(visible=False, value="none")] * MAX_OBJECTS | |
| + [gr.update(visible=False, value=0.0)] * MAX_OBJECTS) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_demo(): | |
| case_names = list(cases.keys()) | |
| case_choices = [(CASE_DISPLAY_NAMES.get(n, n), n) for n in case_names] | |
| first_case = case_names[0] if case_names else None | |
| first_bundle = cases[first_case] if first_case else None | |
| first_ui_cfg = (first_bundle.demo_case_handler.get_ui_config() | |
| if first_bundle else {"objects": []}) | |
| with gr.Blocks(title="RealWonder — Interactive Video Generation") as demo: | |
| gr.Markdown( | |
| "# 🎬 RealWonder — Interactive Video Generation\n" | |
| "Select a scene, configure a force, and watch physics-driven video generation in real time." | |
| ) | |
| with gr.Row(): | |
| # ---- Left column: controls ---- | |
| with gr.Column(scale=1, min_width=320): | |
| case_dropdown = gr.Dropdown( | |
| choices=case_choices, | |
| value=first_case, | |
| label="Scene", | |
| ) | |
| prompt_input = gr.Textbox( | |
| value=first_bundle.default_prompt if first_bundle else "", | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| gr.Markdown("### Force Controls") | |
| # Up to MAX_OBJECTS rows of (direction radio, strength slider). | |
| # We use gr.Group so we can toggle the whole row's visibility. | |
| obj_groups = [] | |
| dir_radios = [] | |
| str_sliders = [] | |
| for i in range(MAX_OBJECTS): | |
| obj = (first_ui_cfg["objects"][i] | |
| if i < len(first_ui_cfg["objects"]) else None) | |
| vis = obj is not None | |
| with gr.Group(visible=vis) as grp: | |
| label_text = obj["label"] if obj else f"Object {i}" | |
| dr = gr.Radio( | |
| choices=["left", "none", "right"], | |
| value=obj.get("default_direction", "none") if obj else "none", | |
| label=f"Direction — {label_text}", | |
| ) | |
| sl = gr.Slider( | |
| minimum=0.0, | |
| maximum=obj.get("max_strength", 2.0) if obj else 2.0, | |
| value=obj.get("default_strength", 1.0) if obj else 1.0, | |
| step=0.1, | |
| label=f"Strength — {label_text}", | |
| ) | |
| obj_groups.append(grp) | |
| dir_radios.append(dr) | |
| str_sliders.append(sl) | |
| with gr.Row(): | |
| start_btn = gr.Button("▶ Start", variant="primary") | |
| stop_btn = gr.Button("■ Stop") | |
| reset_btn = gr.Button("↺ Reset") | |
| status_box = gr.Textbox( | |
| label="Status", interactive=False, lines=1, value="Ready.", | |
| ) | |
| # ---- Right column: output ---- | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| value=first_bundle.preview_pil if first_bundle else None, | |
| label="Output", | |
| type="pil", | |
| height=480, | |
| show_download_button=True, | |
| ) | |
| # ---- Event wiring ---- | |
| # Case switch: update preview + prompt + per-object groups | |
| case_dropdown.change( | |
| fn=on_case_change, | |
| inputs=[case_dropdown], | |
| outputs=[output_image, prompt_input] | |
| + obj_groups + dir_radios + str_sliders, | |
| ) | |
| # Generation: stream frames + status updates | |
| gen_event = start_btn.click( | |
| fn=do_generate, | |
| inputs=[case_dropdown, prompt_input] + dir_radios + str_sliders, | |
| outputs=[output_image, status_box], | |
| ) | |
| # Stop: cancel the generator + update status | |
| stop_btn.click( | |
| fn=do_stop, | |
| inputs=[], | |
| outputs=[status_box], | |
| cancels=[gen_event], | |
| ) | |
| # Reset: cancel generator + reset state + restore preview | |
| reset_btn.click( | |
| fn=do_reset, | |
| inputs=[case_dropdown], | |
| outputs=[output_image, status_box], | |
| cancels=[gen_event], | |
| ) | |
| demo.load( | |
| fn=_on_load, | |
| inputs=[], | |
| outputs=[case_dropdown, output_image, prompt_input] | |
| + obj_groups + dir_radios + str_sliders, | |
| ) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| # Download model weights at module-load time (no GPU needed — pure network/disk). | |
| # This runs once when the Space container starts. On subsequent restarts the | |
| # files are already on disk so snapshot_download() is a fast no-op. By doing | |
| # this here we avoid holding a ZeroGPU allocation while waiting on downloads. | |
| _ensure_models_downloaded() | |
| _patch_genesis_from_torch() # Fix Genesis from_torch for PyTorch 2.5 compatibility | |
| startup() # Load all models and scenes to CPU at module level | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |