notenoughram commited on
Commit
752eebf
Β·
verified Β·
1 Parent(s): 2618acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -114
app.py CHANGED
@@ -1,33 +1,20 @@
 
 
 
1
  import os
2
- import sys
3
- import subprocess
4
- import gc
5
  import shutil
6
- from typing import *
7
-
8
- # [AUTO-INSTALL] accelerate 라이브러리
9
- try:
10
- import accelerate
11
- except ImportError:
12
- subprocess.check_call([sys.executable, "-m", "pip", "install", "accelerate"])
13
-
14
- # [μ€‘μš”] OOM λ°©μ§€λ₯Ό μœ„ν•œ λ©”λͺ¨λ¦¬ νŒŒνŽΈν™” μ„€μ •
15
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
16
  os.environ['SPCONV_ALGO'] = 'native'
17
-
18
  import torch
19
- import torch.nn as nn
20
  import numpy as np
21
  import imageio
22
  from easydict import EasyDict as edict
23
  from PIL import Image
24
- import gradio as gr
25
- from gradio_litmodel3d import LitModel3D
26
-
27
  from trellis.pipelines import TrellisVGGTTo3DPipeline
28
  from trellis.representations import Gaussian, MeshExtractResult
29
  from trellis.utils import render_utils, postprocessing_utils
30
- from accelerate import dispatch_model, infer_auto_device_map
 
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -37,18 +24,42 @@ def start_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  os.makedirs(user_dir, exist_ok=True)
39
 
 
40
  def end_session(req: gr.Request):
41
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
42
- if os.path.exists(user_dir):
43
- shutil.rmtree(user_dir)
44
- gc.collect()
45
- torch.cuda.empty_cache()
46
 
47
  def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
50
 
51
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  vid = imageio.get_reader(video, 'ffmpeg')
53
  fps = vid.get_meta_data()['fps']
54
  images = []
@@ -63,10 +74,23 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
63
  return processed_images
64
 
65
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
66
  images = [image[0] for image in images]
67
  processed_images = [pipeline.preprocess_image(image) for image in images]
68
  return processed_images
69
 
 
70
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
71
  return {
72
  'gaussian': {
@@ -82,9 +106,9 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
82
  'faces': mesh.faces.cpu().numpy(),
83
  },
84
  }
85
-
 
86
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
87
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
88
  gs = Gaussian(
89
  aabb=state['gaussian']['aabb'],
90
  sh_degree=state['gaussian']['sh_degree'],
@@ -93,21 +117,37 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
93
  opacity_bias=state['gaussian']['opacity_bias'],
94
  scaling_activation=state['gaussian']['scaling_activation'],
95
  )
96
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=device)
97
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=device)
98
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=device)
99
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=device)
100
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=device)
101
 
102
  mesh = edict(
103
- vertices=torch.tensor(state['mesh']['vertices'], device=device),
104
- faces=torch.tensor(state['mesh']['faces'], device=device),
105
  )
 
106
  return gs, mesh
