Shivamkak commited on
Commit
ea6d36f
·
verified ·
1 Parent(s): da049aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -14
app.py CHANGED
@@ -1,31 +1,42 @@
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  import imageio
4
-
5
  from easydict import EasyDict as edict
6
  from PIL import Image
7
  from trellis.pipelines import TrellisImageTo3DPipeline
 
 
 
 
 
 
 
 
8
 
9
  def start_session(req: gr.Request):
10
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
11
-
12
  os.makedirs(user_dir, exist_ok=True)
13
 
14
 
15
  def end_session(req: gr.Request):
16
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
17
-
18
  shutil.rmtree(user_dir)
19
 
20
 
21
  def preprocess_image(image: Image.Image) -> Image.Image:
22
  """
23
  Preprocess the input image.
24
-
25
  image (Image.Image): The input image.
26
-
27
  Returns:
28
-
29
  Image.Image: The preprocessed image.
30
  """
31
  processed_image = pipeline.preprocess_image(image)
@@ -51,13 +62,36 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
  return {
52
  'gaussian': {
53
  **gs.init_params,
 
 
 
 
 
 
 
54
  'vertices': mesh.vertices.cpu().numpy(),
55
  'faces': mesh.faces.cpu().numpy(),
56
  },
57
-
58
  }
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
62
  )
63
 
@@ -65,6 +99,12 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
65
 
66
 
67
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
68
  @spaces.GPU
