Ashraf commited on
Commit
25f6c4f
·
verified ·
1 Parent(s): 60686ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -3
app.py CHANGED
@@ -1,10 +1,209 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
 
 
 
 
4
  import numpy as np
 
 
 
 
 
 
5
 
6
- # Add this constant definition
7
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
10
  with gr.Row():
@@ -120,5 +319,4 @@ if __name__ == "__main__":
120
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
121
  except:
122
  pass
123
- demo.launch()
124
-
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
+ import os
5
+ import shutil
6
+ os.environ['SPCONV_ALGO'] = 'native'
7
+ from typing import *
8
+ import torch
9
  import numpy as np
10
+ import imageio
11
+ from easydict import EasyDict as edict
12
+ from PIL import Image
13
+ from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.representations import Gaussian, MeshExtractResult
15
+ from trellis.utils import render_utils, postprocessing_utils
16
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
19
+ os.makedirs(TMP_DIR, exist_ok=True)
20
+
21
+ def start_session(req: gr.Request):
22
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
23
+ os.makedirs(user_dir, exist_ok=True)
24
+
25
+ def end_session(req: gr.Request):
26
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
+ shutil.rmtree(user_dir)
28
+
29
+ def preprocess_image(image: Image.Image) -> Image.Image:
30
+ """
31
+ Preprocess the input image.
32
+ Args:
33
+ image (Image.Image): The input image.
34
+ Returns:
35
+ Image.Image: The preprocessed image.
36
+ """
37
+ processed_image = pipeline.preprocess_image(image)
38
+ return processed_image
39
+
40
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
41
+ """
42
+ Preprocess a list of input images.
43
+
44
+ Args:
45
+ images (List[Tuple[Image.Image, str]]): The input images.
46
+
47
+ Returns:
48
+ List[Image.Image]: The preprocessed images.
49
+ """
50
+ images = [image[0] for image in images]
51
+ processed_images = [pipeline.preprocess_image(image) for image in images]
52
+ return processed_images
53
+
54
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
55
+ return {
56
+ 'gaussian': {
57
+ **gs.init_params,
58
+ '_xyz': gs._xyz.cpu().numpy(),
59
+ '_features_dc': gs._features_dc.cpu().numpy(),
60
+ '_scaling': gs._scaling.cpu().numpy(),
61
+ '_rotation': gs._rotation.cpu().numpy(),
62
+ '_opacity': gs._opacity.cpu().numpy(),
63
+ },
64
+ 'mesh': {
65
+ 'vertices': mesh.vertices.cpu().numpy(),
66
+ 'faces': mesh.faces.cpu().numpy(),
67
+ },
68
+ }
69
+
70
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
+ gs = Gaussian(
72
+ aabb=state['gaussian']['aabb'],
73
+ sh_degree=state['gaussian']['sh_degree'],
74
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
75
+ scaling_bias=state['gaussian']['scaling_bias'],
76
+ opacity_bias=state['gaussian']['opacity_bias'],
77
+ scaling_activation=state['gaussian']['scaling_activation'],
78
+ )
79
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
80
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
81
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
82
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
83
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
84
+
85
+ mesh = edict(
86
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
87
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
88
+ )
89
+
90
+ return gs, mesh
91
+
92
+ def get_seed(randomize_seed: bool, seed: int) -> int:
93
+ """
94
+ Get the random seed.
95
+ """
96
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
+
98
+ @spaces.GPU
99
+ def image_to_3d(
100
+ image: Image.Image,
101
+ multiimages: List[Tuple[Image.Image, str]],
102
+ is_multiimage: bool,
103
+ seed: int,
104
+ ss_guidance_strength: float,
105
+ ss_sampling_steps: int,
106
+ slat_guidance_strength: float,
107
+ slat_sampling_steps: int,
108
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
109
+ req: gr.Request,
110
+ ) -> Tuple[dict, str]:
111
+ """
112
+ Convert an image to a 3D model.
113
+ Args:
114
+ image (Image.Image): The input image.
115
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
116
+ is_multiimage (bool): Whether is in multi-image mode.
117
+ seed (int): The random seed.
118
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
119
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
120
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
121
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
122
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
123
+ Returns:
124
+ dict: The information of the generated 3D model.
125
+ str: The path to the video of the 3D model.
126
+ """
127
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
+ if not is_multiimage:
129
+ outputs = pipeline.run(
130
+ image,
131
+ seed=seed,
132
+ formats=["gaussian", "mesh"],
133
+ preprocess_image=False,
134
+ sparse_structure_sampler_params={
135
+ "steps": ss_sampling_steps,
136
+ "cfg_strength": ss_guidance_strength,
137
+ },
138
+ slat_sampler_params={
139
+ "steps": slat_sampling_steps,
140
+ "cfg_strength": slat_guidance_strength,
141
+ },
142
+ )
143
+ else:
144
+ outputs = pipeline.run_multi_image(
145
+ [image[0] for image in multiimages],
146
+ seed=seed,
147
+ formats=["gaussian", "mesh"],
148
+ preprocess_image=False,
149
+ sparse_structure_sampler_params={
150
+ "steps": ss_sampling_steps,
151
+ "cfg_strength": ss_guidance_strength,
152
+ },
153
+ slat_sampler_params={
154
+ "steps": slat_sampling_steps,
155
+ "cfg_strength": slat_guidance_strength,
156
+ },
157
+ mode=multiimage_algo,
158
+ )
159
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
160
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
161
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
162
+ video_path = os.path.join(user_dir, 'sample.mp4')
163
+ imageio.mimsave(video_path, video, fps=15)
164
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
165
+ torch.cuda.empty_cache()
166
+ return state, video_path
167
+
168
+ @spaces.GPU(duration=90)
169
+ def extract_glb(
170
+ state: dict,
171
+ mesh_simplify: float,
172
+ texture_size: int,
173
+ req: gr.Request,
174
+ ) -> Tuple[str, str]:
175
+ """
176
+ Extract a GLB file from the 3D model.
177
+ Args:
178
+ state (dict): The state of the generated 3D model.
179
+ mesh_simplify (float): The mesh simplification factor.
180
+ texture_size (int): The texture resolution.
181
+ Returns:
182
+ str: The path to the extracted GLB file.
183
+ """
184
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
185
+ gs, mesh = unpack_state(state)
186
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
187
+ glb_path = os.path.join(user_dir, 'sample.glb')
188
+ glb.export(glb_path)
189
+ torch.cuda.empty_cache()
190
+ return glb_path, glb_path
191
+
192
+ @spaces.GPU
193
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
194
+ """
195
+ Extract a Gaussian file from the 3D model.
196
+ Args:
197
+ state (dict): The state of the generated 3D model.
198
+ Returns:
199
+ str: The path to the extracted Gaussian file.
200
+ """
201
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
+ gs, _ = unpack_state(state)
203
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
204
+ gs.save_ply(gaussian_path)
205
+ torch.cuda.empty_cache()
206
+ return gaussian_path, gaussian_path
207
 
208
  with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
209
  with gr.Row():
 
319
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
320
  except:
321
  pass
322
+ demo.launch()