chateauxai commited on
Commit
0206e97
·
verified ·
1 Parent(s): 843444c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -251
app.py CHANGED
@@ -1,280 +1,315 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
  import os
5
  import shutil
 
 
 
 
6
  import imageio
 
7
  from PIL import Image
 
 
 
8
 
9
- # Ensure imports are available
10
- try:
11
- from trellis.pipelines import TrellisImageTo3DPipeline
12
- from trellis.representations import Gaussian, MeshExtractResult
13
- from trellis.utils import render_utils, postprocessing_utils
14
- from easydict import EasyDict as edict
15
- except ImportError as e:
16
- print(f"Error importing required libraries: {e}")
17
- print("Please install the following libraries:")
18
- print("- trellis-ai")
19
- print("- easydict")
20
- TrellisImageTo3DPipeline = None
21
-
22
- # Constants
23
  MAX_SEED = np.iinfo(np.int32).max
24
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
25
  os.makedirs(TMP_DIR, exist_ok=True)
26
 
27
- class ImageTo3DConverter:
28
- def __init__(self):
29
- # Initialize the pipeline with error handling
30
- try:
31
- self.pipeline = TrellisImageTo3DPipeline()
32
- except Exception as e:
33
- print(f"Failed to initialize pipeline: {e}")
34
- self.pipeline = None
35
-
36
- def validate_input(self, image, is_multiimage):
37
- """Validate input images before processing"""
38
- if not self.pipeline:
39
- raise ValueError("Pipeline not initialized. Check library installation.")
40
-
41
- if is_multiimage:
42
- # Handle multi-image input
43
- if not image or len(image) == 0:
44
- raise ValueError("No images provided for multi-image processing")
45
- # Ensure images are PIL Image objects
46
- valid_images = [img[0] if isinstance(img, list) else img for img in image]
47
- return valid_images
48
- else:
49
- # Handle single image input
50
- if image is None:
51
- raise ValueError("No image provided")
52
- return image
53
-
54
- def preprocess_image(self, image):
55
- """Safely preprocess a single image"""
56
- try:
57
- return self.pipeline.preprocess_image(image)
58
- except Exception as e:
59
- print(f"Image preprocessing error: {e}")
60
- return image
61
-
62
- def process_image(self,
63
- image,
64
- multiimages,
65
- is_multiimage,
66
- seed,
67
- ss_guidance_strength,
68
- ss_sampling_steps,
69
- slat_guidance_strength,
70
- slat_sampling_steps,
71
- multiimage_algo):
72
- """Main image to 3D conversion method"""
73
- # Validate and preprocess input
74
- try:
75
- processed_input = self.validate_input(image if not is_multiimage else multiimages, is_multiimage)
76
- except ValueError as e:
77
- print(f"Input validation error: {e}")
78
- return None, None
79
-
80
- # Determine processing method based on input type
81
- try:
82
- if not is_multiimage:
83
- outputs = self.pipeline.run(
84
- processed_input,
85
- seed=seed,
86
- formats=["gaussian", "mesh"],
87
- preprocess_image=False,
88
- sparse_structure_sampler_params={
89
- "steps": ss_sampling_steps,
90
- "cfg_strength": ss_guidance_strength,
91
- },
92
- slat_sampler_params={
93
- "steps": slat_sampling_steps,
94
- "cfg_strength": slat_guidance_strength,
95
- },
96
- )
97
- else:
98
- outputs = self.pipeline.run_multi_image(
99
- processed_input,
100
- seed=seed,
101
- formats=["gaussian", "mesh"],
102
- preprocess_image=False,
103
- sparse_structure_sampler_params={
104
- "steps": ss_sampling_steps,
105
- "cfg_strength": ss_guidance_strength,
106
- },
107
- slat_sampler_params={
108
- "steps": slat_sampling_steps,
109
- "cfg_strength": slat_guidance_strength,
110
- },
111
- mode=multiimage_algo,
112
- )
113
- except Exception as e:
114
- print(f"3D conversion error: {e}")
115
- return None, None
116
-
117
- # Generate video
118
- try:
119
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
120
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
121
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
122
-
123
- # Save video
124
- user_dir = os.path.join(TMP_DIR, 'temp_session')
125
- os.makedirs(user_dir, exist_ok=True)
126
- video_path = os.path.join(user_dir, 'sample.mp4')
127
- imageio.mimsave(video_path, video, fps=15)
128
-
129
- # Pack and return state
130
- state = {
131
- 'gaussian': {
132
- **outputs['gaussian'][0].init_params,
133
- '_xyz': outputs['gaussian'][0]._xyz.cpu().numpy(),
134
- '_features_dc': outputs['gaussian'][0]._features_dc.cpu().numpy(),
135
- '_scaling': outputs['gaussian'][0]._scaling.cpu().numpy(),
136
- '_rotation': outputs['gaussian'][0]._rotation.cpu().numpy(),
137
- '_opacity': outputs['gaussian'][0]._opacity.cpu().numpy(),
138
- },
139
- 'mesh': {
140
- 'vertices': outputs['mesh'][0].vertices.cpu().numpy(),
141
- 'faces': outputs['mesh'][0].faces.cpu().numpy(),
142
- },
143
- }
144
-
145
- return state, video_path
146
 
