xinjie.wang commited on
Commit
ddc47cd
·
1 Parent(s): a8ea627
app.py CHANGED
@@ -17,7 +17,9 @@
17
 
18
  import os
19
 
20
- os.environ["GRADIO_APP"] = "imageto3d"
 
 
21
  from glob import glob
22
 
23
  import gradio as gr
@@ -30,13 +32,24 @@ from common import (
30
  extract_3d_representations_v3,
31
  extract_urdf,
32
  get_seed,
33
- image_to_3d,
34
  preprocess_image_fn,
35
  preprocess_sam_image_fn,
36
  select_point,
37
  start_session,
38
  )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
41
  gr.HTML(image_css, visible=False)
42
  # gr.HTML(lighting_css, visible=False)
@@ -67,7 +80,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
67
  )
68
 
69
  with gr.Row():
70
- with gr.Column(scale=2):
71
  with gr.Tabs() as input_tabs:
72
  with gr.Tab(
73
  label="Image(auto seg)", id=0
@@ -163,7 +176,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
163
  step=0.1,
164
  )
165
  ss_sampling_steps = gr.Slider(
166
- 1, 50, label="Sampling Steps", value=12, step=1
 
 
 
 
167
  )
168
  gr.Markdown("Visual Appearance Generation")
169
  with gr.Row():
@@ -175,7 +192,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
175
  step=0.1,
176
  )
177
  slat_sampling_steps = gr.Slider(
178
- 1, 50, label="Sampling Steps", value=12, step=1
 
 
 
 
179
  )
180
 
181
  generate_btn = gr.Button(
@@ -242,7 +263,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
242
  has quality inspection, open with an editor to view details.
243
  """
244
  )
245
-
246
  with gr.Row() as single_image_example:
247
  examples = gr.Examples(
248
  label="Image Gallery",
@@ -252,7 +273,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
252
  glob("assets/example_image/*")
253
  )
254
  ],
255
- inputs=[image_prompt, rmbg_tag],
256
  fn=preprocess_image_fn,
257
  outputs=[image_prompt, raw_image_cache],
258
  run_on_click=True,
@@ -274,16 +295,16 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
274
  run_on_click=True,
275
  examples_per_page=10,
276
  )
277
- with gr.Column(scale=1):
278
  gr.Markdown("<br>")
279
  video_output = gr.Video(
280
  label="Generated 3D Asset",
281
  autoplay=True,
282
  loop=True,
283
- height=300,
284
  )
285
  model_output_gs = gr.Model3D(
286
- label="Gaussian Representation", height=300, interactive=False
287
  )
288
  aligned_gs = gr.Textbox(visible=False)
289
  gr.Markdown(
@@ -292,9 +313,9 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
292
  with gr.Row():
293
  model_output_mesh = gr.Model3D(
294
  label="Mesh Representation",
295
- height=300,
296
  interactive=False,
297
- clear_color=[0.8, 0.8, 0.8, 1],
298
  elem_id="lighter_mesh",
299
  )
300
 
@@ -320,7 +341,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
320
 
321
  image_prompt.upload(
322
  preprocess_image_fn,
323
- inputs=[image_prompt, rmbg_tag],
324
  outputs=[image_prompt, raw_image_cache],
325
  )
326
  image_prompt.change(
@@ -437,11 +458,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
437
  inputs=[
438
  image_prompt,
439
  seed,
440
- ss_guidance_strength,
441
  ss_sampling_steps,
442
- slat_guidance_strength,
443
  slat_sampling_steps,
444
  raw_image_cache,
 
 
445
  image_seg_sam,
446
  is_samimage,
447
  ],
 
17
 
18
  import os
19
 
20
+ # GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
21
+ # GRADIO_APP == "imageto3d", TRELLIS model.
22
+ os.environ["GRADIO_APP"] = "imageto3d_sam3d"
23
  from glob import glob
24
 
25
  import gradio as gr
 
32
  extract_3d_representations_v3,
33
  extract_urdf,
34
  get_seed,
 
35
  preprocess_image_fn,
36
  preprocess_sam_image_fn,
37
  select_point,
38
  start_session,
39
  )
40
 
41
+ app_name = os.getenv("GRADIO_APP")
42
+ if app_name == "imageto3d_sam3d":
43
+ from common import image_to_3d_sam3d as image_to_3d
44
+
45
+ enable_pre_resize = False
46
+ sample_step = 25
47
+ elif app_name == "imageto3d":
48
+ from common import image_to_3d
49
+
50
+ enable_pre_resize = True
51
+ sample_step = 12
52
+
53
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
54
  gr.HTML(image_css, visible=False)
55
  # gr.HTML(lighting_css, visible=False)
 
80
  )
81
 
82
  with gr.Row():
83
+ with gr.Column(scale=3):
84
  with gr.Tabs() as input_tabs:
85
  with gr.Tab(
86
  label="Image(auto seg)", id=0
 
176
  step=0.1,
177
  )
178
  ss_sampling_steps = gr.Slider(
179
+ 1,
180
+ 50,
181
+ label="Sampling Steps",
182
+ value=sample_step,
183
+ step=1,
184
  )
185
  gr.Markdown("Visual Appearance Generation")
186
  with gr.Row():
 
192
  step=0.1,
193
  )
194
  slat_sampling_steps = gr.Slider(
195
+ 1,
196
+ 50,
197
+ label="Sampling Steps",
198
+ value=sample_step,
199
+ step=1,
200
  )
201
 
202
  generate_btn = gr.Button(
 
263
  has quality inspection, open with an editor to view details.
264
  """
265
  )
266
+ enable_pre_resize = gr.State(enable_pre_resize)
267
  with gr.Row() as single_image_example:
268
  examples = gr.Examples(
269
  label="Image Gallery",
 
273
  glob("assets/example_image/*")
274
  )
275
  ],
276
+ inputs=[image_prompt, rmbg_tag, enable_pre_resize],
277
  fn=preprocess_image_fn,
278
  outputs=[image_prompt, raw_image_cache],
279
  run_on_click=True,
 
295
  run_on_click=True,
296
  examples_per_page=10,
297
  )
298
+ with gr.Column(scale=2):
299
  gr.Markdown("<br>")
300
  video_output = gr.Video(
301
  label="Generated 3D Asset",
302
  autoplay=True,
303
  loop=True,
304
+ height=400,
305
  )
306
  model_output_gs = gr.Model3D(
307
+ label="Gaussian Representation", height=350, interactive=False
308
  )
309
  aligned_gs = gr.Textbox(visible=False)
