xinjie.wang commited on
Commit
c154483
·
1 Parent(s): 58164f8
app.py CHANGED
@@ -44,11 +44,13 @@ if app_name == "imageto3d_sam3d":
44
 
45
  enable_pre_resize = False
46
  sample_step = 25
 
47
  elif app_name == "imageto3d":
48
  from common import image_to_3d
49
 
50
  enable_pre_resize = True
51
  sample_step = 12
 
52
 
53
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
54
  gr.HTML(image_css, visible=False)
@@ -155,7 +157,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
155
  )
156
  rmbg_tag = gr.Radio(
157
  choices=["rembg", "rmbg14"],
158
- value="rembg",
159
  label="Background Removal Model",
160
  )
161
  with gr.Row():
 
44
 
45
  enable_pre_resize = False
46
  sample_step = 25
47
+ bg_rm_model_name = "rembg" # "rembg", "rmbg14"
48
  elif app_name == "imageto3d":
49
  from common import image_to_3d
50
 
51
  enable_pre_resize = True
52
  sample_step = 12
53
+ bg_rm_model_name = "rembg" # "rembg", "rmbg14"
54
 
55
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
56
  gr.HTML(image_css, visible=False)
 
157
  )
158
  rmbg_tag = gr.Radio(
159
  choices=["rembg", "rmbg14"],
160
+ value=bg_rm_model_name,
161
  label="Background Removal Model",
162
  )
163
  with gr.Row():
common.py CHANGED
@@ -34,7 +34,7 @@ from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
36
  from embodied_gen.data.differentiable_render import entrypoint as render_api
37
- from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
39
  from embodied_gen.models.gs_model import GaussianOperator
40
  from embodied_gen.models.segment_model import (
@@ -53,6 +53,7 @@ from embodied_gen.scripts.text2image import (
53
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
54
  from embodied_gen.utils.process_media import (
55
  filter_image_small_connected_components,
 
56
  merge_images_video,
57
  )
58
  from embodied_gen.utils.tags import VERSION
@@ -246,6 +247,7 @@ def preprocess_image_fn(
246
 
247
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
248
  image = bg_remover(image)
 
249
 
250
  if preprocess:
251
  image = trellis_preprocess(image)
@@ -928,6 +930,7 @@ def backproject_texture_v2(
928
  texture_size: int,
929
  enable_delight: bool = True,
930
  fix_mesh: bool = False,
 
931
  uuid: str = "sample",
932
  req: gr.Request = None,
933
  ) -> str:
@@ -944,6 +947,7 @@ def backproject_texture_v2(
944
  skip_fix_mesh=not fix_mesh,
945
  delight=enable_delight,
946
  texture_wh=[texture_size, texture_size],
 
947
  )
948
 
949
  output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
 
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
36
  from embodied_gen.data.differentiable_render import entrypoint as render_api
37
+ from embodied_gen.data.utils import trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
39
  from embodied_gen.models.gs_model import GaussianOperator
40
  from embodied_gen.models.segment_model import (
 
53
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
54
  from embodied_gen.utils.process_media import (
55
  filter_image_small_connected_components,
56
+ keep_largest_connected_component,
57
  merge_images_video,
58
  )
59
  from embodied_gen.utils.tags import VERSION
 
247
 
248
  bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
249
  image = bg_remover(image)
250
+ image = keep_largest_connected_component(image)
251
 
252
  if preprocess:
253
  image = trellis_preprocess(image)
 
930
  texture_size: int,
931
  enable_delight: bool = True,
932
  fix_mesh: bool = False,
933
+ no_mesh_post_process: bool = False,
934
  uuid: str = "sample",
935
  req: gr.Request = None,
936
  ) -> str:
 
947
  skip_fix_mesh=not fix_mesh,
948
  delight=enable_delight,
949
  texture_wh=[texture_size, texture_size],
950
+ no_mesh_post_process=no_mesh_post_process,
951
  )
952
 
953
  output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
embodied_gen/data/backproject_v2.py CHANGED
@@ -274,6 +274,7 @@ class TextureBacker:
274
  mask_thresh (float, optional): Threshold for visibility masks.
275
  smooth_texture (bool, optional): Apply post-processing to texture.
276
  inpaint_smooth (bool, optional): Apply inpainting smoothing.
 
277
 
278
  Example:
279
  ```py
@@ -308,6 +309,7 @@ class TextureBacker:
308
  mask_thresh: float = 0.5,
309
  smooth_texture: bool = True,
310
  inpaint_smooth: bool = False,
 
311
  ) -> None:
312
  self.camera_params = camera_params
313
  self.renderer = None
@@ -318,6 +320,7 @@ class TextureBacker:
318
  self.mask_thresh = mask_thresh
319
  self.smooth_texture = smooth_texture
320
  self.inpaint_smooth = inpaint_smooth
 
321
 
322
  self.bake_angle_thresh = bake_angle_thresh
323
  self.bake_unreliable_kernel_size = int(
@@ -668,7 +671,12 @@ class TextureBacker:
668
  mesh, self.scale, self.center
669
  )
670
  textured_mesh = save_mesh_with_mtl(
671
- vertices, faces, uv_map, texture_np, output_path
 
 
 
 
 
672
  )
673
 
674
  return textured_mesh
@@ -766,6 +774,7 @@ def parse_args():
766
  help="Disable saving delight image",
767
  )
768
  parser.add_argument("--n_max_faces", type=int, default=30000)
 
769
  args, unknown = parser.parse_known_args()
770
 
771
  return args
@@ -856,6 +865,7 @@ def entrypoint(
856
  render_wh=args.resolution_hw,
857
  texture_wh=args.texture_wh,
858
  smooth_texture=not args.no_smooth_texture,
 
859
  )
860
 
861
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
 
274
  mask_thresh (float, optional): Threshold for visibility masks.
275
  smooth_texture (bool, optional): Apply post-processing to texture.
276
  inpaint_smooth (bool, optional): Apply inpainting smoothing.
277
+ mesh_post_process (bool, optional): False for preventing modification of vertices.
278
 
279
  Example:
280
  ```py
 
309
  mask_thresh: float = 0.5,
310
  smooth_texture: bool = True,
311
  inpaint_smooth: bool = False,
312
+ mesh_post_process: bool = True,
313
  ) -> None:
314
  self.camera_params = camera_params
315
  self.renderer = None
 
320
  self.mask_thresh = mask_thresh
321
  self.smooth_texture = smooth_texture
322
  self.inpaint_smooth = inpaint_smooth
323
+ self.mesh_post_process = mesh_post_process
324
 
325
  self.bake_angle_thresh = bake_angle_thresh
326
  self.bake_unreliable_kernel_size = int(
 
671
  mesh, self.scale, self.center
672
  )
673
  textured_mesh = save_mesh_with_mtl(
674
+ vertices,
675
+ faces,
676
+ uv_map,
677
+ texture_np,
678
+ output_path,
679
+ mesh_process=self.mesh_post_process,
680
  )
681
 
682
  return textured_mesh
 
774
  help="Disable saving delight image",
775
  )
