Stable-X commited on
Commit
2f72d6c
·
verified ·
1 Parent(s): eb16685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -97
app.py CHANGED
@@ -1,113 +1,439 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name: str, enthusiasm: int = 1) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
- Enhanced greeting function with customizable enthusiasm.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- if not name.strip():
8
- return "Please enter a name!"
9
- greeting = f"Hello, {name}!"
10
- exclamation = "!" * enthusiasm
11
- return f"{greeting}{exclamation} Welcome to your Gradio 6 application."
12
-
13
- # Version 1: Enhanced with Examples and Validation
14
- with gr.Blocks() as demo:
15
-
16
- gr.HTML(
17
- """
18
- <div style="text-align: center; margin-bottom: 20px;">
19
- <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="text-decoration: none; color: #007bff; font-weight: bold;">
20
- Built with anycoder
21
- </a>
22
- </div>
23
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- gr.Markdown("# 🚀 Enhanced Gradio 6 App")
27
- gr.Markdown("Enter your name and choose enthusiasm level for a personalized greeting.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  with gr.Row():
30
  with gr.Column():
31
- name_input = gr.Textbox(
32
- label="Your Name",
33
- placeholder="Type your name here...",
34
- autofocus=True,
35
- max_length=50
36
- )
37
- enthusiasm_slider = gr.Slider(
38
- label="Enthusiasm Level",
39
- minimum=1,
40
- maximum=5,
41
- value=1,
42
- step=1,
43
- info="How excited should the greeting be?"
44
- )
45
- submit_btn = gr.Button("Generate Greeting", variant="primary", size="lg")
46
-
47
- # Add examples
48
- examples = gr.Examples(
49
- examples=[
50
- ["Alice", 3],
51
- ["Bob", 1],
52
- ["Charlie", 5]
53
- ],
54
- inputs=[name_input, enthusiasm_slider],
55
- label="Quick Examples"
56
- )
57
 
 
 
 
 
 
 
 
 
 
 
58
  with gr.Column():
59
- greeting_output = gr.Textbox(
60
- label="Output",
61
- interactive=False,
62
- lines=2,
63
- show_copy_button=True
64
- )
65
- # Add a stats component
66
- char_count = gr.Number(label="Character Count", interactive=False)
67
-
68
- # Event listeners with Gradio 6 syntax
69
- def process_greeting(name, enthusiasm):
70
- result = greet(name, enthusiasm)
71
- return result, len(result)
72
-
73
- submit_btn.click(
74
- fn=process_greeting,
75
- inputs=[name_input, enthusiasm_slider],
76
- outputs=[greeting_output, char_count],
77
- api_visibility="public"
78
- )
79
 
80
- # Real-time updates
81
- name_input.change(
82
- fn=process_greeting,
83
- inputs=[name_input, enthusiasm_slider],
84
- outputs=[greeting_output, char_count],
85
- api_visibility="private"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
 
88
- enthusiasm_slider.change(
89
- fn=process_greeting,
90
- inputs=[name_input, enthusiasm_slider],
91
- outputs=[greeting_output, char_count],
92
- api_visibility="private"
 
 
 
 
 
 
 
93
  )
 
94
 
95
- demo.launch(
96
- theme=gr.themes.Soft(
97
- primary_hue="blue",
98
- secondary_hue="cyan",
99
- neutral_hue="slate",
100
- font=gr.themes.GoogleFont("Inter"),
101
- text_size="lg",
102
- spacing_size="lg",
103
- radius_size="md"
104
- ).set(
105
- button_primary_background_fill="*primary_600",
106
- button_primary_background_fill_hover="*primary_700"
107
- ),
108
- footer_links=[
109
- {"label": "Gradio", "url": "https://gradio.app"},
110
- {"label": "GitHub", "url": "https://github.com/gradio-app/gradio"}
111
- ],
112
- show_error=True
113
- )
 
1
  import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
 
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from trellis.pipelines import TrellisVGGTTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
17
+
18
+
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
+ # TMP_DIR = "tmp/Trellis-demo"
23
+ # os.environ['GRADIO_TEMP_DIR'] = 'tmp'
24
+ os.makedirs(TMP_DIR, exist_ok=True)
25
+
26
+ def start_session(req: gr.Request):
27
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
+ os.makedirs(user_dir, exist_ok=True)
29
+
30
+
31
+ def end_session(req: gr.Request):
32
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ shutil.rmtree(user_dir)
34
+
35
+ @spaces.GPU
36
+ def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
38
+ Preprocess the input image for 3D generation.
39
+
40
+ This function is called when a user uploads an image or selects an example.
41
+ It applies background removal and other preprocessing steps necessary for
42
+ optimal 3D model generation.
43
+ Args:
44
+ image (Image.Image): The input image from the user
45
+ Returns:
46
+ Image.Image: The preprocessed image ready for 3D generation
47
+ """
48
+ processed_image = pipeline.preprocess_image(image)
49
+ return processed_image
50
+
51
+ @spaces.GPU
52
+ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
53
  """
54
+ Preprocess the input video for multi-image 3D generation.
55
+
56
+ This function is called when a user uploads a video.
57
+ It extracts frames from the video and processes each frame to prepare them
58
+ for the multi-image 3D generation pipeline.
59
+
60
+ Args:
61
+ video (str): The path to the input video file
62
+
63
+ Returns:
64
+ List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
65
+ """
66
+ vid = imageio.get_reader(video, 'ffmpeg')
67
+ fps = vid.get_meta_data()['fps']
68
+ images = []
69
+ for i, frame in enumerate(vid):
70
+ if i % max(int(fps * 1), 1) == 0:
71
+ img = Image.fromarray(frame)
72
+ W, H = img.size
73
+ img = img.resize((int(W / H * 512), 512))
74
+ images.append(img)
75
+ vid.close()
76
+ processed_images = [pipeline.preprocess_image(image) for image in images]
77
+ return processed_images
78
+
79
+ @spaces.GPU
80
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
81
+ """
82
+ Preprocess a list of input images for multi-image 3D generation.
83
+
84
+ This function is called when users upload multiple images in the gallery.
85
+ It processes each image to prepare them for the multi-image 3D generation pipeline.
86
+
87
+ Args:
88
+ images (List[Tuple[Image.Image, str]]): The input images from the gallery
89
+
90
+ Returns:
91
+ List[Image.Image]: The preprocessed images ready for 3D generation
92
+ """
93
+ images = [image[0] for image in images]
94
+ processed_images = [pipeline.preprocess_image(image) for image in images]
95
+ return processed_images
96
+
97
+
98
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
99
+ return {
100
+ 'gaussian': {
101
+ **gs.init_params,
102
+ '_xyz': gs._xyz.cpu().numpy(),
103
+ '_features_dc': gs._features_dc.cpu().numpy(),
104
+ '_scaling': gs._scaling.cpu().numpy(),
105
+ '_rotation': gs._rotation.cpu().numpy(),
106
+ '_opacity': gs._opacity.cpu().numpy(),
107
+ },
108
+ 'mesh': {
109
+ 'vertices': mesh.vertices.cpu().numpy(),
110
+ 'faces': mesh.faces.cpu().numpy(),
111
+ },
112
+ }
113
+
114
+
115
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
116
+ gs = Gaussian(
117
+ aabb=state['gaussian']['aabb'],
118
+ sh_degree=state['gaussian']['sh_degree'],
119
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
120
+ scaling_bias=state['gaussian']['scaling_bias'],
121
+ opacity_bias=state['gaussian']['opacity_bias'],
122
+ scaling_activation=state['gaussian']['scaling_activation'],
123
+ )
124
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
125
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
126
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
127
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
128
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
129
+
130
+ mesh = edict(
131
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
132
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
133
+ )
134
+
135
+ return gs, mesh
136
+
137
+
138
+ def get_seed(randomize_seed: bool, seed: int) -> int:
139
+ """
140
+ Get the random seed for generation.
141
+
142
+ This function is called by the generate button to determine whether to use
143
+ a random seed or the user-specified seed value.
144
+
145
+ Args:
146
+ randomize_seed (bool): Whether to generate a random seed
147
+ seed (int): The user-specified seed value
148
+
149
+ Returns:
150
+ int: The seed to use for generation
151
+ """
152
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
153
+
154
+
155
+ @spaces.GPU(duration=120)
156
+ def generate_and_extract_glb(
157
+ multiimages: List[Tuple[Image.Image, str]],
158
+ seed: int,
159
+ ss_guidance_strength: float,
160
+ ss_sampling_steps: int,
161
+ slat_guidance_strength: float,
162
+ slat_sampling_steps: int,
163
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
164
+ mesh_simplify: float,
165
+ texture_size: int,
166
+ req: gr.Request,
167
+ ) -> Tuple[dict, str, str, str]:
168
+ """
169
+ Convert an image to a 3D model and extract GLB file.
170
+ Args:
171
+ image (Image.Image): The input image.
172
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
173
+ is_multiimage (bool): Whether is in multi-image mode.
174
+ seed (int): The random seed.
175
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
176
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
177
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
178
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
179
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
180
+ mesh_simplify (float): The mesh simplification factor.
181
+ texture_size (int): The texture resolution.
182
+ Returns:
183
+ dict: The information of the generated 3D model.
184
+ str: The path to the video of the 3D model.
185
+ str: The path to the extracted GLB file.
186
+ str: The path to the extracted GLB file (for download).
187
+ """
188
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
189
+ image_files = [image[0] for image in multiimages]
190
+
191
+ # Generate 3D model
192
+ outputs, _, _ = pipeline.run(
193
+ image=image_files,
194
+ seed=seed,
195
+ formats=["gaussian", "mesh"],
196
+ preprocess_image=False,
197
+ sparse_structure_sampler_params={
198
+ "steps": ss_sampling_steps,
199
+ "cfg_strength": ss_guidance_strength,
200
+ },
201
+ slat_sampler_params={
202
+ "steps": slat_sampling_steps,
203
+ "cfg_strength": slat_guidance_strength,
204
+ },
205
+ mode=multiimage_algo,
206
  )
207
+
208
+ # Render video
209
+ # import uuid
210
+ # output_id = str(uuid.uuid4())
211
+ # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
212
+ # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
213
+ # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
214
+
215
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
216
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
217
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
218
+ video_path = os.path.join(user_dir, 'sample.mp4')
219
+ imageio.mimsave(video_path, video, fps=15)
220
+
221
+ # Extract GLB
222
+ gs = outputs['gaussian'][0]
223
+ mesh = outputs['mesh'][0]
224
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
225
+ glb_path = os.path.join(user_dir, 'sample.glb')
226
+ glb.export(glb_path)
227
+
228
+ # Pack state for optional Gaussian extraction
229
+ state = pack_state(gs, mesh)
230
+
231
+ torch.cuda.empty_cache()
232
+ return state, video_path, glb_path, glb_path
233
+
234
+
235
+ @spaces.GPU
236
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
237
+ """
238
+ Extract a Gaussian splatting file from the generated 3D model.
239
 
240
+ This function is called when the user clicks "Extract Gaussian" button.
241
+ It converts the 3D model state into a .ply file format containing
242
+ Gaussian splatting data for advanced 3D applications.
243
+ Args:
244
+ state (dict): The state of the generated 3D model containing Gaussian data
245
+ req (gr.Request): Gradio request object for session management
246
+ Returns:
247
+ Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
248
+ """
249
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
250
+ gs, _ = unpack_state(state)
251
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
252
+ gs.save_ply(gaussian_path)
253
+ torch.cuda.empty_cache()
254
+ return gaussian_path, gaussian_path
255
+
256
+
257
+ def prepare_multi_example() -> List[Image.Image]:
258
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
259
+ images = []
260
+ for case in multi_case:
261
+ _images = []
262
+ for i in range(1, 9):
263
+ if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
264
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
265
+ W, H = img.size
266
+ img = img.resize((int(W / H * 512), 512))
267
+ _images.append(np.array(img))
268
+ if len(_images) > 0:
269
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
270
+ return images
271
+
272
+
273
+ def split_image(image: Image.Image) -> List[Image.Image]:
274
+ """
275
+ Split a multi-view image into separate view images.
276
+
277
+ This function is called when users select multi-image examples that contain
278
+ multiple views in a single concatenated image. It automatically splits them
279
+ based on alpha channel boundaries and preprocesses each view.
280
+
281
+ Args:
282
+ image (Image.Image): A concatenated image containing multiple views
283
+
284
+ Returns:
285
+ List[Image.Image]: List of individual preprocessed view images
286
+ """
287
+ image = np.array(image)
288
+ alpha = image[..., 3]
289
+ alpha = np.any(alpha>0, axis=0)
290
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
291
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
292
+ images = []
293
+ for s, e in zip(start_pos, end_pos):
294
+ images.append(Image.fromarray(image[:, s:e+1]))
295
+ return [preprocess_image(image) for image in images]
296
+
297
+ # Create interface
298
+ demo = gr.Blocks(
299
+ title="ReconViaGen",
300
+ css="""
301
+ .slider .inner { width: 5px; background: #FFF; }
302
+ .viewport { aspect-ratio: 4/3; }
303
+ .tabs button.selected { font-size: 20px !important; color: crimson !important; }
304
+ h1, h2, h3 { text-align: center; display: block; }
305
+ .md_feedback li { margin-bottom: 0px !important; }
306
+ """
307
+ )
308
+ with demo:
309
+ gr.Markdown("""
310
+ # 💻 ReconViaGen
311
+ <p align="center">
312
+ <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
313
+ <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
314
+ </a>
315
+ <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
316
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
317
+ </a>
318
+ <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
319
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
320
+ </a>
321
+ </p>
322
+
323
+ ✨This demo is partial. We will release the whole model later. Stay tuned!✨
324
+ """)
325
 
326
  with gr.Row():
327
  with gr.Column():
328
+ with gr.Tabs() as input_tabs:
329
+ with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
330
+ input_video = gr.Video(label="Upload Video", interactive=True, height=300)
331
+ image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
332
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
333
+ gr.Markdown("""
334
+ Input different views of the object in separate images.
335
+ """)
336
+
337
+ with gr.Accordion(label="Generation Settings", open=False):
338
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
339
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
340
+ gr.Markdown("Stage 1: Sparse Structure Generation")
341
+ with gr.Row():
342
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
343
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
344
+ gr.Markdown("Stage 2: Structured Latent Generation")
345
+ with gr.Row():
346
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
347
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
348
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
 
 
 
 
 
349
 
350
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
351
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
352
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
353
+
354
+ generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
355
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
356
+ gr.Markdown("""
357
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
358
+ """)
359
+
360
  with gr.Column():
361
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
362
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
363
+
364
+ with gr.Row():
365
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
366
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
+ output_buf = gr.State()
369
+
370
+ # Example images at the bottom of the page
371
+ with gr.Row() as multiimage_example:
372
+ examples_multi = gr.Examples(
373
+ examples=prepare_multi_example(),
374
+ inputs=[image_prompt],
375
+ fn=split_image,
376
+ outputs=[multiimage_prompt],
377
+ run_on_click=True,
378
+ examples_per_page=8,
379
+ )
380
+
381
+ # Handlers
382
+ demo.load(start_session)
383
+ demo.unload(end_session)
384
+
385
+ input_video.upload(
386
+ preprocess_videos,
387
+ inputs=[input_video],
388
+ outputs=[multiimage_prompt],
389
+ )
390
+ input_video.clear(
391
+ lambda: tuple([None, None]),
392
+ outputs=[input_video, multiimage_prompt],
393
+ )
394
+ multiimage_prompt.upload(
395
+ preprocess_images,
396
+ inputs=[multiimage_prompt],
397
+ outputs=[multiimage_prompt],
398
+ )
399
+
400
+ generate_btn.click(
401
+ get_seed,
402
+ inputs=[randomize_seed, seed],
403
+ outputs=[seed],
404
+ ).then(
405
+ generate_and_extract_glb,
406
+ inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
407
+ outputs=[output_buf, video_output, model_output, download_glb],
408
+ ).then(
409
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
410
+ outputs=[extract_gs_btn, download_glb],
411
+ )
412
+
413
+ video_output.clear(
414
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
415
+ outputs=[extract_gs_btn, download_glb, download_gs],
416
  )
417
 
418
+ extract_gs_btn.click(
419
+ extract_gaussian,
420
+ inputs=[output_buf],
421
+ outputs=[model_output, download_gs],
422
+ ).then(
423
+ lambda: gr.Button(interactive=True),
424
+ outputs=[download_gs],
425
+ )
426
+
427
+ model_output.clear(
428
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
429
+ outputs=[download_glb, download_gs],
430
  )
431
+
432
 
433
+ # Launch the Gradio app
434
+ if __name__ == "__main__":
435
+ pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
436
+ pipeline.cuda()
437
+ pipeline.VGGT_model.cuda()
438
+ pipeline.birefnet_model.cuda()
439
+ demo.launch()