310
  gr.Markdown(
 
313
  with gr.Row():
314
  model_output_mesh = gr.Model3D(
315
  label="Mesh Representation",
316
+ height=350,
317
  interactive=False,
318
+ clear_color=[0, 0, 0, 1],
319
  elem_id="lighter_mesh",
320
  )
321
 
 
341
 
342
  image_prompt.upload(
343
  preprocess_image_fn,
344
+ inputs=[image_prompt, rmbg_tag, enable_pre_resize],
345
  outputs=[image_prompt, raw_image_cache],
346
  )
347
  image_prompt.change(
 
458
  inputs=[
459
  image_prompt,
460
  seed,
 
461
  ss_sampling_steps,
 
462
  slat_sampling_steps,
463
  raw_image_cache,
464
+ ss_guidance_strength,
465
+ slat_guidance_strength,
466
  image_seg_sam,
467
  is_samimage,
468
  ],
app_style.py CHANGED
@@ -1,10 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from gradio.themes import Soft
2
  from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
3
 
4
  lighting_css = """
5
  <style>
6
  #lighter_mesh canvas {
7
- filter: brightness(2.0) !important;
8
  }
9
  </style>
10
  """
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
  from gradio.themes import Soft
18
  from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
19
 
20
  lighting_css = """
21
  <style>
22
  #lighter_mesh canvas {
23
+ filter: brightness(2.3) !important;
24
  }
25
  </style>
26
  """
common.py CHANGED
@@ -151,6 +151,21 @@ if os.getenv("GRADIO_APP") == "imageto3d":
151
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
152
  )
153
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  elif os.getenv("GRADIO_APP") == "textto3d":
155
  RBG_REMOVER = RembgRemover()
156
  RBG14_REMOVER = BMGG14Remover()
@@ -169,6 +184,23 @@ elif os.getenv("GRADIO_APP") == "textto3d":
169
  os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
170
  )
171
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  elif os.getenv("GRADIO_APP") == "texture_edit":
173
  DELIGHT = DelightingModel()
174
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
@@ -201,18 +233,22 @@ def end_session(req: gr.Request) -> None:
201
 
202
  @spaces.GPU
203
  def preprocess_image_fn(
204
- image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
 
 
205
  ) -> tuple[Image.Image, Image.Image]:
206
  if isinstance(image, str):
207
  image = Image.open(image)
208
  elif isinstance(image, np.ndarray):
209
  image = Image.fromarray(image)
210
 
211
- image_cache = resize_pil(image.copy(), 1024)
212
 
213
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
214
  image = bg_remover(image)
215
- image = trellis_preprocess(image)
 
 
216
 
217
  return image, image_cache
218
 