776
  parser.add_argument("--n_max_faces", type=int, default=30000)
777
+ parser.add_argument("--no_mesh_post_process", action="store_true")
778
  args, unknown = parser.parse_known_args()
779
 
780
  return args
 
865
  render_wh=args.resolution_hw,
866
  texture_wh=args.texture_wh,
867
  smooth_texture=not args.no_smooth_texture,
868
+ mesh_post_process=not args.no_mesh_post_process,
869
  )
870
 
871
  textured_mesh = texture_backer(multiviews, mesh, args.output_path)
embodied_gen/data/utils.py CHANGED
@@ -726,6 +726,7 @@ def save_mesh_with_mtl(
726
  texture: Union[Image.Image, np.ndarray],
727
  output_path: str,
728
  material_base=(250, 250, 250, 255),
 
729
  ) -> trimesh.Trimesh:
730
  if isinstance(texture, np.ndarray):
731
  texture = Image.fromarray(texture)
@@ -734,6 +735,7 @@ def save_mesh_with_mtl(
734
  vertices,
735
  faces,
736
  visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
 
737
  )
738
  mesh.visual.material = trimesh.visual.material.SimpleMaterial(
739
  image=texture,
 
726
  texture: Union[Image.Image, np.ndarray],
727
  output_path: str,
728
  material_base=(250, 250, 250, 255),
729
+ mesh_process: bool = True,
730
  ) -> trimesh.Trimesh:
731
  if isinstance(texture, np.ndarray):
732
  texture = Image.fromarray(texture)
 
735
  vertices,
736
  faces,
737
  visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
738
+ process=mesh_process, # True for preventing modification of vertices
739
  )
