prithivMLmods commited on
Commit
4eeed04
ยท
verified ยท
1 Parent(s): 9cc968a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +508 -248
app.py CHANGED
@@ -1,264 +1,524 @@
1
- import argparse
2
- import gc
3
  import os
4
- import sys
 
 
 
 
5
 
6
  import gradio as gr
7
- import torch
8
  import spaces
9
- from omegaconf import OmegaConf
10
-
11
- # Handle cubvh dynamic installation to avoid build isolation errors on HF Spaces
12
- try:
13
- import cubvh
14
- except ImportError:
15
- import subprocess
16
- print("Installing cubvh...")
17
- subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/ashawkey/cubvh.git", "--no-build-isolation"])
18
-
19
- # Add project root to path so we can run from anywhere
20
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
21
-
22
- from ultrashape.rembg import BackgroundRemover
23
- from ultrashape.utils.misc import instantiate_from_config
24
- from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
25
- from ultrashape.utils import voxelize_from_point
26
- from ultrashape.pipelines import UltraShapePipeline
27
-
28
- # Global variables to cache the model
29
- MODEL_CACHE = {}
30
-
31
-
32
- def get_pipeline_cached(config_path, ckpt_path, device='cuda', low_vram=False):
33
- # Check if we have a valid cached pipeline for this checkpoint
34
- if "pipeline" in MODEL_CACHE and MODEL_CACHE.get("ckpt_path") == ckpt_path:
35
- print("Using cached pipeline...")
36
- return MODEL_CACHE["pipeline"], MODEL_CACHE["config"]
37
-
38
- # Clear old cache if it exists (e.g. different checkpoint)
39
- if MODEL_CACHE:
40
- print("Clearing old model cache...")
41
- MODEL_CACHE.clear()
42
- gc.collect()
43
- torch.cuda.empty_cache()
44
-
45
- print(f"Loading config from {config_path}...")
46
- config = OmegaConf.load(config_path)
47
-
48
- print("Instantiating VAE...")
49
- vae = instantiate_from_config(config.model.params.vae_config)
50
-
51
- print("Instantiating DiT...")
52
- dit = instantiate_from_config(config.model.params.dit_cfg)
53
-
54
- print("Instantiating Conditioner...")
55
- conditioner = instantiate_from_config(config.model.params.conditioner_config)
56
-
57
- print("Instantiating Scheduler & Processor...")
58
- scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
59
- image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
60
-
61
- print(f"Loading weights from {ckpt_path}...")
62
- weights = torch.load(ckpt_path, map_location='cpu')
63
-
64
- vae.load_state_dict(weights['vae'], strict=True)
65
- dit.load_state_dict(weights['dit'], strict=True)
66
- conditioner.load_state_dict(weights['conditioner'], strict=True)
67
-
68
- vae.eval().to(device)
69
- dit.eval().to(device)
70
- conditioner.eval().to(device)
71
-
72
- if hasattr(vae, 'enable_flashvdm_decoder'):
73
- vae.enable_flashvdm_decoder()
74
-
75
- print("Creating Pipeline...")
76
- pipeline = UltraShapePipeline(
77
- vae=vae,
78
- model=dit,
79
- scheduler=scheduler,
80
- conditioner=conditioner,
81
- image_processor=image_processor
82
- )
83
 
84
- if low_vram:
85
- pipeline.enable_model_cpu_offload()
86
-
87
- MODEL_CACHE["pipeline"] = pipeline
88
- MODEL_CACHE["config"] = config
89
- MODEL_CACHE["ckpt_path"] = ckpt_path
90
-
91
- return pipeline, config
92
-
93
-
94
- @spaces.GPU
95
- def predict(
96
- image_input,
97
- mesh_input,
98
- steps,
99
- scale,
100
- octree_res,
101
- num_latents,
102
- chunk_size,
103
- seed,
104
- remove_bg,
105
- ckpt_path,
106
- low_vram
 
 
 
 
 
 
107
  ):