@@ -349,11 +385,11 @@ def select_point(
349
  def image_to_3d(
350
  image: Image.Image,
351
  seed: int,
352
- ss_guidance_strength: float,
353
  ss_sampling_steps: int,
354
- slat_guidance_strength: float,
355
  slat_sampling_steps: int,
356
  raw_image_cache: Image.Image,
 
 
357
  sam_image: Image.Image = None,
358
  is_sam_image: bool = False,
359
  req: gr.Request = None,
@@ -392,8 +428,56 @@ def image_to_3d(
392
 
393
  gs_model = outputs["gaussian"][0]
394
  mesh_model = outputs["mesh"][0]
395
- color_images = render_video(gs_model)["color"]
396
- normal_images = render_video(mesh_model)["normal"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  video_path = os.path.join(output_root, "gs_mesh.mp4")
399
  merge_images_video(color_images, normal_images, video_path)
@@ -688,6 +772,7 @@ def text2image_fn(
688
  image_wh: int | tuple[int, int] = [1024, 1024],
689
  rmbg_tag: str = "rembg",
690
  seed: int = None,
 
691
  n_sample: int = 3,
692
  req: gr.Request = None,
693
  ):
@@ -715,7 +800,9 @@ def text2image_fn(
715
 
716
  for idx in range(len(images)):
717
  image = images[idx]
718
- images[idx], _ = preprocess_image_fn(image, rmbg_tag)
 
 
719
 
720
  save_paths = []
721
  for idx, image in enumerate(images):
 
151
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
152
  )
153
  os.makedirs(TMP_DIR, exist_ok=True)
154
+ elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
155
+ from embodied_gen.models.sam3d import Sam3dInference
156
+
157
+ RBG_REMOVER = RembgRemover()
158
+ RBG14_REMOVER = BMGG14Remover()
159
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
160
+ PIPELINE = Sam3dInference()
161
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
162
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
163
+ AESTHETIC_CHECKER = ImageAestheticChecker()
164
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
165
+ TMP_DIR = os.path.join(
166
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
167
+ )
168
+ os.makedirs(TMP_DIR, exist_ok=True)
169
  elif os.getenv("GRADIO_APP") == "textto3d":
170
  RBG_REMOVER = RembgRemover()
171
  RBG14_REMOVER = BMGG14Remover()
 
184
  os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
185
  )
186
  os.makedirs(TMP_DIR, exist_ok=True)
187
+ elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
188
+ from embodied_gen.models.sam3d import Sam3dInference
189
+
190
+ RBG_REMOVER = RembgRemover()
191
+ RBG14_REMOVER = BMGG14Remover()
192
+ PIPELINE = Sam3dInference()
193
+ text_model_dir = "weights/Kolors"
194
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
195
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
196
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
197
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
198
+ AESTHETIC_CHECKER = ImageAestheticChecker()
199
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
200
+ TMP_DIR = os.path.join(
201
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
202
+ )
203
+ os.makedirs(TMP_DIR, exist_ok=True)
204
  elif os.getenv("GRADIO_APP") == "texture_edit":
205
  DELIGHT = DelightingModel()
206
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
 
233
 
234
  @spaces.GPU
235
  def preprocess_image_fn(
236
+ image: str | np.ndarray | Image.Image,
237
+ rmbg_tag: str = "rembg",
238
+ preprocess: bool = True,
239
  ) -> tuple[Image.Image, Image.Image]:
240
  if isinstance(image, str):
241
  image = Image.open(image)
242
  elif isinstance(image, np.ndarray):
243
  image = Image.fromarray(image)
244
 
245
+ image_cache = image.copy() # resize_pil(image.copy(), 1024)
246
 
247
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
248
  image = bg_remover(image)
249
+
250
+ if preprocess:
251
+ image = trellis_preprocess(image)
252
 
253
  return image, image_cache
254
 
 
385
  def image_to_3d(
386
  image: Image.Image,
387
  seed: int,
 
388
  ss_sampling_steps: int,
 
389
  slat_sampling_steps: int,
390
  raw_image_cache: Image.Image,
391
+ ss_guidance_strength: float,
392
+ slat_guidance_strength: float,
393
  sam_image: Image.Image = None,
394
  is_sam_image: bool = False,
395
  req: gr.Request = None,
 
428
 
429
  gs_model = outputs["gaussian"][0]
430
  mesh_model = outputs["mesh"][0]
431
+ color_images = render_video(gs_model, r=1.85)["color"]
432
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
433
+
434
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
435
+ merge_images_video(color_images, normal_images, video_path)
436
+ state = pack_state(gs_model, mesh_model)
437
+
438
+ gc.collect()
439
+ torch.cuda.empty_cache()
440
+
441
+ return state, video_path
442
+
443
+
444
+ @spaces.GPU
445
+ def image_to_3d_sam3d(
446
+ image: Image.Image,
447
+ seed: int,
448
+ ss_sampling_steps: int,
449
+ slat_sampling_steps: int,
450
+ raw_image_cache: Image.Image,
451
+ ss_guidance_strength: float = None,
452
+ slat_guidance_strength: float = None,
453
+ sam_image: Image.Image = None,
454
+ is_sam_image: bool = False,
455
+ req: gr.Request = None,
456
+ ) -> tuple[dict, str]:
457
+ if is_sam_image:
458
+ seg_image = filter_image_small_connected_components(sam_image)
459
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
460
+ else:
461
+ seg_image = image
462
+
463
+ if isinstance(seg_image, np.ndarray):
464
+ seg_image = Image.fromarray(seg_image)
465
+
466
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
467
+ os.makedirs(output_root, exist_ok=True)
468
+ seg_image.save(f"{output_root}/seg_image.png")
469
+ raw_image_cache.save(f"{output_root}/raw_image.png")
470
+ outputs = PIPELINE.run(
471
+ seg_image,
472
+ seed=seed,
473
+ stage1_inference_steps=ss_sampling_steps,
474
+ stage2_inference_steps=slat_sampling_steps,
475
+ )
476
+
477
+ gs_model = outputs["gaussian"][0]
478
+ mesh_model = outputs["mesh"][0]
479
+ color_images = render_video(gs_model, r=1.85)["color"]
480
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
481
 
482
  video_path = os.path.join(output_root, "gs_mesh.mp4")
483
  merge_images_video(color_images, normal_images, video_path)
 
772
  image_wh: int | tuple[int, int] = [1024, 1024],
773
  rmbg_tag: str = "rembg",
774
  seed: int = None,
775
+ enable_pre_resize: bool = True,
776
  n_sample: int = 3,
777
  req: gr.Request = None,
778
  ):
 
800
 
801
  for idx in range(len(images)):
802
  image = images[idx]
803
+ images[idx], _ = preprocess_image_fn(
804
+ image, rmbg_tag, enable_pre_resize
805
+ )
806
 
807
  save_paths = []
808
  for idx, image in enumerate(images):
embodied_gen/data/backproject_v3.py CHANGED
@@ -14,10 +14,10 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
- import os
18
  import argparse
19
  import logging
20
  import math
 
21
  from typing import Literal, Union
22
 
23
  import cv2
@@ -353,8 +353,8 @@ def parse_args():
353
  parser.add_argument(
354
  "--distance",
355
  type=float,
356
- default=5,
357
- help="Camera distance (default: 5)",
358
  )
359
  parser.add_argument(
360
  "--resolution_hw",
@@ -400,8 +400,8 @@ def parse_args():
400
  parser.add_argument(
401
  "--mesh_sipmlify_ratio",
402
  type=float,
403
- default=0.9,
404
- help="Mesh simplification ratio (default: 0.9)",
405
  )
406
  parser.add_argument(
407
  "--delight", action="store_true", help="Use delighting model."
@@ -500,7 +500,7 @@ def entrypoint(
500
  faces = mesh.faces.astype(np.int32)
501
  vertices = vertices.astype(np.float32)
502
 
503
- if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
504
  mesh_fixer = MeshFixer(vertices, faces, args.device)
505
  vertices, faces = mesh_fixer(
506
  filter_ratio=args.mesh_sipmlify_ratio,
@@ -512,7 +512,7 @@ def entrypoint(
512
  if len(faces) > args.n_max_faces:
513
  mesh_fixer = MeshFixer(vertices, faces, args.device)
514
  vertices, faces = mesh_fixer(
515
- filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
516
  max_hole_size=0.04,
517
  resolution=1024,
518
  num_views=1000,
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
 
17
  import argparse
18
  import logging
19
  import math
20
+ import os
21
  from typing import Literal, Union
22
 
23
  import cv2
 
353
  parser.add_argument(
354
  "--distance",
355
  type=float,
356
+ default=4.5,
357
+ help="Camera distance (default: 4.5)",
358
  )
359
  parser.add_argument(
360
  "--resolution_hw",
 
400
  parser.add_argument(
401
  "--mesh_sipmlify_ratio",
402
  type=float,
403
+ default=0.85,
404
+ help="Mesh simplification ratio (default: 0.85)",
405
  )
406
  parser.add_argument(
407
  "--delight", action="store_true", help="Use delighting model."
 
500
  faces = mesh.faces.astype(np.int32)
501
  vertices = vertices.astype(np.float32)
502
 
503
+ if not args.skip_fix_mesh:
504
  mesh_fixer = MeshFixer(vertices, faces, args.device)
505
  vertices, faces = mesh_fixer(
506
  filter_ratio=args.mesh_sipmlify_ratio,
 
512
  if len(faces) > args.n_max_faces:
513
  mesh_fixer = MeshFixer(vertices, faces, args.device)
514
  vertices, faces = mesh_fixer(
515
+ filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1),
516
  max_hole_size=0.04,
517
  resolution=1024,
518
  num_views=1000,
embodied_gen/data/utils.py CHANGED
@@ -15,10 +15,13 @@
15
  # permissions and limitations under the License.
16
 
17
 
 
18
  import math
19
  import os
20
- import random
21
  import zipfile
 
 
22
  from shutil import rmtree
23
  from typing import List, Tuple, Union
24
 
@@ -28,20 +31,9 @@ import numpy as np
28
  import nvdiffrast.torch as dr
29
  import torch
30
  import torch.nn.functional as F
31
- from PIL import Image, ImageEnhance
32
-
33
- try:
34
- from kolors.models.modeling_chatglm import ChatGLMModel
35
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
36
- except ImportError:
37
- ChatGLMTokenizer = None
38
- ChatGLMModel = None
39
- import logging
40
- from dataclasses import dataclass, field
41
-
42
  import trimesh
43
  from kaolin.render.camera import Camera
44
- from torch import nn
45
 
46
  logger = logging.getLogger(__name__)
47
 
@@ -50,10 +42,8 @@ __all__ = [
50
  "DiffrastRender",
51
  "save_images",
52
  "render_pbr",
53
- "prelabel_text_feature",
54
  "calc_vertex_normals",
55
  "normalize_vertices_array",
56
- "load_mesh_to_unit_cube",
57
  "as_list",
58
  "CameraSetting",
59
  "import_kaolin_mesh",
@@ -67,6 +57,7 @@ __all__ = [
67
  "trellis_preprocess",
68
  "delete_dir",
69
  "kaolin_to_opencv_view",
 
70
  ]
71
 
72
 
@@ -520,114 +511,6 @@ def render_pbr(
520
  return image, albedo, diffuse, normal
521
 
522
 
523
- def _move_to_target_device(data, device: str):
524
- if isinstance(data, dict):
525
- for key, value in data.items():
526
- data[key] = _move_to_target_device(value, device)
527
- elif isinstance(data, torch.Tensor):
528
- return data.to(device)
529
-
530
- return data
531
-
532
-
533
- def _encode_prompt(
534
- prompt_batch,
535
- text_encoders,
536
- tokenizers,
537
- proportion_empty_prompts=0,
538
- is_train=True,
539
- ):
540
- prompt_embeds_list = []
541
-
542
- captions = []
543
- for caption in prompt_batch:
544
- if random.random() < proportion_empty_prompts:
545
- captions.append("")
546
- elif isinstance(caption, str):
547
- captions.append(caption)
548
- elif isinstance(caption, (list, np.ndarray)):
549
- captions.append(random.choice(caption) if is_train else caption[0])
550
-
551
- with torch.no_grad():
552
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
553
- text_inputs = tokenizer(
554
- captions,
555
- padding="max_length",
556
- max_length=256,
557
- truncation=True,
558
- return_tensors="pt",
559
- ).to(text_encoder.device)
560
-
561
- output = text_encoder(
562
- input_ids=text_inputs.input_ids,
563
- attention_mask=text_inputs.attention_mask,
564
- position_ids=text_inputs.position_ids,
565
- output_hidden_states=True,
566
- )
567
-
568
- # We are only interested in the pooled output of the text encoder.
569
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
570
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
571
- bs_embed, seq_len, _ = prompt_embeds.shape
572
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
573
- prompt_embeds_list.append(prompt_embeds)
574
-
575
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
576
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
577
-
578
- return prompt_embeds, pooled_prompt_embeds
579
-
580
-
581
- def load_llm_models(pretrained_model_name_or_path: str, device: str):
582
- tokenizer = ChatGLMTokenizer.from_pretrained(
583
- pretrained_model_name_or_path,
584
- subfolder="text_encoder",
585
- )
586
- text_encoder = ChatGLMModel.from_pretrained(
587
- pretrained_model_name_or_path,
588
- subfolder="text_encoder",
589
- ).to(device)
590
-
591
- text_encoders = [
592
- text_encoder,
593
- ]
594
- tokenizers = [
595
- tokenizer,
596
- ]
597
-
598
- logger.info(f"Load model from {pretrained_model_name_or_path} done.")
599
-
600
- return tokenizers, text_encoders
601
-
602
-
603
- def prelabel_text_feature(
604
- prompt_batch: List[str],
605
- output_dir: str,
606
- tokenizers: nn.Module,
607
- text_encoders: nn.Module,
608
- ) -> List[str]:
609
- os.makedirs(output_dir, exist_ok=True)
610
-
611
- # prompt_batch ["text..."]
612
- prompt_embeds, pooled_prompt_embeds = _encode_prompt(
613
- prompt_batch, text_encoders, tokenizers
614
- )
615
-
616
- prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
617
- pooled_prompt_embeds = _move_to_target_device(
618
- pooled_prompt_embeds, device="cpu"
619
- )
620
-
621
- data_dict = dict(
622
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
623
- )
624
-
625
- save_path = os.path.join(output_dir, "text_feat.pth")
626
- torch.save(data_dict, save_path)
627
-
628
- return save_path
629
-
630
-
631
  def _calc_face_normals(
632
  vertices: torch.Tensor, # V,3 first vertex may be unreferenced
633
  faces: torch.Tensor, # F,3 long, first face may be all zero
@@ -683,25 +566,6 @@ def normalize_vertices_array(
683
  return vertices, scale, center
684
 
685
 
686
- def load_mesh_to_unit_cube(
687
- mesh_file: str,
688
- mesh_scale: float = 1.0,
689
- ) -> tuple[trimesh.Trimesh, float, list[float]]:
690
- if not os.path.exists(mesh_file):
691
- raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
692
-
693
- mesh = trimesh.load(mesh_file)
694
- if isinstance(mesh, trimesh.Scene):
695
- mesh = trimesh.utils.concatenate(mesh)
696
-
697
- vertices, scale, center = normalize_vertices_array(
698
- mesh.vertices, mesh_scale
699
- )
700
- mesh.vertices = vertices
701
-
702
- return mesh, scale, center
703
-
704
-
705
  def as_list(obj):
706
  if isinstance(obj, (list, tuple)):
707
  return obj
@@ -998,8 +862,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
998
 
999
 
1000
  def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
1001
- max_size = max(image.size)
1002
- scale = min(1, 1024 / max_size)
 
1003
  if scale < 1:
1004
  new_size = (int(image.width * scale), int(image.height * scale))
1005
  image = image.resize(new_size, Image.Resampling.LANCZOS)
@@ -1068,3 +933,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
1068
  rmtree(item_path)
1069
  else:
1070
  os.remove(item_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # permissions and limitations under the License.
16
 
17
 
18
+ import logging
19
  import math
20
  import os
21
+ import time
22
  import zipfile
23
+ from contextlib import contextmanager
24
+ from dataclasses import dataclass, field
25
  from shutil import rmtree
26
  from typing import List, Tuple, Union
27
 
 
31
  import nvdiffrast.torch as dr
32
  import torch
33
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
34
  import trimesh
35
  from kaolin.render.camera import Camera
36
+ from PIL import Image, ImageEnhance
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
42
  "DiffrastRender",
43
  "save_images",
44
  "render_pbr",
 
45
  "calc_vertex_normals",
46
  "normalize_vertices_array",
 
47
  "as_list",
48
  "CameraSetting",
49
  "import_kaolin_mesh",
 
57
  "trellis_preprocess",
58
  "delete_dir",
59
  "kaolin_to_opencv_view",
60
+ "model_device_ctx",
61
  ]
62
 
63
 
 
511
  return image, albedo, diffuse, normal
512
 
513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  def _calc_face_normals(
515
  vertices: torch.Tensor, # V,3 first vertex may be unreferenced
516
  faces: torch.Tensor, # F,3 long, first face may be all zero
 
566
  return vertices, scale, center
567
 
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  def as_list(obj):
570
  if isinstance(obj, (list, tuple)):
571
  return obj
 
862
 
863
 
864
  def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
865
+ current_max_dim = max(image.size)
866
+ scale = min(1, max_size / current_max_dim)
867
+
868
  if scale < 1:
869
  new_size = (int(image.width * scale), int(image.height * scale))
870
  image = image.resize(new_size, Image.Resampling.LANCZOS)
 
933
  rmtree(item_path)
934
  else:
935
  os.remove(item_path)
936
+
937
+
938
+ @contextmanager
939
+ def model_device_ctx(
940
+ *models,
941
+ src_device: str = "cpu",
942
+ dst_device: str = "cuda",
943
+ verbose: bool = False,
944
+ ):
945
+ start = time.perf_counter()
946
+ for m in models:
947
+ if m is None:
948
+ continue
949
+ m.to(dst_device)
950
+ to_cuda_time = time.perf_counter() - start
951
+
952
+ try:
953
+ yield
954
+ finally:
955
+ start = time.perf_counter()
956
+ for m in models:
957
+ if m is None:
958
+ continue
959
+ m.to(src_device)
960
+ to_cpu_time = time.perf_counter() - start
961
+
962
+ if verbose:
963
+ model_names = [m.__class__.__name__ for m in models]
964
+ logger.debug(
965
+ f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s"
966
+ )
embodied_gen/models/sam3d.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
18
+
19
+ monkey_patch_sam3d()
20
+ import os
21
+ import sys
22
+ from typing import Optional, Union
23
+
24
+ import numpy as np
25
+ from hydra.utils import instantiate
26
+ from modelscope import snapshot_download
27
+ from omegaconf import OmegaConf
28
+ from PIL import Image
29
+
30
+ current_file_path = os.path.abspath(__file__)
31
+ current_dir = os.path.dirname(current_file_path)
32
+ sys.path.append(os.path.join(current_dir, "../.."))
33
+ from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
34
+ InferencePipelinePointMap,
35
+ )
36
+
37
+ __all__ = ["Sam3dInference"]
38
+
39
+
40
+ def load_image(path: str) -> np.ndarray:
41
+ image = Image.open(path)
42
+ image = np.array(image)
43
+ image = image.astype(np.uint8)
44
+ return image
45
+
46
+
47
+ def load_mask(path: str) -> np.ndarray:
48
+ mask = load_image(path)
49
+ mask = mask > 0
50
+ if mask.ndim == 3:
51
+ mask = mask[..., -1]
52
+ return mask
53
+
54
+
55
+ class Sam3dInference:
56
+ def __init__(
57
+ self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
58
+ ) -> None:
59
+ if not os.path.exists(local_dir):
60
+ snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
61
+ config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
62
+ config = OmegaConf.load(config_file)
63
+ config.rendering_engine = "nvdiffrast"
64
+ config.compile_model = compile
65
+ config.workspace_dir = os.path.dirname(config_file)
66
+ # Generate 4 gs in each pixel.
67
+ config["slat_decoder_gs_config_path"] = config.pop(
68
+ "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
69
+ )
70
+ config["slat_decoder_gs_ckpt_path"] = config.pop(
71
+ "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
72
+ )
73
+ self.pipeline: InferencePipelinePointMap = instantiate(config)
74
+
75
+ def merge_mask_to_rgba(
76
+ self, image: np.ndarray, mask: np.ndarray
77
+ ) -> np.ndarray:
78
+ mask = mask.astype(np.uint8) * 255
79
+ mask = mask[..., None]
80
+ rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
81
+
82
+ return rgba_image
83
+
84
+ def run(
85
+ self,
86
+ image: np.ndarray | Image.Image,
87
+ mask: np.ndarray = None,
88
+ seed: int = None,
89
+ pointmap: np.ndarray = None,
90
+ use_stage1_distillation: bool = False,
91
+ use_stage2_distillation: bool = False,
92
+ stage1_inference_steps: int = 25,
93
+ stage2_inference_steps: int = 25,
94
+ ) -> dict:
95
+ if isinstance(image, Image.Image):
96
+ image = np.array(image)
97
+ return self.pipeline.run(
98
+ image,
99
+ mask,
100
+ seed,
101
+ stage1_only=False,
102
+ with_mesh_postprocess=False,
103
+ with_texture_baking=False,
104
+ with_layout_postprocess=False,
105
+ use_vertex_color=True,
106
+ use_stage1_distillation=use_stage1_distillation,
107
+ use_stage2_distillation=use_stage2_distillation,
108
+ stage1_inference_steps=stage1_inference_steps,
109
+ stage2_inference_steps=stage2_inference_steps,
110
+ pointmap=pointmap,
111
+ )
112
+
113
+
114
+ if __name__ == "__main__":
115
+ pipeline = Sam3dInference()
116
+
117
+ # load image
118
+ image = load_image(
119
+ "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
120
+ )
121
+ mask = load_mask(
122
+ "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
123
+ )
124
+
125
+ import torch
126
+
127
+ if torch.cuda.is_available():
128
+ torch.cuda.reset_peak_memory_stats()
129
+ torch.cuda.empty_cache()
130
+
131
+ from time import time
132
+
133
+ start = time()
134
+
135
+ output = pipeline(image, mask, seed=42)
136
+ print(f"Running cost: {round(time()-start, 1)}")
137
+
138
+ if torch.cuda.is_available():
139
+ max_memory = torch.cuda.max_memory_allocated() / (1024**3)
140
+ print(f"(Max VRAM): {max_memory:.2f} GB")
141
+
142
+ print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
143
+
144
+ output["gs"].save_ply(f"outputs/splat.ply")
145
+ print("Your reconstruction has been saved to outputs/splat.ply")
embodied_gen/utils/monkey_patches.py CHANGED
@@ -25,6 +25,12 @@ from omegaconf import OmegaConf
25
  from PIL import Image
26
  from torchvision import transforms
27
 
 
 
 
 
 
 
28
 
29
  def monkey_patch_pano2room():
30
  current_file_path = os.path.abspath(__file__)
@@ -216,3 +222,378 @@ def monkey_patch_maniskill():
216
  ManiSkillScene.get_human_render_camera_images = (
217
  get_human_render_camera_images
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  from PIL import Image
26
  from torchvision import transforms
27
 
28
+ __all__ = [
29
+ "monkey_patch_pano2room",
30
+ "monkey_patch_maniskill",
31
+ "monkey_patch_sam3d",
32
+ ]
33
+
34
 
35
  def monkey_patch_pano2room():
36
  current_file_path = os.path.abspath(__file__)
 
222
  ManiSkillScene.get_human_render_camera_images = (
223
  get_human_render_camera_images
224
  )
225
+
226
+
227
+ def monkey_patch_sam3d():
228
+ from typing import Optional, Union
229
+
230
+ from embodied_gen.data.utils import model_device_ctx
231
+ from embodied_gen.utils.log import logger
232
+
233
+ os.environ["LIDRA_SKIP_INIT"] = "true"
234
+
235
+ current_file_path = os.path.abspath(__file__)
236
+ current_dir = os.path.dirname(current_file_path)
237
+ sam3d_root = os.path.abspath(
238
+ os.path.join(current_dir, "../../thirdparty/sam3d")
239
+ )
240
+ if sam3d_root not in sys.path:
241
+ sys.path.insert(0, sam3d_root)
242
+
243
+ print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
244
+
245
+ def patch_pointmap_infer_pipeline():
246
+ from copy import deepcopy
247
+
248
+ try:
249
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
250
+ InferencePipelinePointMap,
251
+ )
252
+ except ImportError:
253
+ logger.error(
254
+ "[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
255
+ )
256
+ return
257
+
258
+ def patch_run(
259
+ self,
260
+ image: Union[None, Image.Image, np.ndarray],
261
+ mask: Union[None, Image.Image, np.ndarray] = None,
262
+ seed: Optional[int] = None,
263
+ stage1_only=False,
264
+ with_mesh_postprocess=True,
265
+ with_texture_baking=True,
266
+ with_layout_postprocess=True,
267
+ use_vertex_color=False,
268
+ stage1_inference_steps=None,
269
+ stage2_inference_steps=None,
270
+ use_stage1_distillation=False,
271
+ use_stage2_distillation=False,
272
+ pointmap=None,
273
+ decode_formats=None,
274
+ estimate_plane=False,
275
+ ) -> dict:
276
+ image = self.merge_image_and_mask(image, mask)
277
+ with self.device:
278
+ pointmap_dict = self.compute_pointmap(image, pointmap)
279
+ pointmap = pointmap_dict["pointmap"]
280
+ pts = type(self)._down_sample_img(pointmap)
281
+ pts_colors = type(self)._down_sample_img(
282
+ pointmap_dict["pts_color"]
283
+ )
284
+
285
+ if estimate_plane:
286
+ return self.estimate_plane(pointmap_dict, image)
287
+
288
+ ss_input_dict = self.preprocess_image(
289
+ image, self.ss_preprocessor, pointmap=pointmap
290
+ )
291
+
292
+ slat_input_dict = self.preprocess_image(
293
+ image, self.slat_preprocessor
294
+ )
295
+ if seed is not None:
296
+ torch.manual_seed(seed)
297
+
298
+ with model_device_ctx(
299
+ self.models["ss_generator"],
300
+ self.models["ss_decoder"],
301
+ self.condition_embedders["ss_condition_embedder"],
302
+ ):
303
+ ss_return_dict = self.sample_sparse_structure(
304
+ ss_input_dict,
305
+ inference_steps=stage1_inference_steps,
306
+ use_distillation=use_stage1_distillation,
307
+ )
308
+
309
+ # We could probably use the decoder from the models themselves
310
+ pointmap_scale = ss_input_dict.get("pointmap_scale", None)
311
+ pointmap_shift = ss_input_dict.get("pointmap_shift", None)
312
+ ss_return_dict.update(
313
+ self.pose_decoder(
314
+ ss_return_dict,
315
+ scene_scale=pointmap_scale,
316
+ scene_shift=pointmap_shift,
317
+ )
318
+ )
319
+
320
+ logger.info(
321
+ f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
322
+ )
323
+ ss_return_dict["scale"] = (
324
+ ss_return_dict["scale"]
325
+ * ss_return_dict["downsample_factor"]
326
+ )
327
+
328
+ if stage1_only:
329
+ logger.info("Finished!")
330
+ ss_return_dict["voxel"] = (
331
+ ss_return_dict["coords"][:, 1:] / 64 - 0.5
332
+ )
333
+ return {
334
+ **ss_return_dict,
335
+ "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
336
+ "pointmap_colors": pts_colors.cpu().permute(
337
+ (1, 2, 0)
338
+ ), # HxWx3
339
+ }
340
+ # return ss_return_dict
341
+
342
+ coords = ss_return_dict["coords"]
343
+ with model_device_ctx(
344
+ self.models["slat_generator"],
345
+ self.condition_embedders["slat_condition_embedder"],
346
+ ):
347
+ slat = self.sample_slat(
348
+ slat_input_dict,
349
+ coords,
350
+ inference_steps=stage2_inference_steps,
351
+ use_distillation=use_stage2_distillation,
352
+ )
353
+
354
+ with model_device_ctx(
355
+ self.models["slat_decoder_mesh"],
356
+ self.models["slat_decoder_gs"],
357
+ self.models["slat_decoder_gs_4"],
358
+ ):
359
+ outputs = self.decode_slat(
360
+ slat,
361
+ (
362
+ self.decode_formats
363
+ if decode_formats is None
364
+ else decode_formats
365
+ ),
366
+ )
367
+
368
+ outputs = self.postprocess_slat_output(
369
+ outputs,
370
+ with_mesh_postprocess,
371
+ with_texture_baking,
372
+ use_vertex_color,
373
+ )
374
+ glb = outputs.get("glb", None)
375
+
376
+ try:
377
+ if (
378
+ with_layout_postprocess
379
+ and self.layout_post_optimization_method is not None
380
+ ):
381
+ assert (
382
+ glb is not None
383
+ ), "require mesh to run postprocessing"
384
+ logger.info(
385
+ "Running layout post optimization method..."
386
+ )
387
+ postprocessed_pose = self.run_post_optimization(
388
+ deepcopy(glb),
389
+ pointmap_dict["intrinsics"],
390
+ ss_return_dict,
391
+ ss_input_dict,
392
+ )
393
+ ss_return_dict.update(postprocessed_pose)
394
+ except Exception as e:
395
+ logger.error(
396
+ f"Error during layout post optimization: {e}",
397
+ exc_info=True,
398
+ )
399
+
400
+ # glb.export("sample.glb")
401
+ logger.info("Finished!")
402
+
403
+ return {
404
+ **ss_return_dict,
405
+ **outputs,
406
+ "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
407
+ "pointmap_colors": pts_colors.cpu().permute(
408
+ (1, 2, 0)
409
+ ), # HxWx3
410
+ }
411
+
412
+ InferencePipelinePointMap.run = patch_run
413
+
414
+ def patch_infer_init():
415
+ import torch
416
+
417
+ try:
418
+ from sam3d_objects.pipeline import preprocess_utils
419
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
420
+ InferencePipeline,
421
+ )
422
+ from sam3d_objects.pipeline.inference_utils import (
423
+ SLAT_MEAN,
424
+ SLAT_STD,
425
+ )
426
+ except ImportError:
427
+ print(
428
+ "[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
429
+ )
430
+ return
431
+
432
+ def patch_init(
433
+ self,
434
+ ss_generator_config_path,
435
+ ss_generator_ckpt_path,
436
+ slat_generator_config_path,
437
+ slat_generator_ckpt_path,
438
+ ss_decoder_config_path,
439
+ ss_decoder_ckpt_path,
440
+ slat_decoder_gs_config_path,
441
+ slat_decoder_gs_ckpt_path,
442
+ slat_decoder_mesh_config_path,
443
+ slat_decoder_mesh_ckpt_path,
444
+ slat_decoder_gs_4_config_path=None,
445
+ slat_decoder_gs_4_ckpt_path=None,
446
+ ss_encoder_config_path=None,
447
+ ss_encoder_ckpt_path=None,
448
+ decode_formats=["gaussian", "mesh"],
449
+ dtype="bfloat16",
450
+ pad_size=1.0,
451
+ version="v0",
452
+ device="cuda",
453
+ ss_preprocessor=preprocess_utils.get_default_preprocessor(),
454
+ slat_preprocessor=preprocess_utils.get_default_preprocessor(),
455
+ ss_condition_input_mapping=["image"],
456
+ slat_condition_input_mapping=["image"],
457
+ pose_decoder_name="default",
458
+ workspace_dir="",
459
+ downsample_ss_dist=0, # the distance we use to downsample
460
+ ss_inference_steps=25,
461
+ ss_rescale_t=3,
462
+ ss_cfg_strength=7,
463
+ ss_cfg_interval=[0, 500],
464
+ ss_cfg_strength_pm=0.0,
465
+ slat_inference_steps=25,
466
+ slat_rescale_t=3,
467
+ slat_cfg_strength=5,
468
+ slat_cfg_interval=[0, 500],
469
+ rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
470
+ shape_model_dtype=None,
471
+ compile_model=False,
472
+ slat_mean=SLAT_MEAN,
473
+ slat_std=SLAT_STD,
474
+ ):
475
+ self.rendering_engine = rendering_engine
476
+ self.device = torch.device(device)
477
+ self.compile_model = compile_model
478
+ logger.info(f"self.device: {self.device}")
479
+ logger.info(
480
+ f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
481
+ )
482
+ logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
483
+ with self.device:
484
+ self.decode_formats = decode_formats
485
+ self.pad_size = pad_size
486
+ self.version = version
487
+ self.ss_condition_input_mapping = ss_condition_input_mapping
488
+ self.slat_condition_input_mapping = (
489
+ slat_condition_input_mapping
490
+ )
491
+ self.workspace_dir = workspace_dir
492
+ self.downsample_ss_dist = downsample_ss_dist
493
+ self.ss_inference_steps = ss_inference_steps
494
+ self.ss_rescale_t = ss_rescale_t
495
+ self.ss_cfg_strength = ss_cfg_strength
496
+ self.ss_cfg_interval = ss_cfg_interval
497
+ self.ss_cfg_strength_pm = ss_cfg_strength_pm
498
+ self.slat_inference_steps = slat_inference_steps
499
+ self.slat_rescale_t = slat_rescale_t
500
+ self.slat_cfg_strength = slat_cfg_strength
501
+ self.slat_cfg_interval = slat_cfg_interval
502
+
503
+ self.dtype = self._get_dtype(dtype)
504
+ if shape_model_dtype is None:
505
+ self.shape_model_dtype = self.dtype
506
+ else:
507
+ self.shape_model_dtype = self._get_dtype(shape_model_dtype)
508
+
509
+ # Setup preprocessors
510
+ self.pose_decoder = self.init_pose_decoder(
511
+ ss_generator_config_path, pose_decoder_name
512
+ )
513
+ self.ss_preprocessor = self.init_ss_preprocessor(
514
+ ss_preprocessor, ss_generator_config_path
515
+ )
516
+ self.slat_preprocessor = slat_preprocessor
517
+
518
+ logger.info("Loading model weights...")
519
+ raw_device = self.device
520
+ self.device = torch.device("cpu")
521
+ ss_generator = self.init_ss_generator(
522
+ ss_generator_config_path, ss_generator_ckpt_path
523
+ )
524
+ slat_generator = self.init_slat_generator(
525
+ slat_generator_config_path, slat_generator_ckpt_path
526
+ )
527
+ ss_decoder = self.init_ss_decoder(
528
+ ss_decoder_config_path, ss_decoder_ckpt_path
529
+ )
530
+ ss_encoder = self.init_ss_encoder(
531
+ ss_encoder_config_path, ss_encoder_ckpt_path
532
+ )
533
+ slat_decoder_gs = self.init_slat_decoder_gs(
534
+ slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
535
+ )
536
+ slat_decoder_gs_4 = self.init_slat_decoder_gs(
537
+ slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
538
+ )
539
+ slat_decoder_mesh = self.init_slat_decoder_mesh(
540
+ slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
541
+ )
542
+
543
+ # Load conditioner embedder so that we only load it once
544
+ ss_condition_embedder = self.init_ss_condition_embedder(
545
+ ss_generator_config_path, ss_generator_ckpt_path
546
+ )
547
+ slat_condition_embedder = self.init_slat_condition_embedder(
548
+ slat_generator_config_path, slat_generator_ckpt_path
549
+ )
550
+ self.device = raw_device
551
+
552
+ self.condition_embedders = {
553
+ "ss_condition_embedder": ss_condition_embedder,
554
+ "slat_condition_embedder": slat_condition_embedder,
555
+ }
556
+
557
+ # override generator and condition embedder setting
558
+ self.override_ss_generator_cfg_config(
559
+ ss_generator,
560
+ cfg_strength=ss_cfg_strength,
561
+ inference_steps=ss_inference_steps,
562
+ rescale_t=ss_rescale_t,
563
+ cfg_interval=ss_cfg_interval,
564
+ cfg_strength_pm=ss_cfg_strength_pm,
565
+ )
566
+ self.override_slat_generator_cfg_config(
567
+ slat_generator,
568
+ cfg_strength=slat_cfg_strength,
569
+ inference_steps=slat_inference_steps,
570
+ rescale_t=slat_rescale_t,
571
+ cfg_interval=slat_cfg_interval,
572
+ )
573
+
574
+ self.models = torch.nn.ModuleDict(
575
+ {
576
+ "ss_generator": ss_generator,
577
+ "slat_generator": slat_generator,
578
+ "ss_encoder": ss_encoder,
579
+ "ss_decoder": ss_decoder,
580
+ "slat_decoder_gs": slat_decoder_gs,
581
+ "slat_decoder_gs_4": slat_decoder_gs_4,
582
+ "slat_decoder_mesh": slat_decoder_mesh,
583
+ }
584
+ )
585
+ logger.info("Loading model weights completed!")
586
+
587
+ if self.compile_model:
588
+ logger.info("Compiling model...")
589
+ self._compile()
590
+ logger.info("Model compilation completed!")
591
+ self.slat_mean = torch.tensor(slat_mean)
592
+ self.slat_std = torch.tensor(slat_std)
593
+
594
+ InferencePipeline.__init__ = patch_init
595
+
596
+ patch_pointmap_infer_pipeline()
597
+ patch_infer_init()
598
+
599
+ return
embodied_gen/utils/trender.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import os
18
  import sys
 
19
 
20
  import numpy as np
21
  import spaces
@@ -25,10 +26,8 @@ from tqdm import tqdm
25
  current_file_path = os.path.abspath(__file__)
26
  current_dir = os.path.dirname(current_file_path)
27
  sys.path.append(os.path.join(current_dir, "../.."))
28
- from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
29
- from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
30
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
31
- render_frames,
32
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
33
  )
34
 
@@ -38,7 +37,7 @@ __all__ = [
38
 
39
 
40
  @spaces.GPU
41
- def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
42
  renderer = MeshRenderer()
43
  renderer.rendering_options.resolution = options.get("resolution", 512)
44
  renderer.rendering_options.near = options.get("near", 1)
@@ -60,6 +59,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
60
  return rets
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @spaces.GPU
64
  def render_video(
65
  sample,
@@ -77,7 +127,9 @@ def render_video(
77
  yaws, pitch, r, fov
78
  )
79
  render_fn = (
80
- render_mesh if isinstance(sample, MeshExtractResult) else render_frames
 
 
81
  )
82
  result = render_fn(
83
  sample,
 
16
 
17
  import os
18
  import sys
19
+ from collections import defaultdict
20
 
21
  import numpy as np
22
  import spaces
 
26
  current_file_path = os.path.abspath(__file__)
27
  current_dir = os.path.dirname(current_file_path)
28
  sys.path.append(os.path.join(current_dir, "../.."))
29
+ from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
 
30
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
 
31
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
32
  )
33
 
 
37
 
38
 
39
  @spaces.GPU
40
+ def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs):
41
  renderer = MeshRenderer()
42
  renderer.rendering_options.resolution = options.get("resolution", 512)
43
  renderer.rendering_options.near = options.get("near", 1)
 
59
  return rets
60
 
61
 
62
+ @spaces.GPU
63
+ def render_gs_frames(
64
+ sample,
65
+ extrinsics,
66
+ intrinsics,
67
+ options=None,
68
+ colors_overwrite=None,
69
+ verbose=True,
70
+ **kwargs,
71
+ ):
72
+ def to_img(tensor):
73
+ return np.clip(
74
+ tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
75
+ ).astype(np.uint8)
76
+
77
+ def to_numpy(tensor):
78
+ return tensor.detach().cpu().numpy()
79
+
80
+ renderer = GaussianRenderer()
81
+ renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1)
82
+ renderer.pipe.use_mip_gaussian = True
83
+
84
+ defaults = {
85
+ "resolution": 512,
86
+ "near": 0.8,
87
+ "far": 1.6,
88
+ "bg_color": (0, 0, 0),
89
+ "ssaa": 1,
90
+ }
91
+ final_options = {**defaults, **(options or {})}
92
+
93
+ for k, v in final_options.items():
94
+ if hasattr(renderer.rendering_options, k):
95
+ setattr(renderer.rendering_options, k, v)
96
+
97
+ outputs = defaultdict(list)
98
+ iterator = zip(extrinsics, intrinsics)
99
+ if verbose:
100
+ iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering")
101
+
102
+ for extr, intr in iterator:
103
+ res = renderer.render(
104
+ sample, extr, intr, colors_overwrite=colors_overwrite
105
+ )
106
+ outputs["color"].append(to_img(res["color"]))
107
+ depth = res.get("percent_depth") or res.get("depth")
108
+ outputs["depth"].append(to_numpy(depth) if depth is not None else None)
109
+
110
+ return dict(outputs)
111
+
112
+
113
  @spaces.GPU
114
  def render_video(
115
  sample,
 
127
  yaws, pitch, r, fov
128
  )
129
  render_fn = (
130
+ render_mesh_frames
131
+ if sample.__class__.__name__ == "MeshExtractResult"
132
+ else render_gs_frames
133
  )
134
  result = render_fn(
135
  sample,
requirements.txt CHANGED
@@ -52,4 +52,11 @@ pyquaternion
52
  shapely
53
  sapien==3.0.0b1
54
  typing_extensions==4.14.1
55
- coacd
 
 
 
 
 
 
 
 
52
  shapely
53
  sapien==3.0.0b1
54
  typing_extensions==4.14.1
55
+ ninja
56
+ packaging
57
+ lightning
58
+ astor
59
+ optree
60
+ loguru
61
+ seaborn
62
+ hydra-core
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py CHANGED
@@ -440,7 +440,7 @@ def to_glb(
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
- observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=200)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
 
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
+ observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]