69
  def image_to_3d(
70
  image: Image.Image,
@@ -79,7 +119,7 @@ def image_to_3d(
79
  req: gr.Request,
80
  ) -> Tuple[dict, str]:
81
  """
82
-
83
  Args:
84
  image (Image.Image): The input image.
85
  multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
@@ -90,7 +130,6 @@ def image_to_3d(
90
  slat_guidance_strength (float): The guidance strength for structured latent generation.
91
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
92
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
93
-
94
  Returns:
95
  dict: The information of the generated 3D model.
96
  str: The path to the video of the 3D model.
@@ -131,7 +170,6 @@ def image_to_3d(
131
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
132
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
133
  video_path = os.path.join(user_dir, 'sample.mp4')
134
-
135
  imageio.mimsave(video_path, video, fps=15)
136
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
137
  torch.cuda.empty_cache()
@@ -142,6 +180,16 @@ def image_to_3d(
142
  def extract_glb(
143
  state: dict,
144
  mesh_simplify: float,
 
 
 
 
 
 
 
 
 
 
145
  str: The path to the extracted GLB file.
146
  """
147
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -157,10 +205,8 @@ def extract_glb(
157
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
158
  """
159
  Extract a Gaussian file from the 3D model.
160
-
161
  Args:
162
  state (dict): The state of the generated 3D model.
163
-
164
  Returns:
165
  str: The path to the extracted Gaussian file.
166
  """
@@ -225,6 +271,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
225
 
226
  with gr.Accordion(label="Generation Settings", open=False):
227
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
 
 
 
 
 
228
  with gr.Row():
229
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
230
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
@@ -232,6 +284,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
232
 
233
  generate_btn = gr.Button("Generate")
234
 
 
235
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
236
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
237
 
@@ -245,7 +298,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
245
  with gr.Column():
246
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
247
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
248
-
249
 
250
  with gr.Row():
251
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
@@ -259,6 +311,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  examples = gr.Examples(
260
  examples=[
261
  f'assets/example_image/{image}'
 
 
 
 
 
262
  run_on_click=True,
263
  examples_per_page=64,
264
  )
@@ -298,6 +355,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
298
 
299
  generate_btn.click(
300
  get_seed,
 
301
  outputs=[seed],
302
  ).then(
303
  image_to_3d,
@@ -314,6 +372,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
314
  )
315
 
316
  extract_glb_btn.click(
 
 
 
 
317
  lambda: gr.Button(interactive=True),
318
  outputs=[download_glb],
319
  )
@@ -328,4 +390,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
328
  )
329
 
330
  model_output.clear(
331
- lambda: gr.Button(interactive=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
  import torch
10
  import numpy as np
11
  import imageio
 
12
  from easydict import EasyDict as edict
13
  from PIL import Image
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
17
+
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
+ os.makedirs(TMP_DIR, exist_ok=True)
22
+
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
26
  os.makedirs(user_dir, exist_ok=True)
27
 
28
 
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
31
  shutil.rmtree(user_dir)
32
 
33
 
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
  Preprocess the input image.
37
+ Args:
38
  image (Image.Image): The input image.
 
39
  Returns:
 
40
  Image.Image: The preprocessed image.
41
  """
42
  processed_image = pipeline.preprocess_image(image)
 
62
  return {
63
  'gaussian': {
64
  **gs.init_params,
65
+ '_xyz': gs._xyz.cpu().numpy(),
66
+ '_features_dc': gs._features_dc.cpu().numpy(),
67
+ '_scaling': gs._scaling.cpu().numpy(),
68
+ '_rotation': gs._rotation.cpu().numpy(),
69
+ '_opacity': gs._opacity.cpu().numpy(),
70
+ },
71
+ 'mesh': {
72
  'vertices': mesh.vertices.cpu().numpy(),
73
  'faces': mesh.faces.cpu().numpy(),
74
  },
 
75
  }
76
 
77
 
78
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
79
+ gs = Gaussian(
80
+ aabb=state['gaussian']['aabb'],
81
+ sh_degree=state['gaussian']['sh_degree'],
82
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
83
+ scaling_bias=state['gaussian']['scaling_bias'],
84
+ opacity_bias=state['gaussian']['opacity_bias'],
85
+ scaling_activation=state['gaussian']['scaling_activation'],
86
+ )
87
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
88
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
89
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
90
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
91
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
92
+
93
+ mesh = edict(
94
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
95
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
96
  )
97
 
 
99
 
100
 
101
  def get_seed(randomize_seed: bool, seed: int) -> int:
102
+ """
103
+ Get the random seed.
104
+ """
105
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
106
+
107
+
108
  @spaces.GPU
109
  def image_to_3d(
110
  image: Image.Image,
 
119
  req: gr.Request,
120
  ) -> Tuple[dict, str]:
121
  """
122
+ Convert an image to a 3D model.
123
  Args:
124
  image (Image.Image): The input image.
125
  multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
 
130
  slat_guidance_strength (float): The guidance strength for structured latent generation.
131
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
132
  multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
 
133
  Returns:
134
  dict: The information of the generated 3D model.
135
  str: The path to the video of the 3D model.
 
170
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
171
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
172
  video_path = os.path.join(user_dir, 'sample.mp4')
 
173
  imageio.mimsave(video_path, video, fps=15)
174
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
175
  torch.cuda.empty_cache()
 
180
  def extract_glb(
181
  state: dict,
182
  mesh_simplify: float,
183
+ texture_size: int,
184
+ req: gr.Request,
185
+ ) -> Tuple[str, str]:
186
+ """
187
+ Extract a GLB file from the 3D model.
188
+ Args:
189
+ state (dict): The state of the generated 3D model.
190
+ mesh_simplify (float): The mesh simplification factor.
191
+ texture_size (int): The texture resolution.
192
+ Returns:
193
  str: The path to the extracted GLB file.
194
  """
195
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
205
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
206
  """
207
  Extract a Gaussian file from the 3D model.
 
208
  Args:
209
  state (dict): The state of the generated 3D model.
 
210
  Returns:
211
  str: The path to the extracted Gaussian file.
212
  """
 
271
 
272
  with gr.Accordion(label="Generation Settings", open=False):
273
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
274
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
275
+ gr.Markdown("Stage 1: Sparse Structure Generation")
276
+ with gr.Row():
277
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
278
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
279
+ gr.Markdown("Stage 2: Structured Latent Generation")
280
  with gr.Row():
281
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
282
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
284
 
285
  generate_btn = gr.Button("Generate")
286
 
287
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
288
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
289
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
290
 
 
298
  with gr.Column():
299
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
300
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
 
301
 
302
  with gr.Row():
303
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
 
311
  examples = gr.Examples(
312
  examples=[
313
  f'assets/example_image/{image}'
314
+ for image in os.listdir("assets/example_image")
315
+ ],
316
+ inputs=[image_prompt],
317
+ fn=preprocess_image,
318
+ outputs=[image_prompt],
319
  run_on_click=True,
320
  examples_per_page=64,
321
  )
 
355
 
356
  generate_btn.click(
357
  get_seed,
358
+ inputs=[randomize_seed, seed],
359
  outputs=[seed],
360
  ).then(
361
  image_to_3d,
 
372
  )
373
 
374
  extract_glb_btn.click(
375
+ extract_glb,
376
+ inputs=[output_buf, mesh_simplify, texture_size],
377
+ outputs=[model_output, download_glb],
378
+ ).then(
379
  lambda: gr.Button(interactive=True),
380
  outputs=[download_glb],
381
  )
 
390
  )
391
 
392
  model_output.clear(
393
+ lambda: gr.Button(interactive=False),
394
+ outputs=[download_glb],
395
+ )
396
+
397
+
398
+ # Launch the Gradio app
399
+ if __name__ == "__main__":
400
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
401
+ pipeline.cuda()
402
+ try:
403
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
404
+ except:
405
+ pass
406
+ demo.launch()