740
  mesh.visual.material = trimesh.visual.material.SimpleMaterial(
741
  image=texture,
embodied_gen/models/segment_model.py CHANGED
@@ -43,6 +43,7 @@ __all__ = [
43
  "SAMRemover",
44
  "SAMPredictor",
45
  "RembgRemover",
 
46
  "get_segmented_image_by_agent",
47
  ]
48
 
@@ -376,7 +377,7 @@ class BMGG14Remover(object):
376
 
377
  def __call__(
378
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
379
- ):
380
  """Removes background from an image.
381
 
382
  Args:
@@ -496,13 +497,18 @@ if __name__ == "__main__":
496
  # input_image = "outputs/text2image/tmp/bucket.jpeg"
497
  # output_image = "outputs/text2image/tmp/bucket_seg.png"
498
 
499
- remover = SAMRemover(model_type="vit_h")
500
- remover = RembgRemover()
501
- clean_image = remover(input_image)
502
- clean_image.save(output_image)
503
- get_segmented_image_by_agent(
504
- Image.open(input_image), remover, remover, None, "./test_seg.png"
505
- )
506
 
507
  remover = BMGG14Remover()
508
- remover("embodied_gen/models/test_seg.jpg", "./seg.png")
 
 
 
 
 
 
43
  "SAMRemover",
44
  "SAMPredictor",
45
  "RembgRemover",
46
+ "BMGG14Remover",
47
  "get_segmented_image_by_agent",
48
  ]
49
 
 
377
 
378
  def __call__(
379
  self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
380
+ ) -> Image.Image:
381
  """Removes background from an image.
382
 
383
  Args:
 
497
  # input_image = "outputs/text2image/tmp/bucket.jpeg"
498
  # output_image = "outputs/text2image/tmp/bucket_seg.png"
499
 
500
+ # remover = SAMRemover(model_type="vit_h")
501
+ # remover = RembgRemover()
502
+ # clean_image = remover(input_image)
503
+ # clean_image.save(output_image)
504
+ # get_segmented_image_by_agent(
505
+ # Image.open(input_image), remover, remover, None, "./test_seg.png"
506
+ # )
507
 
508
  remover = BMGG14Remover()
509
+ clean_image = remover("./camera.jpeg", "./seg.png")
510
+ from embodied_gen.utils.process_media import (
511
+ keep_largest_connected_component,
512
+ )
513
+
514
+ keep_largest_connected_component(clean_image).save("./seg_post.png")
embodied_gen/scripts/gen_texture.py CHANGED
@@ -94,6 +94,7 @@ def entrypoint() -> None:
94
  delight=cfg.delight,
95
  no_save_delight_img=True,
96
  texture_wh=[cfg.texture_size, cfg.texture_size],
 
97
  )
98
  drender_api(
99
  mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
 
94
  delight=cfg.delight,
95
  no_save_delight_img=True,
96
  texture_wh=[cfg.texture_size, cfg.texture_size],
97
+ no_mesh_post_process=True,
98
  )
99
  drender_api(
100
  mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
embodied_gen/utils/process_media.py CHANGED
@@ -230,6 +230,29 @@ def filter_image_small_connected_components(
230
  return image
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def combine_images_to_grid(
234
  images: list[str | Image.Image],
235
  cat_row_col: tuple[int, int] = None,
@@ -439,7 +462,7 @@ class SceneTreeVisualizer:
439
  plt.axis("off")
440
 
441
  legend_handles = [
442
- Patch(facecolor=color, edgecolor='black', label=role)
443
  for role, color in self.role_colors.items()
444
  ]
445
  plt.legend(
@@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict:
465
  dict: Mapping from scene ID to description.
466
  """
467
  scene_dict = {}
468
- with open(file_path, "r", encoding='utf-8') as f:
469
  for line in f:
470
  line = line.strip()
471
  if not line or ":" not in line:
@@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool:
487
  """
488
  mime_type, _ = mimetypes.guess_type(filename)
489
 
490
- return mime_type is not None and mime_type.startswith('image')
491
 
492
 
493
  def parse_text_prompts(prompts: list[str]) -> list[str]:
 
230
  return image
231
 
232
 
233
+ def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image:
234
+ if pil_img.mode != "RGBA":
235
+ pil_img = pil_img.convert("RGBA")
236
+
237
+ img_arr = np.array(pil_img)
238
+ alpha_channel = img_arr[:, :, 3]
239
+
240
+ _, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)
241
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
242
+ binary_mask, connectivity=8
243
+ )
244
+ if num_labels < 2:
245
+ return pil_img
246
+
247
+ largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
248
+ new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype(
249
+ np.uint8
250
+ )
251
+ img_arr[:, :, 3] = new_alpha
252
+
253
+ return Image.fromarray(img_arr)
254
+
255
+
256
  def combine_images_to_grid(
257
  images: list[str | Image.Image],
258
  cat_row_col: tuple[int, int] = None,
 
462
  plt.axis("off")
463
 
464
  legend_handles = [
465
+ Patch(facecolor=color, edgecolor="black", label=role)
466
  for role, color in self.role_colors.items()
467
  ]
468
  plt.legend(
 
488
  dict: Mapping from scene ID to description.
489
  """
490
  scene_dict = {}
491
+ with open(file_path, "r", encoding="utf-8") as f:
492
  for line in f:
493
  line = line.strip()
494
  if not line or ":" not in line:
 
510
  """
511
  mime_type, _ = mimetypes.guess_type(filename)
512
 
513
+ return mime_type is not None and mime_type.startswith("image")
514
 
515
 
516
  def parse_text_prompts(prompts: list[str]) -> list[str]: