Vibu46vk commited on
Commit
7d6f599
·
verified ·
1 Parent(s): fdc7f9f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +406 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from trellis.pipelines import TrellisImageTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
17
+
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
+ os.makedirs(TMP_DIR, exist_ok=True)
22
+
23
+
24
+ def start_session(req: gr.Request):
25
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
+ os.makedirs(user_dir, exist_ok=True)
27
+
28
+
29
+ def end_session(req: gr.Request):
30
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
+ shutil.rmtree(user_dir)
32
+
33
+
34
+ def preprocess_image(image: Image.Image) -> Image.Image:
35
+ """
36
+ Preprocess the input image.
37
+ Args:
38
+ image (Image.Image): The input image.
39
+ Returns:
40
+ Image.Image: The preprocessed image.
41
+ """
42
+ processed_image = pipeline.preprocess_image(image)
43
+ return processed_image
44
+
45
+
46
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
47
+ """
48
+ Preprocess a list of input images.
49
+
50
+ Args:
51
+ images (List[Tuple[Image.Image, str]]): The input images.
52
+
53
+ Returns:
54
+ List[Image.Image]: The preprocessed images.
55
+ """
56
+ images = [image[0] for image in images]
57
+ processed_images = [pipeline.preprocess_image(image) for image in images]
58
+ return processed_images
59
+
60
+
61
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
62
+ return {
63
+ 'gaussian': {
64
+ **gs.init_params,
65
+ '_xyz': gs._xyz.cpu().numpy(),
66
+ '_features_dc': gs._features_dc.cpu().numpy(),
67
+ '_scaling': gs._scaling.cpu().numpy(),
68
+ '_rotation': gs._rotation.cpu().numpy(),
69
+ '_opacity': gs._opacity.cpu().numpy(),
70
+ },
71
+ 'mesh': {
72
+ 'vertices': mesh.vertices.cpu().numpy(),
73
+ 'faces': mesh.faces.cpu().numpy(),
74
+ },
75
+ }
76
+
77
+
78
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
79
+ gs = Gaussian(
80
+ aabb=state['gaussian']['aabb'],
81
+ sh_degree=state['gaussian']['sh_degree'],
82
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
83
+ scaling_bias=state['gaussian']['scaling_bias'],
84
+ opacity_bias=state['gaussian']['opacity_bias'],
85
+ scaling_activation=state['gaussian']['scaling_activation'],
86
+ )
87
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
88
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
89
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
90
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
91
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
92
+
93
+ mesh = edict(
94
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
95
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
96
+ )
97
+
98
+ return gs, mesh
99
+
100
+
101
+ def get_seed(randomize_seed: bool, seed: int) -> int:
102
+ """
103
+ Get the random seed.
104
+ """
105
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
106
+
107
+
108
+ @spaces.GPU
109
+ def image_to_3d(
110
+ image: Image.Image,
111
+ multiimages: List[Tuple[Image.Image, str]],
112
+ is_multiimage: bool,
113
+ seed: int,
114
+ ss_guidance_strength: float,
115
+ ss_sampling_steps: int,
116
+ slat_guidance_strength: float,
117
+ slat_sampling_steps: int,
118
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
119
+ req: gr.Request,
120
+ ) -> Tuple[dict, str]:
121
+ """
122
+ Convert an image to a 3D model.
123
+ Args:
124
+ image (Image.Image): The input image.
125
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
126
+ is_multiimage (bool): Whether is in multi-image mode.
127
+ seed (int): The random seed.
128
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
129
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
130
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
131
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
132
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
133
+ Returns:
134
+ dict: The information of the generated 3D model.
135
+ str: The path to the video of the 3D model.
136
+ """
137
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
138
+ if not is_multiimage:
139
+ outputs = pipeline.run(
140
+ image,
141
+ seed=seed,
142
+ formats=["gaussian", "mesh"],
143
+ preprocess_image=False,
144
+ sparse_structure_sampler_params={
145
+ "steps": ss_sampling_steps,
146
+ "cfg_strength": ss_guidance_strength,
147
+ },
148
+ slat_sampler_params={
149
+ "steps": slat_sampling_steps,
150
+ "cfg_strength": slat_guidance_strength,
151
+ },
152
+ )
153
+ else:
154
+ outputs = pipeline.run_multi_image(
155
+ [image[0] for image in multiimages],
156
+ seed=seed,
157
+ formats=["gaussian", "mesh"],
158
+ preprocess_image=False,
159
+ sparse_structure_sampler_params={
160
+ "steps": ss_sampling_steps,
161
+ "cfg_strength": ss_guidance_strength,
162
+ },
163
+ slat_sampler_params={
164
+ "steps": slat_sampling_steps,
165
+ "cfg_strength": slat_guidance_strength,
166
+ },
167
+ mode=multiimage_algo,
168
+ )
169
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
170
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
171
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
172
+ video_path = os.path.join(user_dir, 'sample.mp4')
173
+ imageio.mimsave(video_path, video, fps=15)
174
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
175
+ torch.cuda.empty_cache()
176
+ return state, video_path
177
+
178
+
179
+ @spaces.GPU(duration=90)
180
+ def extract_glb(
181
+ state: dict,
182
+ mesh_simplify: float,
183
+ texture_size: int,
184
+ req: gr.Request,
185
+ ) -> Tuple[str, str]:
186
+ """
187
+ Extract a GLB file from the 3D model.
188
+ Args:
189
+ state (dict): The state of the generated 3D model.
190
+ mesh_simplify (float): The mesh simplification factor.
191
+ texture_size (int): The texture resolution.
192
+ Returns:
193
+ str: The path to the extracted GLB file.
194
+ """
195
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
196
+ gs, mesh = unpack_state(state)
197
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
198
+ glb_path = os.path.join(user_dir, 'sample.glb')
199
+ glb.export(glb_path)
200
+ torch.cuda.empty_cache()
201
+ return glb_path, glb_path
202
+
203
+
204
+ @spaces.GPU
205
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
206
+ """
207
+ Extract a Gaussian file from the 3D model.
208
+ Args:
209
+ state (dict): The state of the generated 3D model.
210
+ Returns:
211
+ str: The path to the extracted Gaussian file.
212
+ """
213
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
214
+ gs, _ = unpack_state(state)
215
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
216
+ gs.save_ply(gaussian_path)
217
+ torch.cuda.empty_cache()
218
+ return gaussian_path, gaussian_path
219
+
220
+
221
+ def prepare_multi_example() -> List[Image.Image]:
222
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
223
+ images = []
224
+ for case in multi_case:
225
+ _images = []
226
+ for i in range(1, 4):
227
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
228
+ W, H = img.size
229
+ img = img.resize((int(W / H * 512), 512))
230
+ _images.append(np.array(img))
231
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
232
+ return images
233
+
234
+
235
+ def split_image(image: Image.Image) -> List[Image.Image]:
236
+ """
237
+ Split an image into multiple views.
238
+ """
239
+ image = np.array(image)
240
+ alpha = image[..., 3]
241
+ alpha = np.any(alpha>0, axis=0)
242
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
243
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
244
+ images = []
245
+ for s, e in zip(start_pos, end_pos):
246
+ images.append(Image.fromarray(image[:, s:e+1]))
247
+ return [preprocess_image(image) for image in images]
248
+
249
+
250
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
251
+ gr.Markdown("""
252
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
253
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
254
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
255
+
256
+ ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
257
+ """)
258
+
259
+ with gr.Row():
260
+ with gr.Column():
261
+ with gr.Tabs() as input_tabs:
262
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
263
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
264
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
265
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
266
+ gr.Markdown("""
267
+ Input different views of the object in separate images.
268
+
269
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
270
+ """)
271
+
272
+ with gr.Accordion(label="Generation Settings", open=False):
273
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
274
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
275
+ gr.Markdown("Stage 1: Sparse Structure Generation")
276
+ with gr.Row():
277
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
278
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
279
+ gr.Markdown("Stage 2: Structured Latent Generation")
280
+ with gr.Row():
281
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
282
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
283
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
284
+
285
+ generate_btn = gr.Button("Generate")
286
+
287
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
288
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
289
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
290
+
291
+ with gr.Row():
292
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
293
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
294
+ gr.Markdown("""
295
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
296
+ """)
297
+
298
+ with gr.Column():
299
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
300
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
301
+
302
+ with gr.Row():
303
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
304
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
305
+
306
+ is_multiimage = gr.State(False)
307
+ output_buf = gr.State()
308
+
309
+ # Example images at the bottom of the page
310
+ with gr.Row() as single_image_example:
311
+ examples = gr.Examples(
312
+ examples=[
313
+ f'assets/example_image/{image}'
314
+ for image in os.listdir("assets/example_image")
315
+ ],
316
+ inputs=[image_prompt],
317
+ fn=preprocess_image,
318
+ outputs=[image_prompt],
319
+ run_on_click=True,
320
+ examples_per_page=64,
321
+ )
322
+ with gr.Row(visible=False) as multiimage_example:
323
+ examples_multi = gr.Examples(
324
+ examples=prepare_multi_example(),
325
+ inputs=[image_prompt],
326
+ fn=split_image,
327
+ outputs=[multiimage_prompt],
328
+ run_on_click=True,
329
+ examples_per_page=8,
330
+ )
331
+
332
+ # Handlers
333
+ demo.load(start_session)
334
+ demo.unload(end_session)
335
+
336
+ single_image_input_tab.select(
337
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
338
+ outputs=[is_multiimage, single_image_example, multiimage_example]
339
+ )
340
+ multiimage_input_tab.select(
341
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
342
+ outputs=[is_multiimage, single_image_example, multiimage_example]
343
+ )
344
+
345
+ image_prompt.upload(
346
+ preprocess_image,
347
+ inputs=[image_prompt],
348
+ outputs=[image_prompt],
349
+ )
350
+ multiimage_prompt.upload(
351
+ preprocess_images,
352
+ inputs=[multiimage_prompt],
353
+ outputs=[multiimage_prompt],
354
+ )
355
+
356
+ generate_btn.click(
357
+ get_seed,
358
+ inputs=[randomize_seed, seed],
359
+ outputs=[seed],
360
+ ).then(
361
+ image_to_3d,
362
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
363
+ outputs=[output_buf, video_output],
364
+ ).then(
365
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
366
+ outputs=[extract_glb_btn, extract_gs_btn],
367
+ )
368
+
369
+ video_output.clear(
370
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
371
+ outputs=[extract_glb_btn, extract_gs_btn],
372
+ )
373
+
374
+ extract_glb_btn.click(
375
+ extract_glb,
376
+ inputs=[output_buf, mesh_simplify, texture_size],
377
+ outputs=[model_output, download_glb],
378
+ ).then(
379
+ lambda: gr.Button(interactive=True),
380
+ outputs=[download_glb],
381
+ )
382
+
383
+ extract_gs_btn.click(
384
+ extract_gaussian,
385
+ inputs=[output_buf],
386
+ outputs=[model_output, download_gs],
387
+ ).then(
388
+ lambda: gr.Button(interactive=True),
389
+ outputs=[download_gs],
390
+ )
391
+
392
+ model_output.clear(
393
+ lambda: gr.Button(interactive=False),
394
+ outputs=[download_glb],
395
+ )
396
+
397
+
398
+ # Launch the Gradio app
399
+ if __name__ == "__main__":
400
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
401
+ pipeline.cuda()
402
+ try:
403
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
404
+ except:
405
+ pass
406
+ demo.launch()