Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
c154483
1
Parent(s):
58164f8
update
Browse files- app.py +3 -1
- common.py +5 -1
- embodied_gen/data/backproject_v2.py +11 -1
- embodied_gen/data/utils.py +2 -0
- embodied_gen/models/segment_model.py +15 -9
- embodied_gen/scripts/gen_texture.py +1 -0
- embodied_gen/utils/process_media.py +26 -3
app.py
CHANGED
|
@@ -44,11 +44,13 @@ if app_name == "imageto3d_sam3d":
|
|
| 44 |
|
| 45 |
enable_pre_resize = False
|
| 46 |
sample_step = 25
|
|
|
|
| 47 |
elif app_name == "imageto3d":
|
| 48 |
from common import image_to_3d
|
| 49 |
|
| 50 |
enable_pre_resize = True
|
| 51 |
sample_step = 12
|
|
|
|
| 52 |
|
| 53 |
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
| 54 |
gr.HTML(image_css, visible=False)
|
|
@@ -155,7 +157,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 155 |
)
|
| 156 |
rmbg_tag = gr.Radio(
|
| 157 |
choices=["rembg", "rmbg14"],
|
| 158 |
-
value=
|
| 159 |
label="Background Removal Model",
|
| 160 |
)
|
| 161 |
with gr.Row():
|
|
|
|
| 44 |
|
| 45 |
enable_pre_resize = False
|
| 46 |
sample_step = 25
|
| 47 |
+
bg_rm_model_name = "rembg" # "rembg", "rmbg14"
|
| 48 |
elif app_name == "imageto3d":
|
| 49 |
from common import image_to_3d
|
| 50 |
|
| 51 |
enable_pre_resize = True
|
| 52 |
sample_step = 12
|
| 53 |
+
bg_rm_model_name = "rembg" # "rembg", "rmbg14"
|
| 54 |
|
| 55 |
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
| 56 |
gr.HTML(image_css, visible=False)
|
|
|
|
| 157 |
)
|
| 158 |
rmbg_tag = gr.Radio(
|
| 159 |
choices=["rembg", "rmbg14"],
|
| 160 |
+
value=bg_rm_model_name,
|
| 161 |
label="Background Removal Model",
|
| 162 |
)
|
| 163 |
with gr.Row():
|
common.py
CHANGED
|
@@ -34,7 +34,7 @@ from PIL import Image
|
|
| 34 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
| 35 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
| 36 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
| 37 |
-
from embodied_gen.data.utils import
|
| 38 |
from embodied_gen.models.delight_model import DelightingModel
|
| 39 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 40 |
from embodied_gen.models.segment_model import (
|
|
@@ -53,6 +53,7 @@ from embodied_gen.scripts.text2image import (
|
|
| 53 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 54 |
from embodied_gen.utils.process_media import (
|
| 55 |
filter_image_small_connected_components,
|
|
|
|
| 56 |
merge_images_video,
|
| 57 |
)
|
| 58 |
from embodied_gen.utils.tags import VERSION
|
|
@@ -246,6 +247,7 @@ def preprocess_image_fn(
|
|
| 246 |
|
| 247 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 248 |
image = bg_remover(image)
|
|
|
|
| 249 |
|
| 250 |
if preprocess:
|
| 251 |
image = trellis_preprocess(image)
|
|
@@ -928,6 +930,7 @@ def backproject_texture_v2(
|
|
| 928 |
texture_size: int,
|
| 929 |
enable_delight: bool = True,
|
| 930 |
fix_mesh: bool = False,
|
|
|
|
| 931 |
uuid: str = "sample",
|
| 932 |
req: gr.Request = None,
|
| 933 |
) -> str:
|
|
@@ -944,6 +947,7 @@ def backproject_texture_v2(
|
|
| 944 |
skip_fix_mesh=not fix_mesh,
|
| 945 |
delight=enable_delight,
|
| 946 |
texture_wh=[texture_size, texture_size],
|
|
|
|
| 947 |
)
|
| 948 |
|
| 949 |
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
|
|
|
| 34 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
| 35 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
| 36 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
| 37 |
+
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 38 |
from embodied_gen.models.delight_model import DelightingModel
|
| 39 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 40 |
from embodied_gen.models.segment_model import (
|
|
|
|
| 53 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 54 |
from embodied_gen.utils.process_media import (
|
| 55 |
filter_image_small_connected_components,
|
| 56 |
+
keep_largest_connected_component,
|
| 57 |
merge_images_video,
|
| 58 |
)
|
| 59 |
from embodied_gen.utils.tags import VERSION
|
|
|
|
| 247 |
|
| 248 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 249 |
image = bg_remover(image)
|
| 250 |
+
image = keep_largest_connected_component(image)
|
| 251 |
|
| 252 |
if preprocess:
|
| 253 |
image = trellis_preprocess(image)
|
|
|
|
| 930 |
texture_size: int,
|
| 931 |
enable_delight: bool = True,
|
| 932 |
fix_mesh: bool = False,
|
| 933 |
+
no_mesh_post_process: bool = False,
|
| 934 |
uuid: str = "sample",
|
| 935 |
req: gr.Request = None,
|
| 936 |
) -> str:
|
|
|
|
| 947 |
skip_fix_mesh=not fix_mesh,
|
| 948 |
delight=enable_delight,
|
| 949 |
texture_wh=[texture_size, texture_size],
|
| 950 |
+
no_mesh_post_process=no_mesh_post_process,
|
| 951 |
)
|
| 952 |
|
| 953 |
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
embodied_gen/data/backproject_v2.py
CHANGED
|
@@ -274,6 +274,7 @@ class TextureBacker:
|
|
| 274 |
mask_thresh (float, optional): Threshold for visibility masks.
|
| 275 |
smooth_texture (bool, optional): Apply post-processing to texture.
|
| 276 |
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
|
|
|
| 277 |
|
| 278 |
Example:
|
| 279 |
```py
|
|
@@ -308,6 +309,7 @@ class TextureBacker:
|
|
| 308 |
mask_thresh: float = 0.5,
|
| 309 |
smooth_texture: bool = True,
|
| 310 |
inpaint_smooth: bool = False,
|
|
|
|
| 311 |
) -> None:
|
| 312 |
self.camera_params = camera_params
|
| 313 |
self.renderer = None
|
|
@@ -318,6 +320,7 @@ class TextureBacker:
|
|
| 318 |
self.mask_thresh = mask_thresh
|
| 319 |
self.smooth_texture = smooth_texture
|
| 320 |
self.inpaint_smooth = inpaint_smooth
|
|
|
|
| 321 |
|
| 322 |
self.bake_angle_thresh = bake_angle_thresh
|
| 323 |
self.bake_unreliable_kernel_size = int(
|
|
@@ -668,7 +671,12 @@ class TextureBacker:
|
|
| 668 |
mesh, self.scale, self.center
|
| 669 |
)
|
| 670 |
textured_mesh = save_mesh_with_mtl(
|
| 671 |
-
vertices,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
)
|
| 673 |
|
| 674 |
return textured_mesh
|
|
@@ -766,6 +774,7 @@ def parse_args():
|
|
| 766 |
help="Disable saving delight image",
|
| 767 |
)
|
| 768 |
parser.add_argument("--n_max_faces", type=int, default=30000)
|
|
|
|
| 769 |
args, unknown = parser.parse_known_args()
|
| 770 |
|
| 771 |
return args
|
|
@@ -856,6 +865,7 @@ def entrypoint(
|
|
| 856 |
render_wh=args.resolution_hw,
|
| 857 |
texture_wh=args.texture_wh,
|
| 858 |
smooth_texture=not args.no_smooth_texture,
|
|
|
|
| 859 |
)
|
| 860 |
|
| 861 |
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
|
|
|
|
| 274 |
mask_thresh (float, optional): Threshold for visibility masks.
|
| 275 |
smooth_texture (bool, optional): Apply post-processing to texture.
|
| 276 |
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
| 277 |
+
mesh_post_process (bool, optional): False for preventing modification of vertices.
|
| 278 |
|
| 279 |
Example:
|
| 280 |
```py
|
|
|
|
| 309 |
mask_thresh: float = 0.5,
|
| 310 |
smooth_texture: bool = True,
|
| 311 |
inpaint_smooth: bool = False,
|
| 312 |
+
mesh_post_process: bool = True,
|
| 313 |
) -> None:
|
| 314 |
self.camera_params = camera_params
|
| 315 |
self.renderer = None
|
|
|
|
| 320 |
self.mask_thresh = mask_thresh
|
| 321 |
self.smooth_texture = smooth_texture
|
| 322 |
self.inpaint_smooth = inpaint_smooth
|
| 323 |
+
self.mesh_post_process = mesh_post_process
|
| 324 |
|
| 325 |
self.bake_angle_thresh = bake_angle_thresh
|
| 326 |
self.bake_unreliable_kernel_size = int(
|
|
|
|
| 671 |
mesh, self.scale, self.center
|
| 672 |
)
|
| 673 |
textured_mesh = save_mesh_with_mtl(
|
| 674 |
+
vertices,
|
| 675 |
+
faces,
|
| 676 |
+
uv_map,
|
| 677 |
+
texture_np,
|
| 678 |
+
output_path,
|
| 679 |
+
mesh_process=self.mesh_post_process,
|
| 680 |
)
|
| 681 |
|
| 682 |
return textured_mesh
|
|
|
|
| 774 |
help="Disable saving delight image",
|
| 775 |
)
|
| 776 |
parser.add_argument("--n_max_faces", type=int, default=30000)
|
| 777 |
+
parser.add_argument("--no_mesh_post_process", action="store_true")
|
| 778 |
args, unknown = parser.parse_known_args()
|
| 779 |
|
| 780 |
return args
|
|
|
|
| 865 |
render_wh=args.resolution_hw,
|
| 866 |
texture_wh=args.texture_wh,
|
| 867 |
smooth_texture=not args.no_smooth_texture,
|
| 868 |
+
mesh_post_process=not args.no_mesh_post_process,
|
| 869 |
)
|
| 870 |
|
| 871 |
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
|
embodied_gen/data/utils.py
CHANGED
|
@@ -726,6 +726,7 @@ def save_mesh_with_mtl(
|
|
| 726 |
texture: Union[Image.Image, np.ndarray],
|
| 727 |
output_path: str,
|
| 728 |
material_base=(250, 250, 250, 255),
|
|
|
|
| 729 |
) -> trimesh.Trimesh:
|
| 730 |
if isinstance(texture, np.ndarray):
|
| 731 |
texture = Image.fromarray(texture)
|
|
@@ -734,6 +735,7 @@ def save_mesh_with_mtl(
|
|
| 734 |
vertices,
|
| 735 |
faces,
|
| 736 |
visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
|
|
|
|
| 737 |
)
|
| 738 |
mesh.visual.material = trimesh.visual.material.SimpleMaterial(
|
| 739 |
image=texture,
|
|
|
|
| 726 |
texture: Union[Image.Image, np.ndarray],
|
| 727 |
output_path: str,
|
| 728 |
material_base=(250, 250, 250, 255),
|
| 729 |
+
mesh_process: bool = True,
|
| 730 |
) -> trimesh.Trimesh:
|
| 731 |
if isinstance(texture, np.ndarray):
|
| 732 |
texture = Image.fromarray(texture)
|
|
|
|
| 735 |
vertices,
|
| 736 |
faces,
|
| 737 |
visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
|
| 738 |
+
process=mesh_process, # True for preventing modification of vertices
|
| 739 |
)
|
| 740 |
mesh.visual.material = trimesh.visual.material.SimpleMaterial(
|
| 741 |
image=texture,
|
embodied_gen/models/segment_model.py
CHANGED
|
@@ -43,6 +43,7 @@ __all__ = [
|
|
| 43 |
"SAMRemover",
|
| 44 |
"SAMPredictor",
|
| 45 |
"RembgRemover",
|
|
|
|
| 46 |
"get_segmented_image_by_agent",
|
| 47 |
]
|
| 48 |
|
|
@@ -376,7 +377,7 @@ class BMGG14Remover(object):
|
|
| 376 |
|
| 377 |
def __call__(
|
| 378 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 379 |
-
):
|
| 380 |
"""Removes background from an image.
|
| 381 |
|
| 382 |
Args:
|
|
@@ -496,13 +497,18 @@ if __name__ == "__main__":
|
|
| 496 |
# input_image = "outputs/text2image/tmp/bucket.jpeg"
|
| 497 |
# output_image = "outputs/text2image/tmp/bucket_seg.png"
|
| 498 |
|
| 499 |
-
remover = SAMRemover(model_type="vit_h")
|
| 500 |
-
remover = RembgRemover()
|
| 501 |
-
clean_image = remover(input_image)
|
| 502 |
-
clean_image.save(output_image)
|
| 503 |
-
get_segmented_image_by_agent(
|
| 504 |
-
|
| 505 |
-
)
|
| 506 |
|
| 507 |
remover = BMGG14Remover()
|
| 508 |
-
remover("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"SAMRemover",
|
| 44 |
"SAMPredictor",
|
| 45 |
"RembgRemover",
|
| 46 |
+
"BMGG14Remover",
|
| 47 |
"get_segmented_image_by_agent",
|
| 48 |
]
|
| 49 |
|
|
|
|
| 377 |
|
| 378 |
def __call__(
|
| 379 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 380 |
+
) -> Image.Image:
|
| 381 |
"""Removes background from an image.
|
| 382 |
|
| 383 |
Args:
|
|
|
|
| 497 |
# input_image = "outputs/text2image/tmp/bucket.jpeg"
|
| 498 |
# output_image = "outputs/text2image/tmp/bucket_seg.png"
|
| 499 |
|
| 500 |
+
# remover = SAMRemover(model_type="vit_h")
|
| 501 |
+
# remover = RembgRemover()
|
| 502 |
+
# clean_image = remover(input_image)
|
| 503 |
+
# clean_image.save(output_image)
|
| 504 |
+
# get_segmented_image_by_agent(
|
| 505 |
+
# Image.open(input_image), remover, remover, None, "./test_seg.png"
|
| 506 |
+
# )
|
| 507 |
|
| 508 |
remover = BMGG14Remover()
|
| 509 |
+
clean_image = remover("./camera.jpeg", "./seg.png")
|
| 510 |
+
from embodied_gen.utils.process_media import (
|
| 511 |
+
keep_largest_connected_component,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
keep_largest_connected_component(clean_image).save("./seg_post.png")
|
embodied_gen/scripts/gen_texture.py
CHANGED
|
@@ -94,6 +94,7 @@ def entrypoint() -> None:
|
|
| 94 |
delight=cfg.delight,
|
| 95 |
no_save_delight_img=True,
|
| 96 |
texture_wh=[cfg.texture_size, cfg.texture_size],
|
|
|
|
| 97 |
)
|
| 98 |
drender_api(
|
| 99 |
mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
|
|
|
|
| 94 |
delight=cfg.delight,
|
| 95 |
no_save_delight_img=True,
|
| 96 |
texture_wh=[cfg.texture_size, cfg.texture_size],
|
| 97 |
+
no_mesh_post_process=True,
|
| 98 |
)
|
| 99 |
drender_api(
|
| 100 |
mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
|
embodied_gen/utils/process_media.py
CHANGED
|
@@ -230,6 +230,29 @@ def filter_image_small_connected_components(
|
|
| 230 |
return image
|
| 231 |
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def combine_images_to_grid(
|
| 234 |
images: list[str | Image.Image],
|
| 235 |
cat_row_col: tuple[int, int] = None,
|
|
@@ -439,7 +462,7 @@ class SceneTreeVisualizer:
|
|
| 439 |
plt.axis("off")
|
| 440 |
|
| 441 |
legend_handles = [
|
| 442 |
-
Patch(facecolor=color, edgecolor=
|
| 443 |
for role, color in self.role_colors.items()
|
| 444 |
]
|
| 445 |
plt.legend(
|
|
@@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict:
|
|
| 465 |
dict: Mapping from scene ID to description.
|
| 466 |
"""
|
| 467 |
scene_dict = {}
|
| 468 |
-
with open(file_path, "r", encoding=
|
| 469 |
for line in f:
|
| 470 |
line = line.strip()
|
| 471 |
if not line or ":" not in line:
|
|
@@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool:
|
|
| 487 |
"""
|
| 488 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 489 |
|
| 490 |
-
return mime_type is not None and mime_type.startswith(
|
| 491 |
|
| 492 |
|
| 493 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
|
|
|
| 230 |
return image
|
| 231 |
|
| 232 |
|
| 233 |
+
def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image:
|
| 234 |
+
if pil_img.mode != "RGBA":
|
| 235 |
+
pil_img = pil_img.convert("RGBA")
|
| 236 |
+
|
| 237 |
+
img_arr = np.array(pil_img)
|
| 238 |
+
alpha_channel = img_arr[:, :, 3]
|
| 239 |
+
|
| 240 |
+
_, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)
|
| 241 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
| 242 |
+
binary_mask, connectivity=8
|
| 243 |
+
)
|
| 244 |
+
if num_labels < 2:
|
| 245 |
+
return pil_img
|
| 246 |
+
|
| 247 |
+
largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
|
| 248 |
+
new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype(
|
| 249 |
+
np.uint8
|
| 250 |
+
)
|
| 251 |
+
img_arr[:, :, 3] = new_alpha
|
| 252 |
+
|
| 253 |
+
return Image.fromarray(img_arr)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
def combine_images_to_grid(
|
| 257 |
images: list[str | Image.Image],
|
| 258 |
cat_row_col: tuple[int, int] = None,
|
|
|
|
| 462 |
plt.axis("off")
|
| 463 |
|
| 464 |
legend_handles = [
|
| 465 |
+
Patch(facecolor=color, edgecolor="black", label=role)
|
| 466 |
for role, color in self.role_colors.items()
|
| 467 |
]
|
| 468 |
plt.legend(
|
|
|
|
| 488 |
dict: Mapping from scene ID to description.
|
| 489 |
"""
|
| 490 |
scene_dict = {}
|
| 491 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 492 |
for line in f:
|
| 493 |
line = line.strip()
|
| 494 |
if not line or ":" not in line:
|
|
|
|
| 510 |
"""
|
| 511 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 512 |
|
| 513 |
+
return mime_type is not None and mime_type.startswith("image")
|
| 514 |
|
| 515 |
|
| 516 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|