147
- except Exception as e:
148
- print(f"Video generation error: {e}")
149
- return None, None
 
 
 
 
 
 
 
150
 
151
- def extract_glb(self, state, mesh_simplify=0.95, texture_size=1024):
152
- """Extract GLB from the processed state"""
153
- try:
154
- # Reconstruct Gaussian and Mesh from state
155
- gs = Gaussian(
156
- aabb=state['gaussian']['aabb'],
157
- sh_degree=state['gaussian']['sh_degree'],
158
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
159
- scaling_bias=state['gaussian'].get('scaling_bias', 0.1),
160
- opacity_bias=state['gaussian'].get('opacity_bias', 0.1),
161
- scaling_activation=state['gaussian'].get('scaling_activation', 'softplus'),
162
- )
163
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
164
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
165
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
166
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
167
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
168
-
169
- mesh = edict(
170
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
171
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
172
- )
173
-
174
- # Convert mesh
175
- mesh.vertices, mesh.faces = postprocessing_utils.remesh_to_quads(
176
- vertices=mesh.vertices.cpu().numpy(),
177
- faces=mesh.faces.cpu().numpy(),
178
- simplify=mesh_simplify
179
- )
180
-
181
- # Generate GLB
182
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
183
-
184
- # Save GLB
185
- user_dir = os.path.join(TMP_DIR, 'temp_session')
186
- os.makedirs(user_dir, exist_ok=True)
187
- glb_path = os.path.join(user_dir, 'sample.glb')
188
- glb.export(glb_path)
189
 
190
- return glb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- except Exception as e:
193
- print(f"GLB extraction error: {e}")
194
- return None
 
 
195
 
196
- # Gradio Interface Setup
197
- def create_gradio_interface():
198
- converter = ImageTo3DConverter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- with gr.Blocks() as demo:
201
- # Input components
202
- with gr.Row():
203
- with gr.Column():
204
- with gr.Tabs() as input_tabs:
205
- with gr.Tab("Single Image"):
206
- single_image = gr.Image(label="Single Image Input")
207
- with gr.Tab("Multiple Images"):
208
- multi_images = gr.Gallery(label="Multiple Image Input")
209
-
210
- # Generation settings
211
- with gr.Accordion("Generation Settings"):
212
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0)
213
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
214
-
215
- with gr.Row():
216
- ss_guidance = gr.Slider(0, 10, label="Sparse Guidance Strength", value=7.5)
217
- ss_steps = gr.Slider(1, 50, label="Sparse Sampling Steps", value=12)
218
-
219
- with gr.Row():
220
- slat_guidance = gr.Slider(0, 10, label="Latent Guidance Strength", value=3.0)
221
- slat_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12)
222
-
223
- multi_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
224
 
225
- # Buttons
226
- generate_btn = gr.Button("Generate 3D Model")
227
-
228
- # GLB Extraction
229
- with gr.Accordion("GLB Extraction"):
230
- mesh_simplify = gr.Slider(0.9, 0.98, label="Mesh Simplify", value=0.95)
231
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024)
232
- extract_glb_btn = gr.Button("Extract GLB")
 
 
 
 
 
 
 
233
 
234
- # Output components
 
235
  with gr.Column():