108
- # Aggressive memory cleanup at start
109
- gc.collect()
110
- torch.cuda.empty_cache()
111
-
112
- try:
113
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
-
115
- # Resolve config path relative to this script
116
- base_dir = os.path.dirname(os.path.abspath(__file__))
117
- config_path = os.path.join(base_dir, "configs", "infer_dit_refine.yaml")
118
-
119
- if not os.path.exists(config_path):
120
- raise gr.Error(f"Config not found at {config_path}")
121
-
122
- if not os.path.exists(ckpt_path):
123
- raise gr.Error(f"Checkpoint not found at '{ckpt_path}'. Please provide a valid checkpoint path or download the model weights.")
124
-
125
- pipeline, config = get_pipeline_cached(config_path, ckpt_path, device, low_vram)
126
-
127
- voxel_res = config.model.params.vae_config.params.voxel_query_res
128
-
129
- print(f"Initializing Surface Loader (Token Num: {num_latents})...")
130
- loader = SharpEdgeSurfaceLoader(
131
- num_sharp_points=204800,
132
- num_uniform_points=204800,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
 
135
- print(f"Processing inputs...")
136
- if image_input is None:
137
- raise gr.Error("Image input is required")
138
- if mesh_input is None:
139
- raise gr.Error("Mesh input is required")
140
-
141
- image = image_input.convert("RGBA")
142
-
143
- if remove_bg or image.mode != 'RGBA':
144
- rembg = BackgroundRemover()
145
- image = rembg(image)
146
-
147
- # Handle mesh input - Gradio Model3D returns path to file
148
- surface = loader(mesh_input, normalize_scale=scale).to(device, dtype=torch.float16)
149
- pc = surface[:, :, :3] # [B, N, 3]
150
-
151
- # Voxelize
152
- _, voxel_idx = voxelize_from_point(pc, num_latents, resolution=voxel_res)
153
-
154
- print("Running diffusion process...")
155
- gen_device = "cpu" if low_vram else device
156
- generator = torch.Generator(gen_device).manual_seed(int(seed))
157
-
158
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
159
- mesh_out_list, _ = pipeline(
160
- image=image,
161
- voxel_cond=voxel_idx,
162
- generator=generator,
163
- box_v=1.0,
164
- mc_level=0.0,
165
- octree_resolution=int(octree_res),
166
- num_chunks=int(chunk_size),
167
- num_inference_steps=int(steps)
168
- )
169
 
170
- # Save output
171
- output_dir = os.path.join(base_dir, "outputs_gradio")
172
- os.makedirs(output_dir, exist_ok=True)
173
- base_name = "output"
174
- save_path = os.path.join(output_dir, f"{base_name}_refined.glb")
175
-
176
- mesh_out = mesh_out_list[0]
177
- mesh_out.export(save_path)
178
- print(f"Successfully saved to {save_path}")
179
-
180
- return save_path
181
-
182
- except Exception as e:
183
- import traceback
184
- traceback.print_exc()
185
- raise gr.Error(str(e))
186
- finally:
187
- # Aggressive memory cleanup at end
188
- gc.collect()
189
- torch.cuda.empty_cache()
190
-
191
-
192
- def main():
193
- parser = argparse.ArgumentParser(description="UltraShape Gradio App")
194
- parser.add_argument("--ckpt", type=str, default="checkpoints/ultrashape.pt", help="Path to split checkpoint (.pt)")
195
- parser.add_argument("--share", action="store_true", help="Share the gradio app")
196
- parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
197
-
198
- args = parser.parse_args()
199
-
200
- # Define a clean & modern Gradio Theme
201
- custom_theme = gr.themes.Soft(
202
- primary_hue="blue",
203
- secondary_hue="indigo",
204
- ).set(
205
- button_primary_background_fill="*primary_500",
206
- button_primary_background_fill_hover="*primary_600",
207
- )
208
 
209
- # Define Gradio Interface
210
- with gr.Blocks(title="UltraShape 1.0", theme=custom_theme) as demo:
211
- gr.Markdown(
212
- """
213
- <div style="text-align: center;">
214
- <h1>UltraShape 1.0</h1>
215
- <h3>High-Fidelity 3D Shape Generation via Scalable Geometric Refinement</h3>
216
- <p>Upload a reference image and a coarse mesh to generate a refined high-quality 3D shape.</p>
217
- </div>
218
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
 
221
- with gr.Row():
222
- with gr.Column(scale=1):
223
- gr.Markdown("### Input Parameters")
224
- image_input = gr.Image(type="pil", label="Reference Image", image_mode="RGBA")
225
- mesh_input = gr.Model3D(label="Coarse Mesh (.glb, .obj)")
226
-
227
- with gr.Accordion("Advanced Options", open=False):
228
- steps = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Inference Steps (12-50)")
229
- scale = gr.Slider(minimum=0.1, maximum=2.0, value=0.99, label="Mesh Normalization Scale")
230
- octree_res = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Octree Resolution")
231
- num_latents = gr.Slider(minimum=1024, maximum=32768, value=32768, step=128,
232
- label="Number of Latent Tokens (Decrease if OOM)")
233
- chunk_size = gr.Slider(minimum=512, maximum=10000, value=2048, step=512,
234
- label="Chunk Size (Decrease if OOM)")
235
- seed = gr.Number(value=42, label="Random Seed")
236
- remove_bg = gr.Checkbox(label="Remove Background from Image", value=False)
237
-
238
- run_btn = gr.Button("Generate Refined Shape", variant="primary", size="lg")
239
-
240
- with gr.Column(scale=1):
241
- gr.Markdown("### Refined Output")
242
- output_model = gr.Model3D(label="High-Fidelity Mesh", interactive=False)
243
-
244
- gr.Markdown(
245
- """
246
- *Note: If you encounter Out-of-Memory (OOM) errors, try checking the 'Advanced Options' and lowering the `Number of Latent Tokens` (e.g., 8192) and `Chunk Size` (e.g., 2000), or run the app with the `--low_vram` flag.*
247
- """
248
- )
249
 
250
- run_btn.click(
251
- fn=lambda img, mesh, s, sc, oct, nml, chk, sd, rm: predict(
252
- img, mesh, s, sc, oct, nml, chk, sd, rm, args.ckpt, args.low_vram
253
- ),
254
- inputs=[
255
- image_input, mesh_input, steps, scale, octree_res,
256
- num_latents, chunk_size, seed, remove_bg
257
- ],
258
- outputs=[output_model]
259
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- demo.launch(share=args.share, server_name='0.0.0.0', server_port=7860)
262
 
263
  if __name__ == "__main__":
264
- main()
 
 
 
 
1
  import os
2
+ import json
3
+ import tempfile
4
+ import subprocess
5
+ import shutil
6
+ from pathlib import Path
7
 
8
  import gradio as gr
 
9
  import spaces
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
15
+ # Lazy model loading (done inside GPU decorator)
16
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
17
+ _model_loaded = False
18
+
19
+
20
+ def _ensure_models():
21
+ global _model_loaded
22
+ if _model_loaded:
23
+ return
24
+ # Models are expected to be pre-downloaded to ./checkpoints/
25
+ # on the Space via the HF repo or a setup script.
26
+ _model_loaded = True
27
+
28
+
29
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
30
+ # Core inference helpers
31
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
32
+
33
+ def _write_caption_txt(image_path: str, caption: str) -> str:
34
+ """Write a .txt caption file beside the image and return the directory."""
35
+ img_path = Path(image_path)
36
+ txt_path = img_path.with_suffix(".txt")
37
+ txt_path.write_text(caption)
38
+ return str(img_path.parent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+
41
+ def _run(cmd: list[str], desc: str = "") -> tuple[bool, str]:
42
+ """Run a shell command and return (success, stderr/stdout)."""
43
+ print(f"[Lyra] {desc}: {' '.join(cmd)}")
44
+ result = subprocess.run(
45
+ cmd,
46
+ capture_output=True,
47
+ text=True,
48
+ env={**os.environ, "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
49
+ )
50
+ log = result.stdout + "\n" + result.stderr
51
+ return result.returncode == 0, log
52
+
53
+
54
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
55
+ # Zoom-in / Zoom-out trajectory (Option 1)
56
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
57
+
58
+ @spaces.GPU(duration=900)
59
+ def run_zoomgs(
60
+ image,
61
+ caption: str,
62
+ sample_id: int,
63
+ zoom_in_strength: float,
64
+ zoom_out_strength: float,
65
+ num_frames_in: int,
66
+ num_frames_out: int,
67
+ use_dmd: bool,
68
+ run_reconstruction: bool,
69
  ):
70
+ _ensure_models()
71
+
72
+ with tempfile.TemporaryDirectory() as tmp:
73
+ # Save uploaded image + caption
74
+ img_path = Path(tmp) / "input.png"
75
+ caption_path = Path(tmp) / "input.txt"
76
+
77
+ from PIL import Image
78
+ Image.fromarray(image).save(img_path)
79
+ caption_path.write_text(caption.strip() or "A scenic outdoor environment.")
80
+
81
+ output_dir = Path(tmp) / "outputs" / "zoomgs"
82
+ output_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ cmd = [
85
+ "python", "-m", "lyra_2._src.inference.lyra2_zoomgs_inference",
86
+ "--input_image_path", str(tmp),
87
+ "--sample_id", "0", # we always name it input.png โ†’ id 0 equivalent
88
+ "--experiment", "lyra2",
89
+ "--checkpoint_dir", "checkpoints/model",
90
+ "--prompt_dir", str(tmp),
91
+ "--output_path", str(output_dir),
92
+ "--num_frames_zoom_in", str(num_frames_in),
93
+ "--num_frames_zoom_out", str(num_frames_out),
94
+ "--zoom_in_strength", str(zoom_in_strength),
95
+ "--zoom_out_strength", str(zoom_out_strength),
96
+ ]
97
+ if use_dmd:
98
+ cmd.append("--use_dmd")
99
+
100
+ ok, log = _run(["env", "PYTHONPATH=."] + cmd[1:] if cmd[0] == "python" else cmd,
101
+ "ZoomGS video generation")
102
+
103
+ # Locate output video
104
+ video_path = output_dir / "0" / "videos" / "0.mp4"
105
+ if not video_path.exists():
106
+ # Fallback: search recursively
107
+ candidates = list(output_dir.rglob("*.mp4"))
108
+ video_path = candidates[0] if candidates else None
109
+
110
+ gs_video = None
111
+ if run_reconstruction and video_path and video_path.exists():
112
+ ok2, log2 = _run(
113
+ ["python", "-m", "lyra_2._src.inference.vipe_da3_gs_recon",
114
+ "--input_video_path", str(video_path)],
115
+ "GS reconstruction",
116
+ )
117
+ log += "\n" + log2
118
+ ply_candidates = list(output_dir.rglob("gs_trajectory.mp4"))
119
+ if ply_candidates:
120
+ gs_video = str(ply_candidates[0])
121
+
122
+ return (
123
+ str(video_path) if video_path and video_path.exists() else None,
124
+ gs_video,
125
+ log[-4000:],
126
  )
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
130
+ # Custom trajectory (Option 2)
131
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ @spaces.GPU(duration=900)
134
+ def run_custom_traj(
135
+ image,
136
+ trajectory_file,
137
+ captions_json: str,
138
+ num_frames: int,
139
+ pose_scale: float,
140
+ use_dmd: bool,
141
+ run_reconstruction: bool,
142
+ ):
143
+ _ensure_models()
144
+
145
+ with tempfile.TemporaryDirectory() as tmp:
146
+ from PIL import Image
147
+ img_path = Path(tmp) / "first_frame.png"
148
+ Image.fromarray(image).save(img_path)
149
+
150
+ traj_path = Path(tmp) / "trajectory.npz"
151
+ shutil.copy(trajectory_file.name, traj_path)
152
+
153
+ captions_path = Path(tmp) / "captions.json"
154
+ try:
155
+ json.loads(captions_json) # validate
156
+ captions_path.write_text(captions_json)
157
+ except json.JSONDecodeError:
158
+ captions_path.write_text(json.dumps({"0": captions_json}))
159
+
160
+ output_dir = Path(tmp) / "outputs" / "custom"
161
+ output_dir.mkdir(parents=True, exist_ok=True)
162
+
163
+ cmd = [
164
+ "python", "-m", "lyra_2._src.inference.lyra2_custom_traj_inference",
165
+ "--input_image_path", str(img_path),
166
+ "--trajectory_path", str(traj_path),
167
+ "--experiment", "lyra2",
168
+ "--checkpoint_dir", "checkpoints/model",
169
+ "--captions_path", str(captions_path),
170
+ "--num_frames", str(num_frames),
171
+ "--output_path", str(output_dir),
172
+ "--pose_scale", str(pose_scale),
173
+ ]
174
+ if use_dmd:
175
+ cmd.append("--use_dmd")
176
+
177
+ ok, log = _run(cmd, "Custom trajectory video generation")
178
+
179
+ video_candidates = list(output_dir.rglob("*.mp4"))
180
+ video_path = video_candidates[0] if video_candidates else None
181
+
182
+ gs_video = None
183
+ if run_reconstruction and video_path:
184
+ ok2, log2 = _run(
185
+ ["python", "-m", "lyra_2._src.inference.vipe_da3_gs_recon",
186
+ "--input_video_path", str(video_path)],
187
+ "GS reconstruction",
188
+ )
189
+ log += "\n" + log2
190
+ ply_candidates = list(output_dir.rglob("gs_trajectory.mp4"))
191
+ if ply_candidates:
192
+ gs_video = str(ply_candidates[0])
193
+
194
+ return (
195
+ str(video_path) if video_path else None,
196
+ gs_video,
197
+ log[-4000:],
198
  )
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
202
+ # UI
203
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
204
+
205
+ CSS = """
206
+ /* โ”€โ”€ Global reset & fonts โ”€โ”€ */
207
+ @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=DM+Mono:wght@300;400;500&display=swap');
208
+
209
+ :root {
210
+ --bg: #0a0c10;
211
+ --surface: #111318;
212
+ --border: #1e2230;
213
+ --accent: #5affb0;
214
+ --accent2: #a78bfa;
215
+ --text: #e8eaf0;
216
+ --muted: #5a5f72;
217
+ --radius: 12px;
218
+ --font-head: 'Syne', sans-serif;
219
+ --font-mono: 'DM Mono', monospace;
220
+ }
221
+
222
+ body, .gradio-container {
223
+ background: var(--bg) !important;
224
+ color: var(--text) !important;
225
+ font-family: var(--font-head) !important;
226
+ }
227
+
228
+ /* Header banner */
229
+ #header {
230
+ background: linear-gradient(135deg, #0d1117 0%, #161b27 60%, #0f1520 100%);
231
+ border: 1px solid var(--border);
232
+ border-radius: var(--radius);
233
+ padding: 32px 40px 28px;
234
+ margin-bottom: 24px;
235
+ position: relative;
236
+ overflow: hidden;
237
+ }
238
+ #header::before {
239
+ content: '';
240
+ position: absolute;
241
+ inset: 0;
242
+ background: radial-gradient(ellipse 70% 60% at 80% 50%, rgba(94,255,176,0.06) 0%, transparent 70%),
243
+ radial-gradient(ellipse 50% 80% at 20% 80%, rgba(167,139,250,0.06) 0%, transparent 70%);
244
+ pointer-events: none;
245
+ }
246
+ #header h1 {
247
+ font-size: 2.4rem;
248
+ font-weight: 800;
249
+ letter-spacing: -0.02em;
250
+ margin: 0 0 8px;
251
+ background: linear-gradient(90deg, var(--accent) 0%, var(--accent2) 100%);
252
+ -webkit-background-clip: text;
253
+ -webkit-text-fill-color: transparent;
254
+ }
255
+ #header p {
256
+ color: var(--muted);
257
+ font-family: var(--font-mono);
258
+ font-size: 0.85rem;
259
+ margin: 0;
260
+ letter-spacing: 0.02em;
261
+ }
262
+ #header .badge {
263
+ display: inline-block;
264
+ margin-right: 8px;
265
+ padding: 3px 10px;
266
+ background: rgba(94,255,176,0.1);
267
+ border: 1px solid rgba(94,255,176,0.25);
268
+ border-radius: 20px;
269
+ color: var(--accent);
270
+ font-size: 0.75rem;
271
+ font-family: var(--font-mono);
272
+ }
273
+
274
+ /* Tabs */
275
+ .tab-nav button {
276
+ background: transparent !important;
277
+ border: none !important;
278
+ border-bottom: 2px solid transparent !important;
279
+ color: var(--muted) !important;
280
+ font-family: var(--font-head) !important;
281
+ font-weight: 600 !important;
282
+ font-size: 0.95rem !important;
283
+ padding: 10px 20px !important;
284
+ transition: all .2s !important;
285
+ }
286
+ .tab-nav button.selected, .tab-nav button:hover {
287
+ color: var(--accent) !important;
288
+ border-bottom-color: var(--accent) !important;
289
+ background: transparent !important;
290
+ }
291
+
292
+ /* Panels / blocks */
293
+ .gr-panel, .gr-box, .gradio-group {
294
+ background: var(--surface) !important;
295
+ border: 1px solid var(--border) !important;
296
+ border-radius: var(--radius) !important;
297
+ }
298
+
299
+ /* Inputs */
300
+ input, textarea, .gr-input, .gr-textbox textarea {
301
+ background: #0d0f14 !important;
302
+ border: 1px solid var(--border) !important;
303
+ color: var(--text) !important;
304
+ font-family: var(--font-mono) !important;
305
+ border-radius: 8px !important;
306
+ }
307
+ input:focus, textarea:focus {
308
+ border-color: var(--accent) !important;
309
+ box-shadow: 0 0 0 2px rgba(94,255,176,0.12) !important;
310
+ }
311
+
312
+ /* Sliders */
313
+ input[type=range] { accent-color: var(--accent) !important; }
314
+
315
+ /* Buttons */
316
+ button.primary, .gr-button-primary {
317
+ background: linear-gradient(135deg, var(--accent) 0%, #38d9a9 100%) !important;
318
+ color: #0a0c10 !important;
319
+ font-family: var(--font-head) !important;
320
+ font-weight: 700 !important;
321
+ border: none !important;
322
+ border-radius: 8px !important;
323
+ padding: 12px 28px !important;
324
+ font-size: 0.95rem !important;
325
+ letter-spacing: 0.01em !important;
326
+ transition: opacity .2s !important;
327
+ }
328
+ button.primary:hover { opacity: 0.85 !important; }
329
+
330
+ button.secondary, .gr-button-secondary {
331
+ background: transparent !important;
332
+ border: 1px solid var(--border) !important;
333
+ color: var(--muted) !important;
334
+ font-family: var(--font-head) !important;
335
+ border-radius: 8px !important;
336
+ }
337
+
338
+ /* Labels */
339
+ label, .gr-form > label, .block > label span {
340
+ color: var(--muted) !important;
341
+ font-family: var(--font-mono) !important;
342
+ font-size: 0.8rem !important;
343
+ letter-spacing: 0.04em !important;
344
+ text-transform: uppercase !important;
345
+ }
346
+
347
+ /* Log box */
348
+ #log-box textarea {
349
+ font-size: 0.78rem !important;
350
+ color: #7af0b0 !important;
351
+ background: #060709 !important;
352
+ }
353
+
354
+ /* Accordion */
355
+ .gr-accordion { border-color: var(--border) !important; }
356
+
357
+ /* Info note */
358
+ .info-note {
359
+ background: rgba(167,139,250,0.07);
360
+ border: 1px solid rgba(167,139,250,0.2);
361
+ border-radius: 8px;
362
+ padding: 12px 16px;
363
+ font-family: var(--font-mono);
364
+ font-size: 0.8rem;
365
+ color: #c4b5fd;
366
+ line-height: 1.6;
367
+ }
368
+ """
369
+
370
+
371
+ def build_app():
372
+ with gr.Blocks(css=CSS, title="Lyra 2.0 โ€” Explorable 3D Worlds") as demo:
373
+
374
+ # โ”€โ”€ Header โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
375
+ gr.HTML("""
376
+ <div id="header">
377
+ <h1>โœฆ Lyra 2.0</h1>
378
+ <p>
379
+ <span class="badge">NVIDIA Research</span>
380
+ <span class="badge">3D Gaussian Splatting</span>
381
+ <span class="badge">arXiv 2604.13036</span>
382
+ </p>
383
+ <p style="margin-top:14px; color:#8892a4; font-size:0.9rem; font-family:'Syne',sans-serif;">
384
+ Generate persistent, explorable 3D worlds from a single image.
385
+ Walk through scenes, revisit areas โ€” no spatial forgetting, no temporal drift.
386
+ </p>
387
+ </div>
388
+ """)
389
+
390
+ # โ”€โ”€ Tabs โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
391
+ with gr.Tabs():
392
+
393
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
394
+ # TAB 1 โ€” Zoom-in / Zoom-out
395
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
396
+ with gr.Tab("๐Ÿ”ญ Zoom Trajectory"):
397
+ gr.HTML('<div class="info-note">Generate a zoom-in โ†’ zoom-out exploration video from a single image, then optionally lift it to a 3D Gaussian Splatting scene.</div>')
398
+
399
+ with gr.Row():
400
+ with gr.Column(scale=1):
401
+ z_image = gr.Image(label="Input Image", type="numpy", height=280)
402
+ z_caption = gr.Textbox(
403
+ label="Scene Caption",
404
+ placeholder="A sunlit forest clearing with tall pine treesโ€ฆ",
405
+ lines=2,
406
+ )
407
+ with gr.Accordion("Advanced Options", open=False):
408
+ with gr.Row():
409
+ z_in_str = gr.Slider(0.1, 3.0, value=0.5, step=0.1, label="Zoom-in Strength")
410
+ z_out_str = gr.Slider(0.1, 3.0, value=1.5, step=0.1, label="Zoom-out Strength")
411
+ with gr.Row():
412
+ z_frames_in = gr.Slider(81, 401, value=81, step=80, label="Frames Zoom-in (1+80k)")
413
+ z_frames_out = gr.Slider(81, 401, value=241, step=80, label="Frames Zoom-out (1+80k)")
414
+ with gr.Row():
415
+ z_dmd = gr.Checkbox(label="โšก Fast Mode (DMD ร—15 speedup, lower quality)", value=False)
416
+ z_recon = gr.Checkbox(label="๐ŸงŠ Run 3DGS Reconstruction after video", value=True)
417
+ z_btn = gr.Button("Generate World", variant="primary")
418
+
419
+ with gr.Column(scale=1):
420
+ z_video = gr.Video(label="Generated Exploration Video", height=280)
421
+ z_gs_vid = gr.Video(label="3DGS Flythrough (if reconstruction enabled)", height=280)
422
+ z_log = gr.Textbox(label="Log", lines=6, interactive=False, elem_id="log-box")
423
+
424
+ z_btn.click(
425
+ fn=run_zoomgs,
426
+ inputs=[z_image, z_caption, gr.State(0),
427
+ z_in_str, z_out_str, z_frames_in, z_frames_out,
428
+ z_dmd, z_recon],
429
+ outputs=[z_video, z_gs_vid, z_log],
430
+ )
431
+
432
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
433
+ # TAB 2 โ€” Custom Trajectory
434
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
435
+ with gr.Tab("๐ŸŽฎ Custom Trajectory"):
436
+ gr.HTML('<div class="info-note">Provide your own camera trajectory (.npz with <code>w2c</code>, <code>intrinsics</code>, <code>image_height</code>, <code>image_width</code>) and per-chunk captions (JSON keyed by frame index, e.g. <code>{"0": "โ€ฆ", "81": "โ€ฆ"}</code>).</div>')
437
+
438
+ with gr.Row():
439
+ with gr.Column(scale=1):
440
+ c_image = gr.Image(label="First Frame", type="numpy", height=240)
441
+ c_traj = gr.File(label="Trajectory (.npz)", file_types=[".npz"])
442
+ c_captions = gr.Textbox(
443
+ label='Per-chunk Captions (JSON or single string)',
444
+ placeholder='{"0": "A grand hall interior", "81": "Corridor leading outside"}',
445
+ lines=3,
446
+ )
447
+ with gr.Accordion("Advanced Options", open=False):
448
+ with gr.Row():
449
+ c_frames = gr.Slider(81, 961, value=481, step=80, label="Num Frames (1+80k)")
450
+ c_pose_scale = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Pose Scale")
451
+ with gr.Row():
452
+ c_dmd = gr.Checkbox(label="โšก Fast Mode (DMD)", value=False)
453
+ c_recon = gr.Checkbox(label="๐ŸงŠ Run 3DGS Reconstruction", value=True)
454
+ c_btn = gr.Button("Generate World", variant="primary")
455
+
456
+ with gr.Column(scale=1):
457
+ c_video = gr.Video(label="Generated Video", height=260)
458
+ c_gs_vid = gr.Video(label="3DGS Flythrough", height=260)
459
+ c_log = gr.Textbox(label="Log", lines=6, interactive=False, elem_id="log-box")
460
+
461
+ c_btn.click(
462
+ fn=run_custom_traj,
463
+ inputs=[c_image, c_traj, c_captions,
464
+ c_frames, c_pose_scale, c_dmd, c_recon],
465
+ outputs=[c_video, c_gs_vid, c_log],
466
+ )
467
+
468
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
469
+ # TAB 3 โ€” Model Info
470
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
471
+ with gr.Tab("โ„น๏ธ About"):
472
+ gr.Markdown("""
473
+ ## Lyra 2.0 โ€” Explorable Generative 3D Worlds
474
+
475
+ **NVIDIA Research** ยท [Paper](https://arxiv.org/abs/2604.13036) ยท [Project Page](https://research.nvidia.com/labs/sil/projects/lyra2/) ยท [HuggingFace](https://huggingface.co/nvidia/Lyra-2.0)
476
+
477
+ ### How it works
478
+
479
+ Lyra 2.0 solves two fundamental failure modes of long-horizon 3D world generation:
480
+
481
+ | Problem | Solution |
482
+ |---|---|
483
+ | **Spatial Forgetting** โ€” previously seen regions fall out of context and are hallucinated on revisit | Per-frame 3D geometry used for information routing โ€” retrieve past frames and establish dense correspondences |
484
+ | **Temporal Drifting** โ€” autoregressive errors accumulate and distort appearance/geometry | Self-augmented training histories expose the model to its own degraded outputs, teaching correction not propagation |
485
+
486
+ The generated video is then lifted to a **3D Gaussian Splatting** scene via VIPE pose estimation + Depth Anything 3 depth.
487
+
488
+ ### GPU Requirements
489
+
490
+ - Recommended: **H100 80 GB** (or A100 80 GB)
491
+ - ~9 min per 80 frames at full quality ยท ~35 s with `--use_dmd` (DMD fast mode)
492
+ - GS reconstruction adds ~1 min on top
493
+
494
+ ### Checkpoint Setup
495
+
496
+ Checkpoints are expected at `./checkpoints/model/`.
497
+ Download from HuggingFace:
498
+
499
+ ```bash
500
+ huggingface-cli download nvidia/Lyra-2.0 \\
501
+ --include "checkpoints/*" \\
502
+ --local-dir .
503
+ ```
504
+
505
+ ### Citation
506
+
507
+ ```bibtex
508
+ @article{shen2026lyra2,
509
+ title={Lyra 2.0: Explorable Generative 3D Worlds},
510
+ author={Shen, Tianchang and Bahmani, Sherwin and He, Kai and ...},
511
+ journal={arXiv preprint arXiv:2604.13036},
512
+ year={2026}
513
+ }
514
+ ```
515
+
516
+ *Model weights released under [NVIDIA Internal Scientific Research and Development Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-internal-scientific-research-and-development-model-license/).*
517
+ """)
518
+
519
+ return demo
520
 
 
521
 
522
  if __name__ == "__main__":
523
+ demo = build_app()
524
+ demo.launch()