xinjie.wang commited on
Commit
b05f3ac
·
1 Parent(s): 74b41f3
app.py CHANGED
@@ -267,7 +267,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
267
 
268
  demo.load(start_session)
269
  demo.unload(end_session)
270
-
271
  mesh_input.change(
272
  lambda: tuple(
273
  [
@@ -368,6 +368,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
368
  texture_size,
369
  project_delight,
370
  fix_mesh,
 
371
  ],
372
  outputs=[mesh_output, mesh_outpath, download_btn],
373
  ).success(
 
267
 
268
  demo.load(start_session)
269
  demo.unload(end_session)
270
+ no_mesh_post_process = gr.State(True)
271
  mesh_input.change(
272
  lambda: tuple(
273
  [
 
368
  texture_size,
369
  project_delight,
370
  fix_mesh,
371
+ no_mesh_post_process,
372
  ],
373
  outputs=[mesh_output, mesh_outpath, download_btn],
374
  ).success(
app_style.py CHANGED
@@ -1,10 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from gradio.themes import Soft
2
  from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
3
 
4
  lighting_css = """
5
  <style>
6
  #lighter_mesh canvas {
7
- filter: brightness(2.0) !important;
8
  }
9
  </style>
10
  """
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
  from gradio.themes import Soft
18
  from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
19
 
20
  lighting_css = """
21
  <style>
22
  #lighter_mesh canvas {
23
+ filter: brightness(2.3) !important;
24
  }
25
  </style>
26
  """
common.py CHANGED
@@ -53,6 +53,7 @@ from embodied_gen.scripts.text2image import (
53
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
54
  from embodied_gen.utils.process_media import (
55
  filter_image_small_connected_components,
 
56
  merge_images_video,
57
  )
58
  from embodied_gen.utils.tags import VERSION
@@ -151,6 +152,21 @@ if os.getenv("GRADIO_APP") == "imageto3d":
151
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
152
  )
153
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  elif os.getenv("GRADIO_APP") == "textto3d":
155
  RBG_REMOVER = RembgRemover()
156
  RBG14_REMOVER = BMGG14Remover()
@@ -169,6 +185,23 @@ elif os.getenv("GRADIO_APP") == "textto3d":
169
  os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
170
  )
171
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  elif os.getenv("GRADIO_APP") == "texture_edit":
173
  DELIGHT = DelightingModel()
174
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
@@ -201,18 +234,23 @@ def end_session(req: gr.Request) -> None:
201
 
202
  @spaces.GPU
203
  def preprocess_image_fn(
204
- image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
 
 
205
  ) -> tuple[Image.Image, Image.Image]:
206
  if isinstance(image, str):
207
  image = Image.open(image)
208
  elif isinstance(image, np.ndarray):
209
  image = Image.fromarray(image)
210
 
211
- image_cache = image.copy().resize((512, 512))
212
 
213
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
214
  image = bg_remover(image)
215
- image = trellis_preprocess(image)
 
 
 
216
 
217
  return image, image_cache
218
 
@@ -224,7 +262,7 @@ def preprocess_sam_image_fn(
224
  image = Image.fromarray(image)
225
 
226
  sam_image = SAM_PREDICTOR.preprocess_image(image)
227
- image_cache = Image.fromarray(sam_image).resize((512, 512))
228
  SAM_PREDICTOR.predictor.set_image(sam_image)
229
 
230
  return sam_image, image_cache
@@ -349,11 +387,11 @@ def select_point(
349
  def image_to_3d(
350
  image: Image.Image,
351
  seed: int,
352
- ss_guidance_strength: float,
353
  ss_sampling_steps: int,
354
- slat_guidance_strength: float,
355
  slat_sampling_steps: int,
356
  raw_image_cache: Image.Image,
 
 
357
  sam_image: Image.Image = None,
358
  is_sam_image: bool = False,
359
  req: gr.Request = None,
@@ -392,8 +430,56 @@ def image_to_3d(
392
 
393
  gs_model = outputs["gaussian"][0]
394
  mesh_model = outputs["mesh"][0]
395
- color_images = render_video(gs_model)["color"]
396
- normal_images = render_video(mesh_model)["normal"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  video_path = os.path.join(output_root, "gs_mesh.mp4")
399
  merge_images_video(color_images, normal_images, video_path)
@@ -688,6 +774,7 @@ def text2image_fn(
688
  image_wh: int | tuple[int, int] = [1024, 1024],
689
  rmbg_tag: str = "rembg",
690
  seed: int = None,
 
691
  n_sample: int = 3,
692
  req: gr.Request = None,
693
  ):
@@ -715,7 +802,9 @@ def text2image_fn(
715
 
716
  for idx in range(len(images)):
717
  image = images[idx]
718
- images[idx], _ = preprocess_image_fn(image, rmbg_tag)
 
 
719
 
720
  save_paths = []
721
  for idx, image in enumerate(images):
@@ -841,6 +930,7 @@ def backproject_texture_v2(
841
  texture_size: int,
842
  enable_delight: bool = True,
843
  fix_mesh: bool = False,
 
844
  uuid: str = "sample",
845
  req: gr.Request = None,
846
  ) -> str:
@@ -857,6 +947,7 @@ def backproject_texture_v2(
857
  skip_fix_mesh=not fix_mesh,
858
  delight=enable_delight,
859
  texture_wh=[texture_size, texture_size],
 
860
  )
861
 
862
  output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
 
53
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
54
  from embodied_gen.utils.process_media import (
55
  filter_image_small_connected_components,
56
+ keep_largest_connected_component,
57
  merge_images_video,
58
  )
59
  from embodied_gen.utils.tags import VERSION
 
152
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
153
  )
154
  os.makedirs(TMP_DIR, exist_ok=True)
155
+ elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
156
+ from embodied_gen.models.sam3d import Sam3dInference
157
+
158
+ RBG_REMOVER = RembgRemover()
159
+ RBG14_REMOVER = BMGG14Remover()
160
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
161
+ PIPELINE = Sam3dInference()
162
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
163
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
164
+ AESTHETIC_CHECKER = ImageAestheticChecker()
165
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
166
+ TMP_DIR = os.path.join(
167
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
168
+ )
169
+ os.makedirs(TMP_DIR, exist_ok=True)
170
  elif os.getenv("GRADIO_APP") == "textto3d":
171
  RBG_REMOVER = RembgRemover()
172
  RBG14_REMOVER = BMGG14Remover()
 
185
  os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
186
  )
187
  os.makedirs(TMP_DIR, exist_ok=True)
188
+ elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
189
+ from embodied_gen.models.sam3d import Sam3dInference
190
+
191
+ RBG_REMOVER = RembgRemover()
192
+ RBG14_REMOVER = BMGG14Remover()
193
+ PIPELINE = Sam3dInference()
194
+ text_model_dir = "weights/Kolors"
195
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
196
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
197
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
198
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
199
+ AESTHETIC_CHECKER = ImageAestheticChecker()
200
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
201
+ TMP_DIR = os.path.join(
202
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
203
+ )
204
+ os.makedirs(TMP_DIR, exist_ok=True)
205
  elif os.getenv("GRADIO_APP") == "texture_edit":
206
  DELIGHT = DelightingModel()
207
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
 
234
 
235
  @spaces.GPU
236
  def preprocess_image_fn(
237
+ image: str | np.ndarray | Image.Image,
238
+ rmbg_tag: str = "rembg",
239
+ preprocess: bool = True,
240
  ) -> tuple[Image.Image, Image.Image]:
241
  if isinstance(image, str):
242
  image = Image.open(image)
243
  elif isinstance(image, np.ndarray):
244
  image = Image.fromarray(image)
245
 
246
+ image_cache = image.copy() # resize_pil(image.copy(), 1024)
247
 
248
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
249
  image = bg_remover(image)
250
+ image = keep_largest_connected_component(image)
251
+
252
+ if preprocess:
253
+ image = trellis_preprocess(image)
254
 
255
  return image, image_cache
256
 
 
262
  image = Image.fromarray(image)
263
 
264
  sam_image = SAM_PREDICTOR.preprocess_image(image)
265
+ image_cache = sam_image.copy()
266
  SAM_PREDICTOR.predictor.set_image(sam_image)
267
 
268
  return sam_image, image_cache
 
387
  def image_to_3d(
388
  image: Image.Image,
389
  seed: int,
 
390
  ss_sampling_steps: int,
 
391
  slat_sampling_steps: int,
392
  raw_image_cache: Image.Image,
393
+ ss_guidance_strength: float,
394
+ slat_guidance_strength: float,
395
  sam_image: Image.Image = None,
396
  is_sam_image: bool = False,
397
  req: gr.Request = None,
 
430
 
431
  gs_model = outputs["gaussian"][0]
432
  mesh_model = outputs["mesh"][0]
433
+ color_images = render_video(gs_model, r=1.85)["color"]
434
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
435
+
436
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
437
+ merge_images_video(color_images, normal_images, video_path)
438
+ state = pack_state(gs_model, mesh_model)
439
+
440
+ gc.collect()
441
+ torch.cuda.empty_cache()
442
+
443
+ return state, video_path
444
+
445
+
446
+ @spaces.GPU
447
+ def image_to_3d_sam3d(
448
+ image: Image.Image,
449
+ seed: int,
450
+ ss_sampling_steps: int,
451
+ slat_sampling_steps: int,
452
+ raw_image_cache: Image.Image,
453
+ ss_guidance_strength: float = None,
454
+ slat_guidance_strength: float = None,
455
+ sam_image: Image.Image = None,
456
+ is_sam_image: bool = False,
457
+ req: gr.Request = None,
458
+ ) -> tuple[dict, str]:
459
+ if is_sam_image:
460
+ seg_image = filter_image_small_connected_components(sam_image)
461
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
462
+ else:
463
+ seg_image = image
464
+
465
+ if isinstance(seg_image, np.ndarray):
466
+ seg_image = Image.fromarray(seg_image)
467
+
468
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
469
+ os.makedirs(output_root, exist_ok=True)
470
+ seg_image.save(f"{output_root}/seg_image.png")
471
+ raw_image_cache.save(f"{output_root}/raw_image.png")
472
+ outputs = PIPELINE.run(
473
+ seg_image,
474
+ seed=seed,
475
+ stage1_inference_steps=ss_sampling_steps,
476
+ stage2_inference_steps=slat_sampling_steps,
477
+ )
478
+
479
+ gs_model = outputs["gaussian"][0]
480
+ mesh_model = outputs["mesh"][0]
481
+ color_images = render_video(gs_model, r=1.85)["color"]
482
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
483
 
484
  video_path = os.path.join(output_root, "gs_mesh.mp4")
485
  merge_images_video(color_images, normal_images, video_path)
 
774
  image_wh: int | tuple[int, int] = [1024, 1024],
775
  rmbg_tag: str = "rembg",
776
  seed: int = None,
777
+ enable_pre_resize: bool = True,
778
  n_sample: int = 3,
779
  req: gr.Request = None,
780
  ):
 
802
 
803
  for idx in range(len(images)):
804
  image = images[idx]
805
+ images[idx], _ = preprocess_image_fn(
806
+ image, rmbg_tag, enable_pre_resize
807
+ )
808
 
809
  save_paths = []
810
  for idx, image in enumerate(images):
 
930
  texture_size: int,
931
  enable_delight: bool = True,
932
  fix_mesh: bool = False,
933
+ no_mesh_post_process: bool = False,
934
  uuid: str = "sample",
935
  req: gr.Request = None,
936
  ) -> str:
 
947
  skip_fix_mesh=not fix_mesh,
948
  delight=enable_delight,
949
  texture_wh=[texture_size, texture_size],
950
+ no_mesh_post_process=no_mesh_post_process,
951
  )
952
 
953
  output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
embodied_gen/data/backproject_v2.py CHANGED
@@ -274,6 +274,7 @@ class TextureBacker:
274
  mask_thresh (float, optional): Threshold for visibility masks.
275
  smooth_texture (bool, optional): Apply post-processing to texture.
276
  inpaint_smooth (bool, optional): Apply inpainting smoothing.
 
277
 
278
  Example:
279
  ```py
@@ -308,6 +309,7 @@ class TextureBacker:
308
  mask_thresh: float = 0.5,
309
  smooth_texture: bool = True,
310
  inpaint_smooth: bool = False,
 
311
  ) -> None:
312
  self.camera_params = camera_params
313
  self.renderer = None
@@ -318,6 +320,7 @@ class TextureBacker:
318
  self.mask_thresh = mask_thresh
319
  self.smooth_texture = smooth_texture
320
  self.inpaint_smooth = inpaint_smooth
 
321
 
322
  self.bake_angle_thresh = bake_angle_thresh
323
  self.bake_unreliable_kernel_size = int(
@@ -668,7 +671,12 @@ class TextureBacker:
668
  mesh, self.scale, self.center
669
  )
670
  textured_mesh = save_mesh_with_mtl(
671
- vertices, faces, uv_map, texture_np, output_path
 
 
 
 
 
672
  )
673
 
674
  return textured_mesh
@@ -766,6 +774,7 @@ def parse_args():
766
  help="Disable saving delight image",
767
  )
768
  parser.add_argument("--n_max_faces", type=int, default=30000)
 
769
  args, unknown = parser.parse_known_args()
770
 
771
  return args
@@ -856,6 +865,7 @@ def entrypoint(
856
  render_wh=args.resolution_hw,
857
  texture_wh=args.texture_wh,
858
  smooth_texture=not args.no_smooth_texture,
 
859
  )
860
 
861
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
 
274
  mask_thresh (float, optional): Threshold for visibility masks.
275
  smooth_texture (bool, optional): Apply post-processing to texture.
276
  inpaint_smooth (bool, optional): Apply inpainting smoothing.
277
+ mesh_post_process (bool, optional): False for preventing modification of vertices.
278
 
279
  Example:
280
  ```py
 
309
  mask_thresh: float = 0.5,
310
  smooth_texture: bool = True,
311
  inpaint_smooth: bool = False,
312
+ mesh_post_process: bool = True,
313
  ) -> None:
314
  self.camera_params = camera_params
315
  self.renderer = None
 
320
  self.mask_thresh = mask_thresh
321
  self.smooth_texture = smooth_texture
322
  self.inpaint_smooth = inpaint_smooth
323
+ self.mesh_post_process = mesh_post_process
324
 
325
  self.bake_angle_thresh = bake_angle_thresh
326
  self.bake_unreliable_kernel_size = int(
 
671
  mesh, self.scale, self.center
672
  )
673
  textured_mesh = save_mesh_with_mtl(
674
+ vertices,
675
+ faces,
676
+ uv_map,
677
+ texture_np,
678
+ output_path,
679
+ mesh_process=self.mesh_post_process,
680
  )
681
 
682
  return textured_mesh
 
774
  help="Disable saving delight image",
775
  )
776
  parser.add_argument("--n_max_faces", type=int, default=30000)
777
+ parser.add_argument("--no_mesh_post_process", action="store_true")
778
  args, unknown = parser.parse_known_args()
779
 
780
  return args
 
865
  render_wh=args.resolution_hw,
866
  texture_wh=args.texture_wh,
867
  smooth_texture=not args.no_smooth_texture,
868
+ mesh_post_process=not args.no_mesh_post_process,
869
  )
870
 
871
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
embodied_gen/data/backproject_v3.py CHANGED
@@ -14,10 +14,10 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
-
18
  import argparse
19
  import logging
20
  import math
 
21
  from typing import Literal, Union
22
 
23
  import cv2
@@ -353,8 +353,8 @@ def parse_args():
353
  parser.add_argument(
354
  "--distance",
355
  type=float,
356
- default=5,
357
- help="Camera distance (default: 5)",
358
  )
359
  parser.add_argument(
360
  "--resolution_hw",
@@ -400,8 +400,8 @@ def parse_args():
400
  parser.add_argument(
401
  "--mesh_sipmlify_ratio",
402
  type=float,
403
- default=0.9,
404
- help="Mesh simplification ratio (default: 0.9)",
405
  )
406
  parser.add_argument(
407
  "--delight", action="store_true", help="Use delighting model."
@@ -425,6 +425,7 @@ def parse_args():
425
  return args
426
 
427
 
 
428
  def entrypoint(
429
  delight_model: DelightingModel = None,
430
  imagesr_model: ImageRealESRGAN = None,
@@ -499,7 +500,7 @@ def entrypoint(
499
  faces = mesh.faces.astype(np.int32)
500
  vertices = vertices.astype(np.float32)
501
 
502
- if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
503
  mesh_fixer = MeshFixer(vertices, faces, args.device)
504
  vertices, faces = mesh_fixer(
505
  filter_ratio=args.mesh_sipmlify_ratio,
@@ -511,7 +512,7 @@ def entrypoint(
511
  if len(faces) > args.n_max_faces:
512
  mesh_fixer = MeshFixer(vertices, faces, args.device)
513
  vertices, faces = mesh_fixer(
514
- filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
515
  max_hole_size=0.04,
516
  resolution=1024,
517
  num_views=1000,
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
 
17
  import argparse
18
  import logging
19
  import math
20
+ import os
21
  from typing import Literal, Union
22
 
23
  import cv2
 
353
  parser.add_argument(
354
  "--distance",
355
  type=float,
356
+ default=4.5,
357
+ help="Camera distance (default: 4.5)",
358
  )
359
  parser.add_argument(
360
  "--resolution_hw",
 
400
  parser.add_argument(
401
  "--mesh_sipmlify_ratio",
402
  type=float,
403
+ default=0.85,
404
+ help="Mesh simplification ratio (default: 0.85)",
405
  )
406
  parser.add_argument(
407
  "--delight", action="store_true", help="Use delighting model."
 
425
  return args
426
 
427
 
428
+ @spaces.GPU
429
  def entrypoint(
430
  delight_model: DelightingModel = None,
431
  imagesr_model: ImageRealESRGAN = None,
 
500
  faces = mesh.faces.astype(np.int32)
501
  vertices = vertices.astype(np.float32)
502
 
503
+ if not args.skip_fix_mesh:
504
  mesh_fixer = MeshFixer(vertices, faces, args.device)
505
  vertices, faces = mesh_fixer(
506
  filter_ratio=args.mesh_sipmlify_ratio,
 
512
  if len(faces) > args.n_max_faces:
513
  mesh_fixer = MeshFixer(vertices, faces, args.device)
514
  vertices, faces = mesh_fixer(
515
+ filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1),
516
  max_hole_size=0.04,
517
  resolution=1024,
518
  num_views=1000,
embodied_gen/data/utils.py CHANGED
@@ -15,10 +15,13 @@
15
  # permissions and limitations under the License.
16
 
17
 
 
18
  import math
19
  import os
20
- import random
21
  import zipfile
 
 
22
  from shutil import rmtree
23
  from typing import List, Tuple, Union
24
 
@@ -28,20 +31,9 @@ import numpy as np
28
  import nvdiffrast.torch as dr
29
  import torch
30
  import torch.nn.functional as F
31
- from PIL import Image, ImageEnhance
32
-
33
- try:
34
- from kolors.models.modeling_chatglm import ChatGLMModel
35
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
36
- except ImportError:
37
- ChatGLMTokenizer = None
38
- ChatGLMModel = None
39
- import logging
40
- from dataclasses import dataclass, field
41
-
42
  import trimesh
43
  from kaolin.render.camera import Camera
44
- from torch import nn
45
 
46
  logger = logging.getLogger(__name__)
47
 
@@ -50,10 +42,8 @@ __all__ = [
50
  "DiffrastRender",
51
  "save_images",
52
  "render_pbr",
53
- "prelabel_text_feature",
54
  "calc_vertex_normals",
55
  "normalize_vertices_array",
56
- "load_mesh_to_unit_cube",
57
  "as_list",
58
  "CameraSetting",
59
  "import_kaolin_mesh",
@@ -67,6 +57,7 @@ __all__ = [
67
  "trellis_preprocess",
68
  "delete_dir",
69
  "kaolin_to_opencv_view",
 
70
  ]
71
 
72
 
@@ -520,114 +511,6 @@ def render_pbr(
520
  return image, albedo, diffuse, normal
521
 
522
 
523
- def _move_to_target_device(data, device: str):
524
- if isinstance(data, dict):
525
- for key, value in data.items():
526
- data[key] = _move_to_target_device(value, device)
527
- elif isinstance(data, torch.Tensor):
528
- return data.to(device)
529
-
530
- return data
531
-
532
-
533
- def _encode_prompt(
534
- prompt_batch,
535
- text_encoders,
536
- tokenizers,
537
- proportion_empty_prompts=0,
538
- is_train=True,
539
- ):
540
- prompt_embeds_list = []
541
-
542
- captions = []
543
- for caption in prompt_batch:
544
- if random.random() < proportion_empty_prompts:
545
- captions.append("")
546
- elif isinstance(caption, str):
547
- captions.append(caption)
548
- elif isinstance(caption, (list, np.ndarray)):
549
- captions.append(random.choice(caption) if is_train else caption[0])
550
-
551
- with torch.no_grad():
552
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
553
- text_inputs = tokenizer(
554
- captions,
555
- padding="max_length",
556
- max_length=256,
557
- truncation=True,
558
- return_tensors="pt",
559
- ).to(text_encoder.device)
560
-
561
- output = text_encoder(
562
- input_ids=text_inputs.input_ids,
563
- attention_mask=text_inputs.attention_mask,
564
- position_ids=text_inputs.position_ids,
565
- output_hidden_states=True,
566
- )
567
-
568
- # We are only interested in the pooled output of the text encoder.
569
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
570
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
571
- bs_embed, seq_len, _ = prompt_embeds.shape
572
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
573
- prompt_embeds_list.append(prompt_embeds)
574
-
575
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
576
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
577
-
578
- return prompt_embeds, pooled_prompt_embeds
579
-
580
-
581
- def load_llm_models(pretrained_model_name_or_path: str, device: str):
582
- tokenizer = ChatGLMTokenizer.from_pretrained(
583
- pretrained_model_name_or_path,
584
- subfolder="text_encoder",
585
- )
586
- text_encoder = ChatGLMModel.from_pretrained(
587
- pretrained_model_name_or_path,
588
- subfolder="text_encoder",
589
- ).to(device)
590
-
591
- text_encoders = [
592
- text_encoder,
593
- ]
594
- tokenizers = [
595
- tokenizer,
596
- ]
597
-
598
- logger.info(f"Load model from {pretrained_model_name_or_path} done.")
599
-
600
- return tokenizers, text_encoders
601
-
602
-
603
- def prelabel_text_feature(
604
- prompt_batch: List[str],
605
- output_dir: str,
606
- tokenizers: nn.Module,
607
- text_encoders: nn.Module,
608
- ) -> List[str]:
609
- os.makedirs(output_dir, exist_ok=True)
610
-
611
- # prompt_batch ["text..."]
612
- prompt_embeds, pooled_prompt_embeds = _encode_prompt(
613
- prompt_batch, text_encoders, tokenizers
614
- )
615
-
616
- prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
617
- pooled_prompt_embeds = _move_to_target_device(
618
- pooled_prompt_embeds, device="cpu"
619
- )
620
-
621
- data_dict = dict(
622
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
623
- )
624
-
625
- save_path = os.path.join(output_dir, "text_feat.pth")
626
- torch.save(data_dict, save_path)
627
-
628
- return save_path
629
-
630
-
631
  def _calc_face_normals(
632
  vertices: torch.Tensor, # V,3 first vertex may be unreferenced
633
  faces: torch.Tensor, # F,3 long, first face may be all zero
@@ -683,25 +566,6 @@ def normalize_vertices_array(
683
  return vertices, scale, center
684
 
685
 
686
- def load_mesh_to_unit_cube(
687
- mesh_file: str,
688
- mesh_scale: float = 1.0,
689
- ) -> tuple[trimesh.Trimesh, float, list[float]]:
690
- if not os.path.exists(mesh_file):
691
- raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
692
-
693
- mesh = trimesh.load(mesh_file)
694
- if isinstance(mesh, trimesh.Scene):
695
- mesh = trimesh.utils.concatenate(mesh)
696
-
697
- vertices, scale, center = normalize_vertices_array(
698
- mesh.vertices, mesh_scale
699
- )
700
- mesh.vertices = vertices
701
-
702
- return mesh, scale, center
703
-
704
-
705
  def as_list(obj):
706
  if isinstance(obj, (list, tuple)):
707
  return obj
@@ -862,6 +726,7 @@ def save_mesh_with_mtl(
862
  texture: Union[Image.Image, np.ndarray],
863
  output_path: str,
864
  material_base=(250, 250, 250, 255),
 
865
  ) -> trimesh.Trimesh:
866
  if isinstance(texture, np.ndarray):
867
  texture = Image.fromarray(texture)
@@ -870,6 +735,7 @@ def save_mesh_with_mtl(
870
  vertices,
871
  faces,
872
  visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
 
873
  )
874
  mesh.visual.material = trimesh.visual.material.SimpleMaterial(
875
  image=texture,
@@ -998,8 +864,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
998
 
999
 
1000
  def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
1001
- max_size = max(image.size)
1002
- scale = min(1, 1024 / max_size)
 
1003
  if scale < 1:
1004
  new_size = (int(image.width * scale), int(image.height * scale))
1005
  image = image.resize(new_size, Image.Resampling.LANCZOS)
@@ -1068,3 +935,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
1068
  rmtree(item_path)
1069
  else:
1070
  os.remove(item_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # permissions and limitations under the License.
16
 
17
 
18
+ import logging
19
  import math
20
  import os
21
+ import time
22
  import zipfile
23
+ from contextlib import contextmanager
24
+ from dataclasses import dataclass, field
25
  from shutil import rmtree
26
  from typing import List, Tuple, Union
27
 
 
31
  import nvdiffrast.torch as dr
32
  import torch
33
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
34
  import trimesh
35
  from kaolin.render.camera import Camera
36
+ from PIL import Image, ImageEnhance
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
42
  "DiffrastRender",
43
  "save_images",
44
  "render_pbr",
 
45
  "calc_vertex_normals",
46
  "normalize_vertices_array",
 
47
  "as_list",
48
  "CameraSetting",
49
  "import_kaolin_mesh",
 
57
  "trellis_preprocess",
58
  "delete_dir",
59
  "kaolin_to_opencv_view",
60
+ "model_device_ctx",
61
  ]
62
 
63
 
 
511
  return image, albedo, diffuse, normal
512
 
513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  def _calc_face_normals(
515
  vertices: torch.Tensor, # V,3 first vertex may be unreferenced
516
  faces: torch.Tensor, # F,3 long, first face may be all zero
 
566
  return vertices, scale, center
567
 
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  def as_list(obj):
570
  if isinstance(obj, (list, tuple)):
571
  return obj
 
726
  texture: Union[Image.Image, np.ndarray],
727
  output_path: str,
728
  material_base=(250, 250, 250, 255),
729
+ mesh_process: bool = True,
730
  ) -> trimesh.Trimesh:
731
  if isinstance(texture, np.ndarray):
732
  texture = Image.fromarray(texture)
 
735
  vertices,
736
  faces,
737
  visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
738
+ process=mesh_process, # True for preventing modification of vertices
739
  )
740
  mesh.visual.material = trimesh.visual.material.SimpleMaterial(
741
  image=texture,
 
864
 
865
 
866
  def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
867
+ current_max_dim = max(image.size)
868
+ scale = min(1, max_size / current_max_dim)
869
+
870
  if scale < 1:
871
  new_size = (int(image.width * scale), int(image.height * scale))
872
  image = image.resize(new_size, Image.Resampling.LANCZOS)
 
935
  rmtree(item_path)
936
  else:
937
  os.remove(item_path)
938
+
939
+
940
+ @contextmanager
941
+ def model_device_ctx(
942
+ *models,
943
+ src_device: str = "cpu",
944
+ dst_device: str = "cuda",
945
+ verbose: bool = False,
946
+ ):
947
+ start = time.perf_counter()
948
+ for m in models:
949
+ if m is None:
950
+ continue
951
+ m.to(dst_device)
952
+ to_cuda_time = time.perf_counter() - start
953
+
954
+ try:
955
+ yield
956
+ finally:
957
+ start = time.perf_counter()
958
+ for m in models:
959
+ if m is None:
960
+ continue
961
+ m.to(src_device)
962
+ to_cpu_time = time.perf_counter() - start
963
+
964
+ if verbose:
965
+ model_names = [m.__class__.__name__ for m in models]
966
+ logger.debug(
967
+ f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s"
968
+ )
embodied_gen/models/sam3d.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
18
+
19
+ monkey_patch_sam3d()
20
+ import os
21
+ import sys
22
+ from typing import Optional, Union
23
+
24
+ import numpy as np
25
+ from hydra.utils import instantiate
26
+ from modelscope import snapshot_download
27
+ from omegaconf import OmegaConf
28
+ from PIL import Image
29
+
30
+ current_file_path = os.path.abspath(__file__)
31
+ current_dir = os.path.dirname(current_file_path)
32
+ sys.path.append(os.path.join(current_dir, "../.."))
33
+ from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
34
+ InferencePipelinePointMap,
35
+ )
36
+
37
+ __all__ = ["Sam3dInference"]
38
+
39
+
40
+ def load_image(path: str) -> np.ndarray:
41
+ image = Image.open(path)
42
+ image = np.array(image)
43
+ image = image.astype(np.uint8)
44
+ return image
45
+
46
+
47
+ def load_mask(path: str) -> np.ndarray:
48
+ mask = load_image(path)
49
+ mask = mask > 0
50
+ if mask.ndim == 3:
51
+ mask = mask[..., -1]
52
+ return mask
53
+
54
+
55
+ class Sam3dInference:
56
+ def __init__(
57
+ self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
58
+ ) -> None:
59
+ if not os.path.exists(local_dir):
60
+ snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
61
+ config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
62
+ config = OmegaConf.load(config_file)
63
+ config.rendering_engine = "nvdiffrast"
64
+ config.compile_model = compile
65
+ config.workspace_dir = os.path.dirname(config_file)
66
+ # Generate 4 gs in each pixel.
67
+ config["slat_decoder_gs_config_path"] = config.pop(
68
+ "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
69
+ )
70
+ config["slat_decoder_gs_ckpt_path"] = config.pop(
71
+ "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
72
+ )
73
+ self.pipeline: InferencePipelinePointMap = instantiate(config)
74
+
75
+ def merge_mask_to_rgba(
76
+ self, image: np.ndarray, mask: np.ndarray
77
+ ) -> np.ndarray:
78
+ mask = mask.astype(np.uint8) * 255
79
+ mask = mask[..., None]
80
+ rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
81
+
82
+ return rgba_image
83
+
84
+ def run(
85
+ self,
86
+ image: np.ndarray | Image.Image,
87
+ mask: np.ndarray = None,
88
+ seed: int = None,
89
+ pointmap: np.ndarray = None,
90
+ use_stage1_distillation: bool = False,
91
+ use_stage2_distillation: bool = False,
92
+ stage1_inference_steps: int = 25,
93
+ stage2_inference_steps: int = 25,
94
+ ) -> dict:
95
+ if isinstance(image, Image.Image):
96
+ image = np.array(image)
97
+ if mask is not None:
98
+ image = self.merge_mask_to_rgba(image, mask)
99
+ return self.pipeline.run(
100
+ image,
101
+ None,
102
+ seed,
103
+ stage1_only=False,
104
+ with_mesh_postprocess=False,
105
+ with_texture_baking=False,
106
+ with_layout_postprocess=False,
107
+ use_vertex_color=True,
108
+ use_stage1_distillation=use_stage1_distillation,
109
+ use_stage2_distillation=use_stage2_distillation,
110
+ stage1_inference_steps=stage1_inference_steps,
111
+ stage2_inference_steps=stage2_inference_steps,
112
+ pointmap=pointmap,
113
+ )
114
+
115
+
116
+ if __name__ == "__main__":
117
+ pipeline = Sam3dInference()
118
+
119
+ # load image
120
+ image = load_image(
121
+ "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
122
+ )
123
+ mask = load_mask(
124
+ "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
125
+ )
126
+
127
+ import torch
128
+
129
+ if torch.cuda.is_available():
130
+ torch.cuda.reset_peak_memory_stats()
131
+ torch.cuda.empty_cache()
132
+
133
+ from time import time
134
+
135
+ start = time()
136
+
137
+ output = pipeline.run(image, mask, seed=42)
138
+ print(f"Running cost: {round(time()-start, 1)}")
139
+
140
+ if torch.cuda.is_available():
141
+ max_memory = torch.cuda.max_memory_allocated() / (1024**3)
142
+ print(f"(Max VRAM): {max_memory:.2f} GB")
143
+
144
+ print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
145
+
146
+ output["gs"].save_ply(f"outputs/splat.ply")
147
+ print("Your reconstruction has been saved to outputs/splat.ply")
embodied_gen/models/segment_model.py CHANGED
@@ -43,6 +43,7 @@ __all__ = [
43
  "SAMRemover",
44
  "SAMPredictor",
45
  "RembgRemover",
 
46
  "get_segmented_image_by_agent",
47
  ]
48
 
@@ -376,7 +377,7 @@ class BMGG14Remover(object):
376
 
377
  def __call__(
378
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
379
- ):
380
  """Removes background from an image.
381
 
382
  Args:
@@ -496,13 +497,18 @@ if __name__ == "__main__":
496
  # input_image = "outputs/text2image/tmp/bucket.jpeg"
497
  # output_image = "outputs/text2image/tmp/bucket_seg.png"
498
 
499
- remover = SAMRemover(model_type="vit_h")
500
- remover = RembgRemover()
501
- clean_image = remover(input_image)
502
- clean_image.save(output_image)
503
- get_segmented_image_by_agent(
504
- Image.open(input_image), remover, remover, None, "./test_seg.png"
505
- )
506
 
507
  remover = BMGG14Remover()
508
- remover("embodied_gen/models/test_seg.jpg", "./seg.png")
 
 
 
 
 
 
43
  "SAMRemover",
44
  "SAMPredictor",
45
  "RembgRemover",
46
+ "BMGG14Remover",
47
  "get_segmented_image_by_agent",
48
  ]
49
 
 
377
 
378
  def __call__(
379
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
380
+ ) -> Image.Image:
381
  """Removes background from an image.
382
 
383
  Args:
 
497
  # input_image = "outputs/text2image/tmp/bucket.jpeg"
498
  # output_image = "outputs/text2image/tmp/bucket_seg.png"
499
 
500
+ # remover = SAMRemover(model_type="vit_h")
501
+ # remover = RembgRemover()
502
+ # clean_image = remover(input_image)
503
+ # clean_image.save(output_image)
504
+ # get_segmented_image_by_agent(
505
+ # Image.open(input_image), remover, remover, None, "./test_seg.png"
506
+ # )
507
 
508
  remover = BMGG14Remover()
509
+ clean_image = remover("./camera.jpeg", "./seg.png")
510
+ from embodied_gen.utils.process_media import (
511
+ keep_largest_connected_component,
512
+ )
513
+
514
+ keep_largest_connected_component(clean_image).save("./seg_post.png")
embodied_gen/scripts/gen_texture.py CHANGED
@@ -94,6 +94,7 @@ def entrypoint() -> None:
94
  delight=cfg.delight,
95
  no_save_delight_img=True,
96
  texture_wh=[cfg.texture_size, cfg.texture_size],
 
97
  )
98
  drender_api(
99
  mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
 
94
  delight=cfg.delight,
95
  no_save_delight_img=True,
96
  texture_wh=[cfg.texture_size, cfg.texture_size],
97
+ no_mesh_post_process=True,
98
  )
99
  drender_api(
100
  mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
embodied_gen/scripts/imageto3d.py CHANGED
@@ -26,12 +26,14 @@ import numpy as np
26
  import torch
27
  import trimesh
28
  from PIL import Image
29
- from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
30
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
- from embodied_gen.models.delight_model import DelightingModel
 
32
  from embodied_gen.models.gs_model import GaussianOperator
33
  from embodied_gen.models.segment_model import RembgRemover
34
- from embodied_gen.models.sr_model import ImageRealESRGAN
 
35
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
36
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
37
  from embodied_gen.utils.log import logger
@@ -59,8 +61,8 @@ os.environ["SPCONV_ALGO"] = "native"
59
  random.seed(0)
60
 
61
  logger.info("Loading Image3D Models...")
62
- DELIGHT = DelightingModel()
63
- IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
64
  RBG_REMOVER = RembgRemover()
65
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
66
  "microsoft/TRELLIS-image-large"
@@ -108,9 +110,7 @@ def parse_args():
108
  default=2,
109
  )
110
  parser.add_argument("--disable_decompose_convex", action="store_true")
111
- parser.add_argument(
112
- "--texture_wh", type=int, nargs=2, default=[2048, 2048]
113
- )
114
  args, unknown = parser.parse_known_args()
115
 
116
  return args
@@ -248,16 +248,14 @@ def entrypoint(**kwargs):
248
  mesh.export(mesh_obj_path)
249
 
250
  mesh = backproject_api(
251
- delight_model=DELIGHT,
252
- imagesr_model=IMAGESR_MODEL,
253
- color_path=color_path,
254
  mesh_path=mesh_obj_path,
255
  output_path=mesh_obj_path,
256
  skip_fix_mesh=False,
257
- delight=True,
258
- texture_wh=args.texture_wh,
259
- elevation=[20, -10, 60, -50],
260
- num_images=12,
261
  )
262
 
263
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
 
26
  import torch
27
  import trimesh
28
  from PIL import Image
29
+ from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
30
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
+
32
+ # from embodied_gen.models.delight_model import DelightingModel
33
  from embodied_gen.models.gs_model import GaussianOperator
34
  from embodied_gen.models.segment_model import RembgRemover
35
+
36
+ # from embodied_gen.models.sr_model import ImageRealESRGAN
37
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
38
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
39
  from embodied_gen.utils.log import logger
 
61
  random.seed(0)
62
 
63
  logger.info("Loading Image3D Models...")
64
+ # DELIGHT = DelightingModel()
65
+ # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
66
  RBG_REMOVER = RembgRemover()
67
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
68
  "microsoft/TRELLIS-image-large"
 
110
  default=2,
111
  )
112
  parser.add_argument("--disable_decompose_convex", action="store_true")
113
+ parser.add_argument("--texture_size", type=int, default=2048)
 
 
114
  args, unknown = parser.parse_known_args()
115
 
116
  return args
 
248
  mesh.export(mesh_obj_path)
249
 
250
  mesh = backproject_api(
251
+ # delight_model=DELIGHT,
252
+ # imagesr_model=IMAGESR_MODEL,
253
+ gs_path=aligned_gs_path,
254
  mesh_path=mesh_obj_path,
255
  output_path=mesh_obj_path,
256
  skip_fix_mesh=False,
257
+ texture_size=args.texture_size,
258
+ delight=False,
 
 
259
  )
260
 
261
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
embodied_gen/utils/monkey_patches.py CHANGED
@@ -25,6 +25,12 @@ from omegaconf import OmegaConf
25
  from PIL import Image
26
  from torchvision import transforms
27
 
 
 
 
 
 
 
28
 
29
  def monkey_patch_pano2room():
30
  current_file_path = os.path.abspath(__file__)
@@ -216,3 +222,374 @@ def monkey_patch_maniskill():
216
  ManiSkillScene.get_human_render_camera_images = (
217
  get_human_render_camera_images
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  from PIL import Image
26
  from torchvision import transforms
27
 
28
+ __all__ = [
29
+ "monkey_patch_pano2room",
30
+ "monkey_patch_maniskill",
31
+ "monkey_patch_sam3d",
32
+ ]
33
+
34
 
35
  def monkey_patch_pano2room():
36
  current_file_path = os.path.abspath(__file__)
 
222
  ManiSkillScene.get_human_render_camera_images = (
223
  get_human_render_camera_images
224
  )
225
+
226
+
227
+ def monkey_patch_sam3d():
228
+ from typing import Optional, Union
229
+
230
+ from embodied_gen.data.utils import model_device_ctx
231
+ from embodied_gen.utils.log import logger
232
+
233
+ os.environ["LIDRA_SKIP_INIT"] = "true"
234
+
235
+ current_file_path = os.path.abspath(__file__)
236
+ current_dir = os.path.dirname(current_file_path)
237
+ sam3d_root = os.path.abspath(
238
+ os.path.join(current_dir, "../../thirdparty/sam3d")
239
+ )
240
+ if sam3d_root not in sys.path:
241
+ sys.path.insert(0, sam3d_root)
242
+
243
+ print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
244
+
245
+ def patch_pointmap_infer_pipeline():
246
+ from copy import deepcopy
247
+
248
+ try:
249
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
250
+ InferencePipelinePointMap,
251
+ )
252
+ except ImportError:
253
+ logger.error(
254
+ "[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
255
+ )
256
+ return
257
+
258
+ def patch_run(
259
+ self,
260
+ image: Union[None, Image.Image, np.ndarray],
261
+ mask: Union[None, Image.Image, np.ndarray] = None,
262
+ seed: Optional[int] = None,
263
+ stage1_only=False,
264
+ with_mesh_postprocess=True,
265
+ with_texture_baking=True,
266
+ with_layout_postprocess=True,
267
+ use_vertex_color=False,
268
+ stage1_inference_steps=None,
269
+ stage2_inference_steps=None,
270
+ use_stage1_distillation=False,
271
+ use_stage2_distillation=False,
272
+ pointmap=None,
273
+ decode_formats=None,
274
+ estimate_plane=False,
275
+ ) -> dict:
276
+ image = self.merge_image_and_mask(image, mask)
277
+ with self.device:
278
+ pointmap_dict = self.compute_pointmap(image, pointmap)
279
+ pointmap = pointmap_dict["pointmap"]
280
+ pts = type(self)._down_sample_img(pointmap)
281
+ pts_colors = type(self)._down_sample_img(
282
+ pointmap_dict["pts_color"]
283
+ )
284
+
285
+ if estimate_plane:
286
+ return self.estimate_plane(pointmap_dict, image)
287
+
288
+ ss_input_dict = self.preprocess_image(
289
+ image, self.ss_preprocessor, pointmap=pointmap
290
+ )
291
+
292
+ slat_input_dict = self.preprocess_image(
293
+ image, self.slat_preprocessor
294
+ )
295
+ if seed is not None:
296
+ torch.manual_seed(seed)
297
+
298
+ with model_device_ctx(
299
+ self.models["ss_generator"],
300
+ self.models["ss_decoder"],
301
+ self.condition_embedders["ss_condition_embedder"],
302
+ ):
303
+ ss_return_dict = self.sample_sparse_structure(
304
+ ss_input_dict,
305
+ inference_steps=stage1_inference_steps,
306
+ use_distillation=use_stage1_distillation,
307
+ )
308
+
309
+ # We could probably use the decoder from the models themselves
310
+ pointmap_scale = ss_input_dict.get("pointmap_scale", None)
311
+ pointmap_shift = ss_input_dict.get("pointmap_shift", None)
312
+ ss_return_dict.update(
313
+ self.pose_decoder(
314
+ ss_return_dict,
315
+ scene_scale=pointmap_scale,
316
+ scene_shift=pointmap_shift,
317
+ )
318
+ )
319
+
320
+ logger.info(
321
+ f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
322
+ )
323
+ ss_return_dict["scale"] = (
324
+ ss_return_dict["scale"]
325
+ * ss_return_dict["downsample_factor"]
326
+ )
327
+
328
+ if stage1_only:
329
+ logger.info("Finished!")
330
+ ss_return_dict["voxel"] = (
331
+ ss_return_dict["coords"][:, 1:] / 64 - 0.5
332
+ )
333
+ return {
334
+ **ss_return_dict,
335
+ "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
336
+ "pointmap_colors": pts_colors.cpu().permute(
337
+ (1, 2, 0)
338
+ ), # HxWx3
339
+ }
340
+ # return ss_return_dict
341
+
342
+ coords = ss_return_dict["coords"]
343
+ with model_device_ctx(
344
+ self.models["slat_generator"],
345
+ self.condition_embedders["slat_condition_embedder"],
346
+ ):
347
+ slat = self.sample_slat(
348
+ slat_input_dict,
349
+ coords,
350
+ inference_steps=stage2_inference_steps,
351
+ use_distillation=use_stage2_distillation,
352
+ )
353
+
354
+ with model_device_ctx(
355
+ self.models["slat_decoder_mesh"],
356
+ self.models["slat_decoder_gs"],
357
+ self.models["slat_decoder_gs_4"],
358
+ ):
359
+ outputs = self.decode_slat(
360
+ slat,
361
+ (
362
+ self.decode_formats
363
+ if decode_formats is None
364
+ else decode_formats
365
+ ),
366
+ )
367
+
368
+ outputs = self.postprocess_slat_output(
369
+ outputs,
370
+ with_mesh_postprocess,
371
+ with_texture_baking,
372
+ use_vertex_color,
373
+ )
374
+ glb = outputs.get("glb", None)
375
+
376
+ try:
377
+ if (
378
+ with_layout_postprocess
379
+ and self.layout_post_optimization_method is not None
380
+ ):
381
+ assert (
382
+ glb is not None
383
+ ), "require mesh to run postprocessing"
384
+ logger.info(
385
+ "Running layout post optimization method..."
386
+ )
387
+ postprocessed_pose = self.run_post_optimization(
388
+ deepcopy(glb),
389
+ pointmap_dict["intrinsics"],
390
+ ss_return_dict,
391
+ ss_input_dict,
392
+ )
393
+ ss_return_dict.update(postprocessed_pose)
394
+ except Exception as e:
395
+ logger.error(
396
+ f"Error during layout post optimization: {e}",
397
+ exc_info=True,
398
+ )
399
+
400
+ result = {
401
+ **ss_return_dict,
402
+ **outputs,
403
+ "pointmap": pts.cpu().permute((1, 2, 0)),
404
+ "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
405
+ }
406
+ return result
407
+
408
+ InferencePipelinePointMap.run = patch_run
409
+
410
+ def patch_infer_init():
411
+ import torch
412
+
413
+ try:
414
+ from sam3d_objects.pipeline import preprocess_utils
415
+ from sam3d_objects.pipeline.inference_pipeline_pointmap import (
416
+ InferencePipeline,
417
+ )
418
+ from sam3d_objects.pipeline.inference_utils import (
419
+ SLAT_MEAN,
420
+ SLAT_STD,
421
+ )
422
+ except ImportError:
423
+ print(
424
+ "[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
425
+ )
426
+ return
427
+
428
+ def patch_init(
429
+ self,
430
+ ss_generator_config_path,
431
+ ss_generator_ckpt_path,
432
+ slat_generator_config_path,
433
+ slat_generator_ckpt_path,
434
+ ss_decoder_config_path,
435
+ ss_decoder_ckpt_path,
436
+ slat_decoder_gs_config_path,
437
+ slat_decoder_gs_ckpt_path,
438
+ slat_decoder_mesh_config_path,
439
+ slat_decoder_mesh_ckpt_path,
440
+ slat_decoder_gs_4_config_path=None,
441
+ slat_decoder_gs_4_ckpt_path=None,
442
+ ss_encoder_config_path=None,
443
+ ss_encoder_ckpt_path=None,
444
+ decode_formats=["gaussian", "mesh"],
445
+ dtype="bfloat16",
446
+ pad_size=1.0,
447
+ version="v0",
448
+ device="cuda",
449
+ ss_preprocessor=preprocess_utils.get_default_preprocessor(),
450
+ slat_preprocessor=preprocess_utils.get_default_preprocessor(),
451
+ ss_condition_input_mapping=["image"],
452
+ slat_condition_input_mapping=["image"],
453
+ pose_decoder_name="default",
454
+ workspace_dir="",
455
+ downsample_ss_dist=0, # the distance we use to downsample
456
+ ss_inference_steps=25,
457
+ ss_rescale_t=3,
458
+ ss_cfg_strength=7,
459
+ ss_cfg_interval=[0, 500],
460
+ ss_cfg_strength_pm=0.0,
461
+ slat_inference_steps=25,
462
+ slat_rescale_t=3,
463
+ slat_cfg_strength=5,
464
+ slat_cfg_interval=[0, 500],
465
+ rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
466
+ shape_model_dtype=None,
467
+ compile_model=False,
468
+ slat_mean=SLAT_MEAN,
469
+ slat_std=SLAT_STD,
470
+ ):
471
+ self.rendering_engine = rendering_engine
472
+ self.device = torch.device(device)
473
+ self.compile_model = compile_model
474
+ logger.info(f"self.device: {self.device}")
475
+ logger.info(
476
+ f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
477
+ )
478
+ logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
479
+ with self.device:
480
+ self.decode_formats = decode_formats
481
+ self.pad_size = pad_size
482
+ self.version = version
483
+ self.ss_condition_input_mapping = ss_condition_input_mapping
484
+ self.slat_condition_input_mapping = (
485
+ slat_condition_input_mapping
486
+ )
487
+ self.workspace_dir = workspace_dir
488
+ self.downsample_ss_dist = downsample_ss_dist
489
+ self.ss_inference_steps = ss_inference_steps
490
+ self.ss_rescale_t = ss_rescale_t
491
+ self.ss_cfg_strength = ss_cfg_strength
492
+ self.ss_cfg_interval = ss_cfg_interval
493
+ self.ss_cfg_strength_pm = ss_cfg_strength_pm
494
+ self.slat_inference_steps = slat_inference_steps
495
+ self.slat_rescale_t = slat_rescale_t
496
+ self.slat_cfg_strength = slat_cfg_strength
497
+ self.slat_cfg_interval = slat_cfg_interval
498
+
499
+ self.dtype = self._get_dtype(dtype)
500
+ if shape_model_dtype is None:
501
+ self.shape_model_dtype = self.dtype
502
+ else:
503
+ self.shape_model_dtype = self._get_dtype(shape_model_dtype)
504
+
505
+ # Setup preprocessors
506
+ self.pose_decoder = self.init_pose_decoder(
507
+ ss_generator_config_path, pose_decoder_name
508
+ )
509
+ self.ss_preprocessor = self.init_ss_preprocessor(
510
+ ss_preprocessor, ss_generator_config_path
511
+ )
512
+ self.slat_preprocessor = slat_preprocessor
513
+
514
+ logger.info("Loading model weights...")
515
+ raw_device = self.device
516
+ self.device = torch.device("cpu")
517
+ ss_generator = self.init_ss_generator(
518
+ ss_generator_config_path, ss_generator_ckpt_path
519
+ )
520
+ slat_generator = self.init_slat_generator(
521
+ slat_generator_config_path, slat_generator_ckpt_path
522
+ )
523
+ ss_decoder = self.init_ss_decoder(
524
+ ss_decoder_config_path, ss_decoder_ckpt_path
525
+ )
526
+ ss_encoder = self.init_ss_encoder(
527
+ ss_encoder_config_path, ss_encoder_ckpt_path
528
+ )
529
+ slat_decoder_gs = self.init_slat_decoder_gs(
530
+ slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
531
+ )
532
+ slat_decoder_gs_4 = self.init_slat_decoder_gs(
533
+ slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
534
+ )
535
+ slat_decoder_mesh = self.init_slat_decoder_mesh(
536
+ slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
537
+ )
538
+
539
+ # Load conditioner embedder so that we only load it once
540
+ ss_condition_embedder = self.init_ss_condition_embedder(
541
+ ss_generator_config_path, ss_generator_ckpt_path
542
+ )
543
+ slat_condition_embedder = self.init_slat_condition_embedder(
544
+ slat_generator_config_path, slat_generator_ckpt_path
545
+ )
546
+ self.device = raw_device
547
+
548
+ self.condition_embedders = {
549
+ "ss_condition_embedder": ss_condition_embedder,
550
+ "slat_condition_embedder": slat_condition_embedder,
551
+ }
552
+
553
+ # override generator and condition embedder setting
554
+ self.override_ss_generator_cfg_config(
555
+ ss_generator,
556
+ cfg_strength=ss_cfg_strength,
557
+ inference_steps=ss_inference_steps,
558
+ rescale_t=ss_rescale_t,
559
+ cfg_interval=ss_cfg_interval,
560
+ cfg_strength_pm=ss_cfg_strength_pm,
561
+ )
562
+ self.override_slat_generator_cfg_config(
563
+ slat_generator,
564
+ cfg_strength=slat_cfg_strength,
565
+ inference_steps=slat_inference_steps,
566
+ rescale_t=slat_rescale_t,
567
+ cfg_interval=slat_cfg_interval,
568
+ )
569
+
570
+ self.models = torch.nn.ModuleDict(
571
+ {
572
+ "ss_generator": ss_generator,
573
+ "slat_generator": slat_generator,
574
+ "ss_encoder": ss_encoder,
575
+ "ss_decoder": ss_decoder,
576
+ "slat_decoder_gs": slat_decoder_gs,
577
+ "slat_decoder_gs_4": slat_decoder_gs_4,
578
+ "slat_decoder_mesh": slat_decoder_mesh,
579
+ }
580
+ )
581
+ logger.info("Loading model weights completed!")
582
+
583
+ if self.compile_model:
584
+ logger.info("Compiling model...")
585
+ self._compile()
586
+ logger.info("Model compilation completed!")
587
+ self.slat_mean = torch.tensor(slat_mean)
588
+ self.slat_std = torch.tensor(slat_std)
589
+
590
+ InferencePipeline.__init__ = patch_init
591
+
592
+ patch_pointmap_infer_pipeline()
593
+ patch_infer_init()
594
+
595
+ return
embodied_gen/utils/process_media.py CHANGED
@@ -230,6 +230,29 @@ def filter_image_small_connected_components(
230
  return image
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def combine_images_to_grid(
234
  images: list[str | Image.Image],
235
  cat_row_col: tuple[int, int] = None,
@@ -439,7 +462,7 @@ class SceneTreeVisualizer:
439
  plt.axis("off")
440
 
441
  legend_handles = [
442
- Patch(facecolor=color, edgecolor='black', label=role)
443
  for role, color in self.role_colors.items()
444
  ]
445
  plt.legend(
@@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict:
465
  dict: Mapping from scene ID to description.
466
  """
467
  scene_dict = {}
468
- with open(file_path, "r", encoding='utf-8') as f:
469
  for line in f:
470
  line = line.strip()
471
  if not line or ":" not in line:
@@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool:
487
  """
488
  mime_type, _ = mimetypes.guess_type(filename)
489
 
490
- return mime_type is not None and mime_type.startswith('image')
491
 
492
 
493
  def parse_text_prompts(prompts: list[str]) -> list[str]:
 
230
  return image
231
 
232
 
233
+ def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image:
234
+ if pil_img.mode != "RGBA":
235
+ pil_img = pil_img.convert("RGBA")
236
+
237
+ img_arr = np.array(pil_img)
238
+ alpha_channel = img_arr[:, :, 3]
239
+
240
+ _, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)
241
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
242
+ binary_mask, connectivity=8
243
+ )
244
+ if num_labels < 2:
245
+ return pil_img
246
+
247
+ largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
248
+ new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype(
249
+ np.uint8
250
+ )
251
+ img_arr[:, :, 3] = new_alpha
252
+
253
+ return Image.fromarray(img_arr)
254
+
255
+
256
  def combine_images_to_grid(
257
  images: list[str | Image.Image],
258
  cat_row_col: tuple[int, int] = None,
 
462
  plt.axis("off")
463
 
464
  legend_handles = [
465
+ Patch(facecolor=color, edgecolor="black", label=role)
466
  for role, color in self.role_colors.items()
467
  ]
468
  plt.legend(
 
488
  dict: Mapping from scene ID to description.
489
  """
490
  scene_dict = {}
491
+ with open(file_path, "r", encoding="utf-8") as f:
492
  for line in f:
493
  line = line.strip()
494
  if not line or ":" not in line:
 
510
  """
511
  mime_type, _ = mimetypes.guess_type(filename)
512
 
513
+ return mime_type is not None and mime_type.startswith("image")
514
 
515
 
516
  def parse_text_prompts(prompts: list[str]) -> list[str]:
embodied_gen/utils/trender.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import os
18
  import sys
 
19
 
20
  import numpy as np
21
  import spaces
@@ -25,10 +26,8 @@ from tqdm import tqdm
25
  current_file_path = os.path.abspath(__file__)
26
  current_dir = os.path.dirname(current_file_path)
27
  sys.path.append(os.path.join(current_dir, "../.."))
28
- from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
29
- from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
30
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
31
- render_frames,
32
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
33
  )
34
 
@@ -38,7 +37,7 @@ __all__ = [
38
 
39
 
40
  @spaces.GPU
41
- def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
42
  renderer = MeshRenderer()
43
  renderer.rendering_options.resolution = options.get("resolution", 512)
44
  renderer.rendering_options.near = options.get("near", 1)
@@ -60,6 +59,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
60
  return rets
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @spaces.GPU
64
  def render_video(
65
  sample,
@@ -77,7 +127,9 @@ def render_video(
77
  yaws, pitch, r, fov
78
  )
79
  render_fn = (
80
- render_mesh if isinstance(sample, MeshExtractResult) else render_frames
 
 
81
  )
82
  result = render_fn(
83
  sample,
 
16
 
17
  import os
18
  import sys
19
+ from collections import defaultdict
20
 
21
  import numpy as np
22
  import spaces
 
26
  current_file_path = os.path.abspath(__file__)
27
  current_dir = os.path.dirname(current_file_path)
28
  sys.path.append(os.path.join(current_dir, "../.."))
29
+ from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
 
30
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
 
31
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
32
  )
33
 
 
37
 
38
 
39
  @spaces.GPU
40
+ def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs):
41
  renderer = MeshRenderer()
42
  renderer.rendering_options.resolution = options.get("resolution", 512)
43
  renderer.rendering_options.near = options.get("near", 1)
 
59
  return rets
60
 
61
 
62
+ @spaces.GPU
63
+ def render_gs_frames(
64
+ sample,
65
+ extrinsics,
66
+ intrinsics,
67
+ options=None,
68
+ colors_overwrite=None,
69
+ verbose=True,
70
+ **kwargs,
71
+ ):
72
+ def to_img(tensor):
73
+ return np.clip(
74
+ tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
75
+ ).astype(np.uint8)
76
+
77
+ def to_numpy(tensor):
78
+ return tensor.detach().cpu().numpy()
79
+
80
+ renderer = GaussianRenderer()
81
+ renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1)
82
+ renderer.pipe.use_mip_gaussian = True
83
+
84
+ defaults = {
85
+ "resolution": 512,
86
+ "near": 0.8,
87
+ "far": 1.6,
88
+ "bg_color": (0, 0, 0),
89
+ "ssaa": 1,
90
+ }
91
+ final_options = {**defaults, **(options or {})}
92
+
93
+ for k, v in final_options.items():
94
+ if hasattr(renderer.rendering_options, k):
95
+ setattr(renderer.rendering_options, k, v)
96
+
97
+ outputs = defaultdict(list)
98
+ iterator = zip(extrinsics, intrinsics)
99
+ if verbose:
100
+ iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering")
101
+
102
+ for extr, intr in iterator:
103
+ res = renderer.render(
104
+ sample, extr, intr, colors_overwrite=colors_overwrite
105
+ )
106
+ outputs["color"].append(to_img(res["color"]))
107
+ depth = res.get("percent_depth") or res.get("depth")
108
+ outputs["depth"].append(to_numpy(depth) if depth is not None else None)
109
+
110
+ return dict(outputs)
111
+
112
+
113
  @spaces.GPU
114
  def render_video(
115
  sample,
 
127
  yaws, pitch, r, fov
128
  )
129
  render_fn = (
130
+ render_mesh_frames
131
+ if sample.__class__.__name__ == "MeshExtractResult"
132
+ else render_gs_frames
133
  )
134
  result = render_fn(
135
  sample,
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py CHANGED
@@ -440,7 +440,7 @@ def to_glb(
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
- observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=200)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
 
440
  vertices, faces, uvs = parametrize_mesh(vertices, faces)
441
 
442
  # bake texture
443
+ observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
444
  masks = [np.any(observation > 0, axis=-1) for observation in observations]
445
  extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446
  intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]