236
- video_output = gr.Video(label="Generated 3D Asset Preview")
237
- glb_output = gr.File(label="Extracted GLB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- # Event handlers
240
- def generate_3d(image, multi_image, seed, ss_guidance, ss_steps,
241
- slat_guidance, slat_steps, multi_algo):
242
- # Determine if it's multi-image mode
243
- is_multi = isinstance(multi_image, list) and len(multi_image) > 0
244
 
245
- # Randomize seed if selected
246
- if randomize_seed:
247
- seed = np.random.randint(0, MAX_SEED)
248
 
249
- # Process image
250
- state, video = converter.process_image(
251
- image, multi_image, is_multi, seed,
252
- ss_guidance, ss_steps,
253
- slat_guidance, slat_steps,
254
- multi_algo
255
- )
256
 
257
- return video if video else None
 
 
 
 
 
258
 
259
- def extract_glb(state, simplify, texture_size):
260
- if state is None:
261
- return None
262
- glb_path = converter.extract_glb(state, simplify, texture_size)
263
- return glb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # Connect event handlers
266
- generate_btn.click(
267
- generate_3d,
268
- inputs=[single_image, multi_images, seed, ss_guidance, ss_steps,
269
- slat_guidance, slat_steps, multi_algo],
270
- outputs=[video_output]
271
- )
 
 
 
 
 
272
 
273
- extract_glb_btn.click(
274
- extract_glb,
275
- inputs=[ mesh_simplify, texture_size],
276
- outputs=[glb_output]
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  # Launch the Gradio app
280
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
+ os.environ['SPCONV_ALGO'] = 'native'
7
+ from typing import *
8
+ import torch
9
+ import numpy as np
10
  import imageio
11
+ from easydict import EasyDict as edict
12
  from PIL import Image
13
+ from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.representations import Gaussian, MeshExtractResult
15
+ from trellis.utils import render_utils, postprocessing_utils
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
19
  os.makedirs(TMP_DIR, exist_ok=True)
20
 
21
+ def start_session(req: gr.Request):
22
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
23
+ os.makedirs(user_dir, exist_ok=True)
24
+
25
+ def end_session(req: gr.Request):
26
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
+ shutil.rmtree(user_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def preprocess_image(image: Image.Image) -> Image.Image:
30
+ """
31
+ Preprocess the input image.
32
+ Args:
33
+ image (Image.Image): The input image.
34
+ Returns:
35
+ Image.Image: The preprocessed image.
36
+ """
37
+ processed_image = pipeline.preprocess_image(image)
38
+ return processed_image
39
 
40
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
41
+ """
42
+ Preprocess a list of input images.
43
+
44
+ Args:
45
+ images (List[Tuple[Image.Image, str]]): The input images.
46
+
47
+ Returns:
48
+ List[Image.Image]: The preprocessed images.
49
+ """
50
+ images = [image[0] for image in images]
51
+ processed_images = [pipeline.preprocess_image(image) for image in images]
52
+ return processed_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
55
+ return {
56
+ 'gaussian': {
57
+ **gs.init_params,
58
+ '_xyz': gs._xyz.cpu().numpy(),
59
+ '_features_dc': gs._features_dc.cpu().numpy(),
60
+ '_scaling': gs._scaling.cpu().numpy(),
61
+ '_rotation': gs._rotation.cpu().numpy(),
62
+ '_opacity': gs._opacity.cpu().numpy(),
63
+ },
64
+ 'mesh': {
65
+ 'vertices': mesh.vertices.cpu().numpy(),
66
+ 'faces': mesh.faces.cpu().numpy(),
67
+ },
68
+ }
69
+
70
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
+ gs = Gaussian(
72
+ aabb=state['gaussian']['aabb'],
73
+ sh_degree=state['gaussian']['sh_degree'],
74
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
75
+ scaling_bias=state['gaussian']['scaling_bias'],
76
+ opacity_bias=state['gaussian']['opacity_bias'],
77
+ scaling_activation=state['gaussian']['scaling_activation'],
78
+ )
79
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
80
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
81
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
82
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
83
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
84
+
85
+ mesh = edict(
86
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
87
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
88
+ )
89
+
90
+ return gs, mesh
91
 
92
+ def get_seed(randomize_seed: bool, seed: int) -> int:
93
+ """
94
+ Get the random seed.
95
+ """
96
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
 
98
+ @spaces.GPU
99
+ def image_to_3d(
100
+ image: Image.Image,
101
+ multiimages: List[Tuple[Image.Image, str]],
102
+ is_multiimage: bool,
103
+ seed: int,
104
+ ss_guidance_strength: float,
105
+ ss_sampling_steps: int,
106
+ slat_guidance_strength: float,
107
+ slat_sampling_steps: int,
108
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
109
+ req: gr.Request,
110
+ ) -> Tuple[dict, str]:
111
+ """
112
+ Convert an image to a 3D model.
113
+ Args:
114
+ image (Image.Image): The input image.
115
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
116
+ is_multiimage (bool): Whether is in multi-image mode.
117
+ seed (int): The random seed.
118
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
119
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
120
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
121
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
122
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
123
+ Returns:
124
+ dict: The information of the generated 3D model.
125
+ str: The path to the video of the 3D model.
126
+ """
127
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
+ if not is_multiimage:
129
+ outputs = pipeline.run(
130
+ image,
131
+ seed=seed,
132
+ formats=["gaussian", "mesh"],
133
+ preprocess_image=False,
134
+ sparse_structure_sampler_params={
135
+ "steps": ss_sampling_steps,
136
+ "cfg_strength": ss_guidance_strength,
137
+ },
138
+ slat_sampler_params={
139
+ "steps": slat_sampling_steps,
140
+ "cfg_strength": slat_guidance_strength,
141
+ },
142
+ )
143
+ else:
144
+ outputs = pipeline.run_multi_image(
145
+ [image[0] for image in multiimages],
146
+ seed=seed,
147
+ formats=["gaussian", "mesh"],
148
+ preprocess_image=False,
149
+ sparse_structure_sampler_params={
150
+ "steps": ss_sampling_steps,
151
+ "cfg_strength": ss_guidance_strength,
152
+ },
153
+ slat_sampler_params={
154
+ "steps": slat_sampling_steps,
155
+ "cfg_strength": slat_guidance_strength,
156
+ },
157
+ mode=multiimage_algo,
158
+ )
159
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
160
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
161
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
162
+ video_path = os.path.join(user_dir, 'sample.mp4')
163
+ imageio.mimsave(video_path, video, fps=15)
164
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
165
+ torch.cuda.empty_cache()
166
+ return state, video_path
167
 
168
+ @spaces.GPU(duration=90)
169
+ def extract_glb(
170
+ state: dict,
171
+ mesh_simplify: float,
172
+ texture_size: int,
173
+ req: gr.Request,
174
+ ) -> Tuple[str, str]:
175
+ """
176
+ Extract a GLB file from the 3D model.
177
+ Args:
178
+ state (dict): The state of the generated 3D model.
179
+ mesh_simplify (float): The mesh simplification factor.
180
+ texture_size (int): The texture resolution.
181
+ Returns:
182
+ str: The path to the extracted GLB file.
183
+ """
184
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
185
+ gs, mesh = unpack_state(state)
186
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
187
+ glb_path = os.path.join(user_dir, 'sample.glb')
188
+ glb.export(glb_path)
189
+ torch.cuda.empty_cache()
190
+ return glb_path, glb_path
 
191
 
192
+ @spaces.GPU
193
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
194
+ """
195
+ Extract a Gaussian file from the 3D model.
196
+ Args:
197
+ state (dict): The state of the generated 3D model.
198
+ Returns:
199
+ str: The path to the extracted Gaussian file.
200
+ """
201
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
+ gs, _ = unpack_state(state)
203
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
204
+ gs.save_ply(gaussian_path)
205
+ torch.cuda.empty_cache()
206
+ return gaussian_path, gaussian_path
207
 
208
+ with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
209
+ with gr.Row():
210
  with gr.Column():
211
+ with gr.Tabs() as input_tabs:
212
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
213
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
214
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
215
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
216
+
217
+ with gr.Accordion(label="Generation Settings", open=False):
218
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
219
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
220
+ with gr.Row():
221
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
222
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
223
+ with gr.Row():
224
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
225
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
226
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
227
 
228
+ generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
229
 
230
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
231
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
232
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
233
 
234
+ with gr.Row():
235
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
236
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
237
+
238
+ with gr.Column():
239
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
240
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
241
 
242
+ with gr.Row():
243
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
244
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
245
+
246
+ is_multiimage = gr.State(False)
247
+ output_buf = gr.State()
248
 
249
+ # Handlers
250
+ demo.load(start_session)
251
+ demo.unload(end_session)
252
+
253
+ single_image_input_tab.select(
254
+ lambda: False,
255
+ outputs=[is_multiimage]
256
+ )
257
+ multiimage_input_tab.select(
258
+ lambda: True,
259
+ outputs=[is_multiimage]
260
+ )
261
+
262
+ image_prompt.upload(
263
+ preprocess_image,
264
+ inputs=[image_prompt],
265
+ outputs=[image_prompt],
266
+ )
267
+ multiimage_prompt.upload(
268
+ preprocess_images,
269
+ inputs=[multiimage_prompt],
270
+ outputs=[multiimage_prompt],
271
+ )
272
 
273
+ generate_btn.click(
274
+ get_seed,
275
+ inputs=[randomize_seed, seed],
276
+ outputs=[seed],
277
+ ).then(
278
+ image_to_3d,
279
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
280
+ outputs=[output_buf, video_output],
281
+ ).then(
282
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
283
+ outputs=[extract_glb_btn, extract_gs_btn],
284
+ )
285
 
286
+ video_output.clear(
287
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
288
+ outputs=[extract_glb_btn, extract_gs_btn],
289
+ )
290
+
291
+ extract_glb_btn.click(
292
+ extract_glb,
293
+ inputs=[output_buf, mesh_simplify, texture_size],
294
+ outputs=[model_output, download_glb],
295
+ ).then(
296
+ lambda: gr.Button(interactive=True),
297
+ outputs=[download_glb],
298
+ )
299
+
300
+ extract_gs_btn.click(
301
+ extract_gaussian,
302
+ inputs=[output_buf],
303
+ outputs=[model_output, download_gs],
304
+ ).then(
305
+ lambda: gr.Button(interactive=True),
306
+ outputs=[download_gs],
307
+ )
308
+
309
+ model_output.clear(
310
+ lambda: gr.Button(interactive=False),
311
+ outputs=[download_glb],
312
+ )
313
 
314
  # Launch the Gradio app
315
  if __name__ == "__main__":