107
 
 
108
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
110
 
 
111
  def generate_and_extract_glb(
112
  multiimages: List[Tuple[Image.Image, str]],
113
  seed: int,
@@ -120,35 +160,54 @@ def generate_and_extract_glb(
120
  texture_size: int,
121
  req: gr.Request,
122
  ) -> Tuple[dict, str, str, str]:
123
-
124
- gc.collect()
125
- torch.cuda.empty_cache()
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
  image_files = [image[0] for image in multiimages]
129
 
130
- try:
131
- # [μ€‘μš”] μΆ”λ‘  μ‹œ κ·Έλž˜λ””μ–ΈνŠΈ 계산 끔 (λ©”λͺ¨λ¦¬ μ ˆμ•½)
132
- with torch.no_grad():
133
- outputs, _, _ = pipeline.run(
134
- image=image_files,
135
- seed=seed,
136
- formats=["gaussian", "mesh"],
137
- preprocess_image=False,
138
- sparse_structure_sampler_params={
139
- "steps": ss_sampling_steps,
140
- "cfg_strength": ss_guidance_strength,
141
- },
142
- slat_sampler_params={
143
- "steps": slat_sampling_steps,
144
- "cfg_strength": slat_guidance_strength,
145
- },
146
- mode=multiimage_algo,
147
- )
148
- except Exception as e:
149
- torch.cuda.empty_cache()
150
- # ꡬ체적인 μ—λŸ¬ λ©”μ‹œμ§€ λ°˜ν™˜
151
- raise RuntimeError(f"Generation Failed: {str(e)}")
 
152
 
153
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
154
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
@@ -156,32 +215,44 @@ def generate_and_extract_glb(
156
  video_path = os.path.join(user_dir, 'sample.mp4')
157
  imageio.mimsave(video_path, video, fps=15)
158
 
 
159
  gs = outputs['gaussian'][0]
160
  mesh = outputs['mesh'][0]
161
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
162
  glb_path = os.path.join(user_dir, 'sample.glb')
163
  glb.export(glb_path)
164
 
 
165
  state = pack_state(gs, mesh)
166
 
167
- del outputs, gs, mesh, glb
168
- gc.collect()
169
  torch.cuda.empty_cache()
170
-
171
  return state, video_path, glb_path, glb_path
172
 
 
173
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
175
  gs, _ = unpack_state(state)
176
  gaussian_path = os.path.join(user_dir, 'sample.ply')
177
  gs.save_ply(gaussian_path)
178
- del gs
179
  torch.cuda.empty_cache()
180
  return gaussian_path, gaussian_path
181
 
 
182
  def prepare_multi_example() -> List[Image.Image]:
183
- if not os.path.exists("assets/example_multi_image"):
184
- return []
185
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
186
  images = []
187
  for case in multi_case:
@@ -196,7 +267,21 @@ def prepare_multi_example() -> List[Image.Image]:
196
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
197
  return images
198
 
 
199
  def split_image(image: Image.Image) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  image = np.array(image)
201
  alpha = image[..., 3]
202
  alpha = np.any(alpha>0, axis=0)
@@ -219,7 +304,22 @@ demo = gr.Blocks(
219
  """
220
  )
221
  with demo:
222
- gr.Markdown("# πŸ’» ReconViaGen (GPU 0 Freed)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Row():
225
  with gr.Column():
@@ -228,6 +328,9 @@ with demo:
228
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
229
  image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
230
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
 
 
 
231
 
232
  with gr.Accordion(label="Generation Settings", open=False):
233
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -248,6 +351,9 @@ with demo:
248
 
249
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
250
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
251
 
252
  with gr.Column():
253
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
@@ -259,6 +365,7 @@ with demo:
259
 
260
  output_buf = gr.State()
261
 
 
262
  with gr.Row() as multiimage_example:
263
  examples_multi = gr.Examples(
264
  examples=prepare_multi_example(),
@@ -273,12 +380,25 @@ with demo:
273
  demo.load(start_session)
274
  demo.unload(end_session)
275
 
276
- input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
277
- input_video.clear(lambda: tuple([None, None]), outputs=[input_video, multiimage_prompt])
278
- multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  generate_btn.click(
281
- get_seed, inputs=[randomize_seed, seed], outputs=[seed]
 
 
282
  ).then(
283
  generate_and_extract_glb,
284
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
@@ -293,59 +413,25 @@ with demo:
293
  outputs=[extract_gs_btn, download_glb, download_gs],
294
  )
295
 
296
- extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs]).then(
297
- lambda: gr.Button(interactive=True), outputs=[download_gs]
 
 
 
 
 
298
  )
299
 
300
  model_output.clear(
301
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
302
  outputs=[download_glb, download_gs],
303
  )
 
304
 
305
- # Launch Script
306
  if __name__ == "__main__":
307
- print("πŸš€ Initializing Pipeline...")
308
- # 1. Pipeline λ‘œλ“œ
309
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
310
-
311
- # 2. λͺ¨λ“  λͺ¨λΈμ„ 일단 CUDA:0에 μ˜¬λ €μ„œ κΈ°λ³Έ μ„€μ •(device mismatch λ°©μ§€)을 μ™„λ£Œν•¨
312
  pipeline.cuda()
313
- pipeline._device = torch.device("cuda:0") # λ‚΄λΆ€ device 속성 κ³ μ •
314
-
315
- gpu_count = torch.cuda.device_count()
316
- print(f"⚑ Detected {gpu_count} GPUs.")
317
-
318
- if gpu_count > 1:
319
- print("⚑ Multi-GPU Mode: Offloading VGGT from GPU 0.")
320
-
321
- # [핡심 둜직] GPU 0을 λΉ„μš°κΈ° μœ„ν•œ μ „λž΅
322
- # VGGT λͺ¨λΈμ„ μž μ‹œ CPU둜 λ‚΄λ¦½λ‹ˆλ‹€.
323
- pipeline.VGGT_model.cpu()
324
-
325
- print(" - Calculating Device Map (Banning GPU 0 for VGGT)...")
326
-
327
- # max_memory μ„€μ •:
328
- # GPU 0: "10MiB" (사싀상 VGGT λͺ¨λΈ 적재 κΈˆμ§€)
329
- # GPU 1~N: "20GiB" (μ—¬μœ λ‘­κ²Œ ν• λ‹Ή)
330
- max_mem = {0: "10MiB"}
331
- for i in range(1, gpu_count):
332
- max_mem[i] = "20GiB"
333
-
334
- # 이 μ„€μ •μœΌλ‘œ 맡을 짜면 accelerateλŠ” GPU 0을 κ±΄λ„ˆλ›°κ³  GPU 1λΆ€ν„° λͺ¨λΈμ„ μ±„μ›λ‹ˆλ‹€.
335
- device_map = infer_auto_device_map(
336
- pipeline.VGGT_model,
337
- max_memory=max_mem,
338
- no_split_module_classes=["Block", "ResnetBlock"]
339
- )
340
-
341
- # λ§΅ μ μš©ν•˜μ—¬ λΆ„μ‚° λ‘œλ“œ
342
- pipeline.VGGT_model = dispatch_model(pipeline.VGGT_model, device_map=device_map)
343
-
344
- print("βœ… VGGT Model successfully pushed to GPU 1+.")
345
- print(" - GPU 0: Birefnet (Preprocessing) + Controller")
346
- print(" - GPU 1+: VGGT (Inference)")
347
-
348
- else:
349
- print("⚠️ Warning: Only 1 GPU detected. Expect OOM if VRAM < 24GB.")
350
-
351
  demo.launch()
 
1
+ import gradio as gr
2
+ from gradio_litmodel3d import LitModel3D
3
+
4
  import os
 
 
 
5
  import shutil
 
 
 
 
 
 
 
 
 
 
6
  os.environ['SPCONV_ALGO'] = 'native'
7
+ from typing import *
8
  import torch
 
9
  import numpy as np
10
  import imageio
11
  from easydict import EasyDict as edict
12
  from PIL import Image
 
 
 
13
  from trellis.pipelines import TrellisVGGTTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
+
17
+
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
24
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
25
  os.makedirs(user_dir, exist_ok=True)
26
 
27
+
28
  def end_session(req: gr.Request):
29
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
+ shutil.rmtree(user_dir)
 
 
 
31
 
32
  def preprocess_image(image: Image.Image) -> Image.Image:
33
+ """
34
+ Preprocess the input image for 3D generation.
35
+
36
+ This function is called when a user uploads an image or selects an example.
37
+ It applies background removal and other preprocessing steps necessary for
38
+ optimal 3D model generation.
39
+
40
+ Args:
41
+ image (Image.Image): The input image from the user
42
+
43
+ Returns:
44
+ Image.Image: The preprocessed image ready for 3D generation
45
+ """
46
  processed_image = pipeline.preprocess_image(image)
47
  return processed_image
48
 
49
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
50
+ """
51
+ Preprocess the input video for multi-image 3D generation.
52
+
53
+ This function is called when a user uploads a video.
54
+ It extracts frames from the video and processes each frame to prepare them
55
+ for the multi-image 3D generation pipeline.
56
+
57
+ Args:
58
+ video (str): The path to the input video file
59
+
60
+ Returns:
61
+ List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
62
+ """
63
  vid = imageio.get_reader(video, 'ffmpeg')
64
  fps = vid.get_meta_data()['fps']
65
  images = []
 
74
  return processed_images
75
 
76
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
77
+ """
78
+ Preprocess a list of input images for multi-image 3D generation.
79
+
80
+ This function is called when users upload multiple images in the gallery.
81
+ It processes each image to prepare them for the multi-image 3D generation pipeline.
82
+
83
+ Args:
84
+ images (List[Tuple[Image.Image, str]]): The input images from the gallery
85
+
86
+ Returns:
87
+ List[Image.Image]: The preprocessed images ready for 3D generation
88
+ """
89
  images = [image[0] for image in images]
90
  processed_images = [pipeline.preprocess_image(image) for image in images]
91
  return processed_images
92
 
93
+
94
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
95
  return {
96
  'gaussian': {
 
106
  'faces': mesh.faces.cpu().numpy(),
107
  },
108
  }
109
+
110
+
111
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
 
112
  gs = Gaussian(
113
  aabb=state['gaussian']['aabb'],
114
  sh_degree=state['gaussian']['sh_degree'],
 
117
  opacity_bias=state['gaussian']['opacity_bias'],
118
  scaling_activation=state['gaussian']['scaling_activation'],
119
  )
120
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
121
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
122
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
123
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
124
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
125
 
126
  mesh = edict(
127
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
128
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
129
  )
130
+
131
  return gs, mesh
132
 
133
+
134
  def get_seed(randomize_seed: bool, seed: int) -> int:
135
+ """
136
+ Get the random seed for generation.
137
+
138
+ This function is called by the generate button to determine whether to use
139
+ a random seed or the user-specified seed value.
140
+
141
+ Args:
142
+ randomize_seed (bool): Whether to generate a random seed
143
+ seed (int): The user-specified seed value
144
+
145
+ Returns:
146
+ int: The seed to use for generation
147
+ """
148
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
149
 
150
+
151
  def generate_and_extract_glb(
152
  multiimages: List[Tuple[Image.Image, str]],
153
  seed: int,
 
160
  texture_size: int,
161
  req: gr.Request,
162
  ) -> Tuple[dict, str, str, str]:
163
+ """
164
+ Convert an image to a 3D model and extract GLB file.
165
+
166
+ Args:
167
+ image (Image.Image): The input image.
168
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
169
+ is_multiimage (bool): Whether is in multi-image mode.
170
+ seed (int): The random seed.
171
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
172
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
173
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
174
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
175
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
176
+ mesh_simplify (float): The mesh simplification factor.
177
+ texture_size (int): The texture resolution.
178
+
179
+ Returns:
180
+ dict: The information of the generated 3D model.
181
+ str: The path to the video of the 3D model.
182
+ str: The path to the extracted GLB file.
183
+ str: The path to the extracted GLB file (for download).
184
+ """
185
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
186
  image_files = [image[0] for image in multiimages]
187
 
188
+ # Generate 3D model
189
+ outputs, _, _ = pipeline.run(
190
+ image=image_files,
191
+ seed=seed,
192
+ formats=["gaussian", "mesh"],
193
+ preprocess_image=False,
194
+ sparse_structure_sampler_params={
195
+ "steps": ss_sampling_steps,
196
+ "cfg_strength": ss_guidance_strength,
197
+ },
198
+ slat_sampler_params={
199
+ "steps": slat_sampling_steps,
200
+ "cfg_strength": slat_guidance_strength,
201
+ },
202
+ mode=multiimage_algo,
203
+ )
204
+
205
+ # Render video
206
+ # import uuid
207
+ # output_id = str(uuid.uuid4())
208
+ # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
209
+ # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
210
+ # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
211
 
212
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
213
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
 
215
  video_path = os.path.join(user_dir, 'sample.mp4')
216
  imageio.mimsave(video_path, video, fps=15)
217
 
218
+ # Extract GLB
219
  gs = outputs['gaussian'][0]
220
  mesh = outputs['mesh'][0]
221
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
222
  glb_path = os.path.join(user_dir, 'sample.glb')
223
  glb.export(glb_path)
224
 
225
+ # Pack state for optional Gaussian extraction
226
  state = pack_state(gs, mesh)
227
 
 
 
228
  torch.cuda.empty_cache()
 
229
  return state, video_path, glb_path, glb_path
230
 
231
+
232
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
233
+ """
234
+ Extract a Gaussian splatting file from the generated 3D model.
235
+
236
+ This function is called when the user clicks "Extract Gaussian" button.
237
+ It converts the 3D model state into a .ply file format containing
238
+ Gaussian splatting data for advanced 3D applications.
239
+
240
+ Args:
241
+ state (dict): The state of the generated 3D model containing Gaussian data
242
+ req (gr.Request): Gradio request object for session management
243
+
244
+ Returns:
245
+ Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
246
+ """
247
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
248
  gs, _ = unpack_state(state)
249
  gaussian_path = os.path.join(user_dir, 'sample.ply')
250
  gs.save_ply(gaussian_path)
 
251
  torch.cuda.empty_cache()
252
  return gaussian_path, gaussian_path
253
 
254
+
255
  def prepare_multi_example() -> List[Image.Image]:
 
 
256
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
257
  images = []
258
  for case in multi_case:
 
267
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
268
  return images
269
 
270
+
271
  def split_image(image: Image.Image) -> List[Image.Image]:
272
+ """
273
+ Split a multi-view image into separate view images.
274
+
275
+ This function is called when users select multi-image examples that contain
276
+ multiple views in a single concatenated image. It automatically splits them
277
+ based on alpha channel boundaries and preprocesses each view.
278
+
279
+ Args:
280
+ image (Image.Image): A concatenated image containing multiple views
281
+
282
+ Returns:
283
+ List[Image.Image]: List of individual preprocessed view images
284
+ """
285
  image = np.array(image)
286
  alpha = image[..., 3]
287
  alpha = np.any(alpha>0, axis=0)
 
304
  """
305
  )
306
  with demo:
307
+ gr.Markdown("""
308
+ # πŸ’» ReconViaGen
309
+ <p align="center">
310
+ <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
311
+ <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
312
+ </a>
313
+ <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
314
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
315
+ </a>
316
+ <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
317
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
318
+ </a>
319
+ </p>
320
+
321
+ ✨This demo is partial. We will release the whole model later. Stay tuned!✨
322
+ """)
323
 
324
  with gr.Row():
325
  with gr.Column():
 
328
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
329
  image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
330
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
331
+ gr.Markdown("""
332
+ Input different views of the object in separate images.
333
+ """)
334
 
335
  with gr.Accordion(label="Generation Settings", open=False):
336
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
351
 
352
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
353
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
354
+ gr.Markdown("""
355
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
356
+ """)
357
 
358
  with gr.Column():
359
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
365
 
366
  output_buf = gr.State()
367
 
368
+ # Example images at the bottom of the page
369
  with gr.Row() as multiimage_example:
370
  examples_multi = gr.Examples(
371
  examples=prepare_multi_example(),
 
380
  demo.load(start_session)
381
  demo.unload(end_session)
382
 
383
+ input_video.upload(
384
+ preprocess_videos,
385
+ inputs=[input_video],
386
+ outputs=[multiimage_prompt],
387
+ )
388
+ input_video.clear(
389
+ lambda: tuple([None, None]),
390
+ outputs=[input_video, multiimage_prompt],
391
+ )
392
+ multiimage_prompt.upload(
393
+ preprocess_images,
394
+ inputs=[multiimage_prompt],
395
+ outputs=[multiimage_prompt],
396
+ )
397
 
398
  generate_btn.click(
399
+ get_seed,
400
+ inputs=[randomize_seed, seed],
401
+ outputs=[seed],
402
  ).then(
403
  generate_and_extract_glb,
404
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
 
413
  outputs=[extract_gs_btn, download_glb, download_gs],
414
  )
415
 
416
+ extract_gs_btn.click(
417
+ extract_gaussian,
418
+ inputs=[output_buf],
419
+ outputs=[model_output, download_gs],
420
+ ).then(
421
+ lambda: gr.Button(interactive=True),
422
+ outputs=[download_gs],
423
  )
424
 
425
  model_output.clear(
426
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
427
  outputs=[download_glb, download_gs],
428
  )
429
+
430
 
431
+ # Launch the Gradio app
432
  if __name__ == "__main__":
 
 
433
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
 
 
434
  pipeline.cuda()
435
+ pipeline.VGGT_model.cuda()
436
+ pipeline.birefnet_model.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  demo.launch()