Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
b05f3ac
1
Parent(s):
74b41f3
update
Browse files- app.py +2 -1
- app_style.py +17 -1
- common.py +100 -9
- embodied_gen/data/backproject_v2.py +11 -1
- embodied_gen/data/backproject_v3.py +8 -7
- embodied_gen/data/utils.py +42 -144
- embodied_gen/models/sam3d.py +147 -0
- embodied_gen/models/segment_model.py +15 -9
- embodied_gen/scripts/gen_texture.py +1 -0
- embodied_gen/scripts/imageto3d.py +13 -15
- embodied_gen/utils/monkey_patches.py +377 -0
- embodied_gen/utils/process_media.py +26 -3
- embodied_gen/utils/trender.py +57 -5
- thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py +1 -1
app.py
CHANGED
|
@@ -267,7 +267,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 267 |
|
| 268 |
demo.load(start_session)
|
| 269 |
demo.unload(end_session)
|
| 270 |
-
|
| 271 |
mesh_input.change(
|
| 272 |
lambda: tuple(
|
| 273 |
[
|
|
@@ -368,6 +368,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 368 |
texture_size,
|
| 369 |
project_delight,
|
| 370 |
fix_mesh,
|
|
|
|
| 371 |
],
|
| 372 |
outputs=[mesh_output, mesh_outpath, download_btn],
|
| 373 |
).success(
|
|
|
|
| 267 |
|
| 268 |
demo.load(start_session)
|
| 269 |
demo.unload(end_session)
|
| 270 |
+
no_mesh_post_process = gr.State(True)
|
| 271 |
mesh_input.change(
|
| 272 |
lambda: tuple(
|
| 273 |
[
|
|
|
|
| 368 |
texture_size,
|
| 369 |
project_delight,
|
| 370 |
fix_mesh,
|
| 371 |
+
no_mesh_post_process,
|
| 372 |
],
|
| 373 |
outputs=[mesh_output, mesh_outpath, download_btn],
|
| 374 |
).success(
|
app_style.py
CHANGED
|
@@ -1,10 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from gradio.themes import Soft
|
| 2 |
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
| 3 |
|
| 4 |
lighting_css = """
|
| 5 |
<style>
|
| 6 |
#lighter_mesh canvas {
|
| 7 |
-
filter: brightness(2.
|
| 8 |
}
|
| 9 |
</style>
|
| 10 |
"""
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
from gradio.themes import Soft
|
| 18 |
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
| 19 |
|
| 20 |
lighting_css = """
|
| 21 |
<style>
|
| 22 |
#lighter_mesh canvas {
|
| 23 |
+
filter: brightness(2.3) !important;
|
| 24 |
}
|
| 25 |
</style>
|
| 26 |
"""
|
common.py
CHANGED
|
@@ -53,6 +53,7 @@ from embodied_gen.scripts.text2image import (
|
|
| 53 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 54 |
from embodied_gen.utils.process_media import (
|
| 55 |
filter_image_small_connected_components,
|
|
|
|
| 56 |
merge_images_video,
|
| 57 |
)
|
| 58 |
from embodied_gen.utils.tags import VERSION
|
|
@@ -151,6 +152,21 @@ if os.getenv("GRADIO_APP") == "imageto3d":
|
|
| 151 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 152 |
)
|
| 153 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
elif os.getenv("GRADIO_APP") == "textto3d":
|
| 155 |
RBG_REMOVER = RembgRemover()
|
| 156 |
RBG14_REMOVER = BMGG14Remover()
|
|
@@ -169,6 +185,23 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|
| 169 |
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 170 |
)
|
| 171 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 173 |
DELIGHT = DelightingModel()
|
| 174 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
@@ -201,18 +234,23 @@ def end_session(req: gr.Request) -> None:
|
|
| 201 |
|
| 202 |
@spaces.GPU
|
| 203 |
def preprocess_image_fn(
|
| 204 |
-
image: str | np.ndarray | Image.Image,
|
|
|
|
|
|
|
| 205 |
) -> tuple[Image.Image, Image.Image]:
|
| 206 |
if isinstance(image, str):
|
| 207 |
image = Image.open(image)
|
| 208 |
elif isinstance(image, np.ndarray):
|
| 209 |
image = Image.fromarray(image)
|
| 210 |
|
| 211 |
-
image_cache = image.copy().
|
| 212 |
|
| 213 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 214 |
image = bg_remover(image)
|
| 215 |
-
image =
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return image, image_cache
|
| 218 |
|
|
@@ -224,7 +262,7 @@ def preprocess_sam_image_fn(
|
|
| 224 |
image = Image.fromarray(image)
|
| 225 |
|
| 226 |
sam_image = SAM_PREDICTOR.preprocess_image(image)
|
| 227 |
-
image_cache =
|
| 228 |
SAM_PREDICTOR.predictor.set_image(sam_image)
|
| 229 |
|
| 230 |
return sam_image, image_cache
|
|
@@ -349,11 +387,11 @@ def select_point(
|
|
| 349 |
def image_to_3d(
|
| 350 |
image: Image.Image,
|
| 351 |
seed: int,
|
| 352 |
-
ss_guidance_strength: float,
|
| 353 |
ss_sampling_steps: int,
|
| 354 |
-
slat_guidance_strength: float,
|
| 355 |
slat_sampling_steps: int,
|
| 356 |
raw_image_cache: Image.Image,
|
|
|
|
|
|
|
| 357 |
sam_image: Image.Image = None,
|
| 358 |
is_sam_image: bool = False,
|
| 359 |
req: gr.Request = None,
|
|
@@ -392,8 +430,56 @@ def image_to_3d(
|
|
| 392 |
|
| 393 |
gs_model = outputs["gaussian"][0]
|
| 394 |
mesh_model = outputs["mesh"][0]
|
| 395 |
-
color_images = render_video(gs_model)["color"]
|
| 396 |
-
normal_images = render_video(mesh_model)["normal"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 399 |
merge_images_video(color_images, normal_images, video_path)
|
|
@@ -688,6 +774,7 @@ def text2image_fn(
|
|
| 688 |
image_wh: int | tuple[int, int] = [1024, 1024],
|
| 689 |
rmbg_tag: str = "rembg",
|
| 690 |
seed: int = None,
|
|
|
|
| 691 |
n_sample: int = 3,
|
| 692 |
req: gr.Request = None,
|
| 693 |
):
|
|
@@ -715,7 +802,9 @@ def text2image_fn(
|
|
| 715 |
|
| 716 |
for idx in range(len(images)):
|
| 717 |
image = images[idx]
|
| 718 |
-
images[idx], _ = preprocess_image_fn(
|
|
|
|
|
|
|
| 719 |
|
| 720 |
save_paths = []
|
| 721 |
for idx, image in enumerate(images):
|
|
@@ -841,6 +930,7 @@ def backproject_texture_v2(
|
|
| 841 |
texture_size: int,
|
| 842 |
enable_delight: bool = True,
|
| 843 |
fix_mesh: bool = False,
|
|
|
|
| 844 |
uuid: str = "sample",
|
| 845 |
req: gr.Request = None,
|
| 846 |
) -> str:
|
|
@@ -857,6 +947,7 @@ def backproject_texture_v2(
|
|
| 857 |
skip_fix_mesh=not fix_mesh,
|
| 858 |
delight=enable_delight,
|
| 859 |
texture_wh=[texture_size, texture_size],
|
|
|
|
| 860 |
)
|
| 861 |
|
| 862 |
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
|
|
|
| 53 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 54 |
from embodied_gen.utils.process_media import (
|
| 55 |
filter_image_small_connected_components,
|
| 56 |
+
keep_largest_connected_component,
|
| 57 |
merge_images_video,
|
| 58 |
)
|
| 59 |
from embodied_gen.utils.tags import VERSION
|
|
|
|
| 152 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 153 |
)
|
| 154 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 155 |
+
elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
|
| 156 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 157 |
+
|
| 158 |
+
RBG_REMOVER = RembgRemover()
|
| 159 |
+
RBG14_REMOVER = BMGG14Remover()
|
| 160 |
+
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 161 |
+
PIPELINE = Sam3dInference()
|
| 162 |
+
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 163 |
+
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 164 |
+
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 165 |
+
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 166 |
+
TMP_DIR = os.path.join(
|
| 167 |
+
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 168 |
+
)
|
| 169 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 170 |
elif os.getenv("GRADIO_APP") == "textto3d":
|
| 171 |
RBG_REMOVER = RembgRemover()
|
| 172 |
RBG14_REMOVER = BMGG14Remover()
|
|
|
|
| 185 |
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 186 |
)
|
| 187 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 188 |
+
elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
|
| 189 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 190 |
+
|
| 191 |
+
RBG_REMOVER = RembgRemover()
|
| 192 |
+
RBG14_REMOVER = BMGG14Remover()
|
| 193 |
+
PIPELINE = Sam3dInference()
|
| 194 |
+
text_model_dir = "weights/Kolors"
|
| 195 |
+
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
| 196 |
+
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
| 197 |
+
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 198 |
+
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 199 |
+
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 200 |
+
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 201 |
+
TMP_DIR = os.path.join(
|
| 202 |
+
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 203 |
+
)
|
| 204 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 205 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 206 |
DELIGHT = DelightingModel()
|
| 207 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
|
|
| 234 |
|
| 235 |
@spaces.GPU
|
| 236 |
def preprocess_image_fn(
|
| 237 |
+
image: str | np.ndarray | Image.Image,
|
| 238 |
+
rmbg_tag: str = "rembg",
|
| 239 |
+
preprocess: bool = True,
|
| 240 |
) -> tuple[Image.Image, Image.Image]:
|
| 241 |
if isinstance(image, str):
|
| 242 |
image = Image.open(image)
|
| 243 |
elif isinstance(image, np.ndarray):
|
| 244 |
image = Image.fromarray(image)
|
| 245 |
|
| 246 |
+
image_cache = image.copy() # resize_pil(image.copy(), 1024)
|
| 247 |
|
| 248 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 249 |
image = bg_remover(image)
|
| 250 |
+
image = keep_largest_connected_component(image)
|
| 251 |
+
|
| 252 |
+
if preprocess:
|
| 253 |
+
image = trellis_preprocess(image)
|
| 254 |
|
| 255 |
return image, image_cache
|
| 256 |
|
|
|
|
| 262 |
image = Image.fromarray(image)
|
| 263 |
|
| 264 |
sam_image = SAM_PREDICTOR.preprocess_image(image)
|
| 265 |
+
image_cache = sam_image.copy()
|
| 266 |
SAM_PREDICTOR.predictor.set_image(sam_image)
|
| 267 |
|
| 268 |
return sam_image, image_cache
|
|
|
|
| 387 |
def image_to_3d(
|
| 388 |
image: Image.Image,
|
| 389 |
seed: int,
|
|
|
|
| 390 |
ss_sampling_steps: int,
|
|
|
|
| 391 |
slat_sampling_steps: int,
|
| 392 |
raw_image_cache: Image.Image,
|
| 393 |
+
ss_guidance_strength: float,
|
| 394 |
+
slat_guidance_strength: float,
|
| 395 |
sam_image: Image.Image = None,
|
| 396 |
is_sam_image: bool = False,
|
| 397 |
req: gr.Request = None,
|
|
|
|
| 430 |
|
| 431 |
gs_model = outputs["gaussian"][0]
|
| 432 |
mesh_model = outputs["mesh"][0]
|
| 433 |
+
color_images = render_video(gs_model, r=1.85)["color"]
|
| 434 |
+
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 435 |
+
|
| 436 |
+
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 437 |
+
merge_images_video(color_images, normal_images, video_path)
|
| 438 |
+
state = pack_state(gs_model, mesh_model)
|
| 439 |
+
|
| 440 |
+
gc.collect()
|
| 441 |
+
torch.cuda.empty_cache()
|
| 442 |
+
|
| 443 |
+
return state, video_path
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
@spaces.GPU
|
| 447 |
+
def image_to_3d_sam3d(
|
| 448 |
+
image: Image.Image,
|
| 449 |
+
seed: int,
|
| 450 |
+
ss_sampling_steps: int,
|
| 451 |
+
slat_sampling_steps: int,
|
| 452 |
+
raw_image_cache: Image.Image,
|
| 453 |
+
ss_guidance_strength: float = None,
|
| 454 |
+
slat_guidance_strength: float = None,
|
| 455 |
+
sam_image: Image.Image = None,
|
| 456 |
+
is_sam_image: bool = False,
|
| 457 |
+
req: gr.Request = None,
|
| 458 |
+
) -> tuple[dict, str]:
|
| 459 |
+
if is_sam_image:
|
| 460 |
+
seg_image = filter_image_small_connected_components(sam_image)
|
| 461 |
+
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
| 462 |
+
else:
|
| 463 |
+
seg_image = image
|
| 464 |
+
|
| 465 |
+
if isinstance(seg_image, np.ndarray):
|
| 466 |
+
seg_image = Image.fromarray(seg_image)
|
| 467 |
+
|
| 468 |
+
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 469 |
+
os.makedirs(output_root, exist_ok=True)
|
| 470 |
+
seg_image.save(f"{output_root}/seg_image.png")
|
| 471 |
+
raw_image_cache.save(f"{output_root}/raw_image.png")
|
| 472 |
+
outputs = PIPELINE.run(
|
| 473 |
+
seg_image,
|
| 474 |
+
seed=seed,
|
| 475 |
+
stage1_inference_steps=ss_sampling_steps,
|
| 476 |
+
stage2_inference_steps=slat_sampling_steps,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
gs_model = outputs["gaussian"][0]
|
| 480 |
+
mesh_model = outputs["mesh"][0]
|
| 481 |
+
color_images = render_video(gs_model, r=1.85)["color"]
|
| 482 |
+
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 483 |
|
| 484 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 485 |
merge_images_video(color_images, normal_images, video_path)
|
|
|
|
| 774 |
image_wh: int | tuple[int, int] = [1024, 1024],
|
| 775 |
rmbg_tag: str = "rembg",
|
| 776 |
seed: int = None,
|
| 777 |
+
enable_pre_resize: bool = True,
|
| 778 |
n_sample: int = 3,
|
| 779 |
req: gr.Request = None,
|
| 780 |
):
|
|
|
|
| 802 |
|
| 803 |
for idx in range(len(images)):
|
| 804 |
image = images[idx]
|
| 805 |
+
images[idx], _ = preprocess_image_fn(
|
| 806 |
+
image, rmbg_tag, enable_pre_resize
|
| 807 |
+
)
|
| 808 |
|
| 809 |
save_paths = []
|
| 810 |
for idx, image in enumerate(images):
|
|
|
|
| 930 |
texture_size: int,
|
| 931 |
enable_delight: bool = True,
|
| 932 |
fix_mesh: bool = False,
|
| 933 |
+
no_mesh_post_process: bool = False,
|
| 934 |
uuid: str = "sample",
|
| 935 |
req: gr.Request = None,
|
| 936 |
) -> str:
|
|
|
|
| 947 |
skip_fix_mesh=not fix_mesh,
|
| 948 |
delight=enable_delight,
|
| 949 |
texture_wh=[texture_size, texture_size],
|
| 950 |
+
no_mesh_post_process=no_mesh_post_process,
|
| 951 |
)
|
| 952 |
|
| 953 |
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
embodied_gen/data/backproject_v2.py
CHANGED
|
@@ -274,6 +274,7 @@ class TextureBacker:
|
|
| 274 |
mask_thresh (float, optional): Threshold for visibility masks.
|
| 275 |
smooth_texture (bool, optional): Apply post-processing to texture.
|
| 276 |
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
|
|
|
| 277 |
|
| 278 |
Example:
|
| 279 |
```py
|
|
@@ -308,6 +309,7 @@ class TextureBacker:
|
|
| 308 |
mask_thresh: float = 0.5,
|
| 309 |
smooth_texture: bool = True,
|
| 310 |
inpaint_smooth: bool = False,
|
|
|
|
| 311 |
) -> None:
|
| 312 |
self.camera_params = camera_params
|
| 313 |
self.renderer = None
|
|
@@ -318,6 +320,7 @@ class TextureBacker:
|
|
| 318 |
self.mask_thresh = mask_thresh
|
| 319 |
self.smooth_texture = smooth_texture
|
| 320 |
self.inpaint_smooth = inpaint_smooth
|
|
|
|
| 321 |
|
| 322 |
self.bake_angle_thresh = bake_angle_thresh
|
| 323 |
self.bake_unreliable_kernel_size = int(
|
|
@@ -668,7 +671,12 @@ class TextureBacker:
|
|
| 668 |
mesh, self.scale, self.center
|
| 669 |
)
|
| 670 |
textured_mesh = save_mesh_with_mtl(
|
| 671 |
-
vertices,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
)
|
| 673 |
|
| 674 |
return textured_mesh
|
|
@@ -766,6 +774,7 @@ def parse_args():
|
|
| 766 |
help="Disable saving delight image",
|
| 767 |
)
|
| 768 |
parser.add_argument("--n_max_faces", type=int, default=30000)
|
|
|
|
| 769 |
args, unknown = parser.parse_known_args()
|
| 770 |
|
| 771 |
return args
|
|
@@ -856,6 +865,7 @@ def entrypoint(
|
|
| 856 |
render_wh=args.resolution_hw,
|
| 857 |
texture_wh=args.texture_wh,
|
| 858 |
smooth_texture=not args.no_smooth_texture,
|
|
|
|
| 859 |
)
|
| 860 |
|
| 861 |
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
|
|
|
|
| 274 |
mask_thresh (float, optional): Threshold for visibility masks.
|
| 275 |
smooth_texture (bool, optional): Apply post-processing to texture.
|
| 276 |
inpaint_smooth (bool, optional): Apply inpainting smoothing.
|
| 277 |
+
mesh_post_process (bool, optional): False for preventing modification of vertices.
|
| 278 |
|
| 279 |
Example:
|
| 280 |
```py
|
|
|
|
| 309 |
mask_thresh: float = 0.5,
|
| 310 |
smooth_texture: bool = True,
|
| 311 |
inpaint_smooth: bool = False,
|
| 312 |
+
mesh_post_process: bool = True,
|
| 313 |
) -> None:
|
| 314 |
self.camera_params = camera_params
|
| 315 |
self.renderer = None
|
|
|
|
| 320 |
self.mask_thresh = mask_thresh
|
| 321 |
self.smooth_texture = smooth_texture
|
| 322 |
self.inpaint_smooth = inpaint_smooth
|
| 323 |
+
self.mesh_post_process = mesh_post_process
|
| 324 |
|
| 325 |
self.bake_angle_thresh = bake_angle_thresh
|
| 326 |
self.bake_unreliable_kernel_size = int(
|
|
|
|
| 671 |
mesh, self.scale, self.center
|
| 672 |
)
|
| 673 |
textured_mesh = save_mesh_with_mtl(
|
| 674 |
+
vertices,
|
| 675 |
+
faces,
|
| 676 |
+
uv_map,
|
| 677 |
+
texture_np,
|
| 678 |
+
output_path,
|
| 679 |
+
mesh_process=self.mesh_post_process,
|
| 680 |
)
|
| 681 |
|
| 682 |
return textured_mesh
|
|
|
|
| 774 |
help="Disable saving delight image",
|
| 775 |
)
|
| 776 |
parser.add_argument("--n_max_faces", type=int, default=30000)
|
| 777 |
+
parser.add_argument("--no_mesh_post_process", action="store_true")
|
| 778 |
args, unknown = parser.parse_known_args()
|
| 779 |
|
| 780 |
return args
|
|
|
|
| 865 |
render_wh=args.resolution_hw,
|
| 866 |
texture_wh=args.texture_wh,
|
| 867 |
smooth_texture=not args.no_smooth_texture,
|
| 868 |
+
mesh_post_process=not args.no_mesh_post_process,
|
| 869 |
)
|
| 870 |
|
| 871 |
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
|
embodied_gen/data/backproject_v3.py
CHANGED
|
@@ -14,10 +14,10 @@
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
-
|
| 18 |
import argparse
|
| 19 |
import logging
|
| 20 |
import math
|
|
|
|
| 21 |
from typing import Literal, Union
|
| 22 |
|
| 23 |
import cv2
|
|
@@ -353,8 +353,8 @@ def parse_args():
|
|
| 353 |
parser.add_argument(
|
| 354 |
"--distance",
|
| 355 |
type=float,
|
| 356 |
-
default=5,
|
| 357 |
-
help="Camera distance (default: 5)",
|
| 358 |
)
|
| 359 |
parser.add_argument(
|
| 360 |
"--resolution_hw",
|
|
@@ -400,8 +400,8 @@ def parse_args():
|
|
| 400 |
parser.add_argument(
|
| 401 |
"--mesh_sipmlify_ratio",
|
| 402 |
type=float,
|
| 403 |
-
default=0.
|
| 404 |
-
help="Mesh simplification ratio (default: 0.
|
| 405 |
)
|
| 406 |
parser.add_argument(
|
| 407 |
"--delight", action="store_true", help="Use delighting model."
|
|
@@ -425,6 +425,7 @@ def parse_args():
|
|
| 425 |
return args
|
| 426 |
|
| 427 |
|
|
|
|
| 428 |
def entrypoint(
|
| 429 |
delight_model: DelightingModel = None,
|
| 430 |
imagesr_model: ImageRealESRGAN = None,
|
|
@@ -499,7 +500,7 @@ def entrypoint(
|
|
| 499 |
faces = mesh.faces.astype(np.int32)
|
| 500 |
vertices = vertices.astype(np.float32)
|
| 501 |
|
| 502 |
-
if not args.skip_fix_mesh
|
| 503 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 504 |
vertices, faces = mesh_fixer(
|
| 505 |
filter_ratio=args.mesh_sipmlify_ratio,
|
|
@@ -511,7 +512,7 @@ def entrypoint(
|
|
| 511 |
if len(faces) > args.n_max_faces:
|
| 512 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 513 |
vertices, faces = mesh_fixer(
|
| 514 |
-
filter_ratio=max(0.
|
| 515 |
max_hole_size=0.04,
|
| 516 |
resolution=1024,
|
| 517 |
num_views=1000,
|
|
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
|
|
|
| 17 |
import argparse
|
| 18 |
import logging
|
| 19 |
import math
|
| 20 |
+
import os
|
| 21 |
from typing import Literal, Union
|
| 22 |
|
| 23 |
import cv2
|
|
|
|
| 353 |
parser.add_argument(
|
| 354 |
"--distance",
|
| 355 |
type=float,
|
| 356 |
+
default=4.5,
|
| 357 |
+
help="Camera distance (default: 4.5)",
|
| 358 |
)
|
| 359 |
parser.add_argument(
|
| 360 |
"--resolution_hw",
|
|
|
|
| 400 |
parser.add_argument(
|
| 401 |
"--mesh_sipmlify_ratio",
|
| 402 |
type=float,
|
| 403 |
+
default=0.85,
|
| 404 |
+
help="Mesh simplification ratio (default: 0.85)",
|
| 405 |
)
|
| 406 |
parser.add_argument(
|
| 407 |
"--delight", action="store_true", help="Use delighting model."
|
|
|
|
| 425 |
return args
|
| 426 |
|
| 427 |
|
| 428 |
+
@spaces.GPU
|
| 429 |
def entrypoint(
|
| 430 |
delight_model: DelightingModel = None,
|
| 431 |
imagesr_model: ImageRealESRGAN = None,
|
|
|
|
| 500 |
faces = mesh.faces.astype(np.int32)
|
| 501 |
vertices = vertices.astype(np.float32)
|
| 502 |
|
| 503 |
+
if not args.skip_fix_mesh:
|
| 504 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 505 |
vertices, faces = mesh_fixer(
|
| 506 |
filter_ratio=args.mesh_sipmlify_ratio,
|
|
|
|
| 512 |
if len(faces) > args.n_max_faces:
|
| 513 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 514 |
vertices, faces = mesh_fixer(
|
| 515 |
+
filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1),
|
| 516 |
max_hole_size=0.04,
|
| 517 |
resolution=1024,
|
| 518 |
num_views=1000,
|
embodied_gen/data/utils.py
CHANGED
|
@@ -15,10 +15,13 @@
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
|
|
|
|
| 18 |
import math
|
| 19 |
import os
|
| 20 |
-
import
|
| 21 |
import zipfile
|
|
|
|
|
|
|
| 22 |
from shutil import rmtree
|
| 23 |
from typing import List, Tuple, Union
|
| 24 |
|
|
@@ -28,20 +31,9 @@ import numpy as np
|
|
| 28 |
import nvdiffrast.torch as dr
|
| 29 |
import torch
|
| 30 |
import torch.nn.functional as F
|
| 31 |
-
from PIL import Image, ImageEnhance
|
| 32 |
-
|
| 33 |
-
try:
|
| 34 |
-
from kolors.models.modeling_chatglm import ChatGLMModel
|
| 35 |
-
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
|
| 36 |
-
except ImportError:
|
| 37 |
-
ChatGLMTokenizer = None
|
| 38 |
-
ChatGLMModel = None
|
| 39 |
-
import logging
|
| 40 |
-
from dataclasses import dataclass, field
|
| 41 |
-
|
| 42 |
import trimesh
|
| 43 |
from kaolin.render.camera import Camera
|
| 44 |
-
from
|
| 45 |
|
| 46 |
logger = logging.getLogger(__name__)
|
| 47 |
|
|
@@ -50,10 +42,8 @@ __all__ = [
|
|
| 50 |
"DiffrastRender",
|
| 51 |
"save_images",
|
| 52 |
"render_pbr",
|
| 53 |
-
"prelabel_text_feature",
|
| 54 |
"calc_vertex_normals",
|
| 55 |
"normalize_vertices_array",
|
| 56 |
-
"load_mesh_to_unit_cube",
|
| 57 |
"as_list",
|
| 58 |
"CameraSetting",
|
| 59 |
"import_kaolin_mesh",
|
|
@@ -67,6 +57,7 @@ __all__ = [
|
|
| 67 |
"trellis_preprocess",
|
| 68 |
"delete_dir",
|
| 69 |
"kaolin_to_opencv_view",
|
|
|
|
| 70 |
]
|
| 71 |
|
| 72 |
|
|
@@ -520,114 +511,6 @@ def render_pbr(
|
|
| 520 |
return image, albedo, diffuse, normal
|
| 521 |
|
| 522 |
|
| 523 |
-
def _move_to_target_device(data, device: str):
|
| 524 |
-
if isinstance(data, dict):
|
| 525 |
-
for key, value in data.items():
|
| 526 |
-
data[key] = _move_to_target_device(value, device)
|
| 527 |
-
elif isinstance(data, torch.Tensor):
|
| 528 |
-
return data.to(device)
|
| 529 |
-
|
| 530 |
-
return data
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
def _encode_prompt(
|
| 534 |
-
prompt_batch,
|
| 535 |
-
text_encoders,
|
| 536 |
-
tokenizers,
|
| 537 |
-
proportion_empty_prompts=0,
|
| 538 |
-
is_train=True,
|
| 539 |
-
):
|
| 540 |
-
prompt_embeds_list = []
|
| 541 |
-
|
| 542 |
-
captions = []
|
| 543 |
-
for caption in prompt_batch:
|
| 544 |
-
if random.random() < proportion_empty_prompts:
|
| 545 |
-
captions.append("")
|
| 546 |
-
elif isinstance(caption, str):
|
| 547 |
-
captions.append(caption)
|
| 548 |
-
elif isinstance(caption, (list, np.ndarray)):
|
| 549 |
-
captions.append(random.choice(caption) if is_train else caption[0])
|
| 550 |
-
|
| 551 |
-
with torch.no_grad():
|
| 552 |
-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 553 |
-
text_inputs = tokenizer(
|
| 554 |
-
captions,
|
| 555 |
-
padding="max_length",
|
| 556 |
-
max_length=256,
|
| 557 |
-
truncation=True,
|
| 558 |
-
return_tensors="pt",
|
| 559 |
-
).to(text_encoder.device)
|
| 560 |
-
|
| 561 |
-
output = text_encoder(
|
| 562 |
-
input_ids=text_inputs.input_ids,
|
| 563 |
-
attention_mask=text_inputs.attention_mask,
|
| 564 |
-
position_ids=text_inputs.position_ids,
|
| 565 |
-
output_hidden_states=True,
|
| 566 |
-
)
|
| 567 |
-
|
| 568 |
-
# We are only interested in the pooled output of the text encoder.
|
| 569 |
-
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
|
| 570 |
-
pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
|
| 571 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 572 |
-
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 573 |
-
prompt_embeds_list.append(prompt_embeds)
|
| 574 |
-
|
| 575 |
-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 576 |
-
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 577 |
-
|
| 578 |
-
return prompt_embeds, pooled_prompt_embeds
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
def load_llm_models(pretrained_model_name_or_path: str, device: str):
|
| 582 |
-
tokenizer = ChatGLMTokenizer.from_pretrained(
|
| 583 |
-
pretrained_model_name_or_path,
|
| 584 |
-
subfolder="text_encoder",
|
| 585 |
-
)
|
| 586 |
-
text_encoder = ChatGLMModel.from_pretrained(
|
| 587 |
-
pretrained_model_name_or_path,
|
| 588 |
-
subfolder="text_encoder",
|
| 589 |
-
).to(device)
|
| 590 |
-
|
| 591 |
-
text_encoders = [
|
| 592 |
-
text_encoder,
|
| 593 |
-
]
|
| 594 |
-
tokenizers = [
|
| 595 |
-
tokenizer,
|
| 596 |
-
]
|
| 597 |
-
|
| 598 |
-
logger.info(f"Load model from {pretrained_model_name_or_path} done.")
|
| 599 |
-
|
| 600 |
-
return tokenizers, text_encoders
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def prelabel_text_feature(
|
| 604 |
-
prompt_batch: List[str],
|
| 605 |
-
output_dir: str,
|
| 606 |
-
tokenizers: nn.Module,
|
| 607 |
-
text_encoders: nn.Module,
|
| 608 |
-
) -> List[str]:
|
| 609 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 610 |
-
|
| 611 |
-
# prompt_batch ["text..."]
|
| 612 |
-
prompt_embeds, pooled_prompt_embeds = _encode_prompt(
|
| 613 |
-
prompt_batch, text_encoders, tokenizers
|
| 614 |
-
)
|
| 615 |
-
|
| 616 |
-
prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
|
| 617 |
-
pooled_prompt_embeds = _move_to_target_device(
|
| 618 |
-
pooled_prompt_embeds, device="cpu"
|
| 619 |
-
)
|
| 620 |
-
|
| 621 |
-
data_dict = dict(
|
| 622 |
-
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
save_path = os.path.join(output_dir, "text_feat.pth")
|
| 626 |
-
torch.save(data_dict, save_path)
|
| 627 |
-
|
| 628 |
-
return save_path
|
| 629 |
-
|
| 630 |
-
|
| 631 |
def _calc_face_normals(
|
| 632 |
vertices: torch.Tensor, # V,3 first vertex may be unreferenced
|
| 633 |
faces: torch.Tensor, # F,3 long, first face may be all zero
|
|
@@ -683,25 +566,6 @@ def normalize_vertices_array(
|
|
| 683 |
return vertices, scale, center
|
| 684 |
|
| 685 |
|
| 686 |
-
def load_mesh_to_unit_cube(
|
| 687 |
-
mesh_file: str,
|
| 688 |
-
mesh_scale: float = 1.0,
|
| 689 |
-
) -> tuple[trimesh.Trimesh, float, list[float]]:
|
| 690 |
-
if not os.path.exists(mesh_file):
|
| 691 |
-
raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
|
| 692 |
-
|
| 693 |
-
mesh = trimesh.load(mesh_file)
|
| 694 |
-
if isinstance(mesh, trimesh.Scene):
|
| 695 |
-
mesh = trimesh.utils.concatenate(mesh)
|
| 696 |
-
|
| 697 |
-
vertices, scale, center = normalize_vertices_array(
|
| 698 |
-
mesh.vertices, mesh_scale
|
| 699 |
-
)
|
| 700 |
-
mesh.vertices = vertices
|
| 701 |
-
|
| 702 |
-
return mesh, scale, center
|
| 703 |
-
|
| 704 |
-
|
| 705 |
def as_list(obj):
|
| 706 |
if isinstance(obj, (list, tuple)):
|
| 707 |
return obj
|
|
@@ -862,6 +726,7 @@ def save_mesh_with_mtl(
|
|
| 862 |
texture: Union[Image.Image, np.ndarray],
|
| 863 |
output_path: str,
|
| 864 |
material_base=(250, 250, 250, 255),
|
|
|
|
| 865 |
) -> trimesh.Trimesh:
|
| 866 |
if isinstance(texture, np.ndarray):
|
| 867 |
texture = Image.fromarray(texture)
|
|
@@ -870,6 +735,7 @@ def save_mesh_with_mtl(
|
|
| 870 |
vertices,
|
| 871 |
faces,
|
| 872 |
visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
|
|
|
|
| 873 |
)
|
| 874 |
mesh.visual.material = trimesh.visual.material.SimpleMaterial(
|
| 875 |
image=texture,
|
|
@@ -998,8 +864,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
|
|
| 998 |
|
| 999 |
|
| 1000 |
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
|
| 1001 |
-
|
| 1002 |
-
scale = min(1,
|
|
|
|
| 1003 |
if scale < 1:
|
| 1004 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 1005 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
@@ -1068,3 +935,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
|
|
| 1068 |
rmtree(item_path)
|
| 1069 |
else:
|
| 1070 |
os.remove(item_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
|
| 18 |
+
import logging
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
+
import time
|
| 22 |
import zipfile
|
| 23 |
+
from contextlib import contextmanager
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
from shutil import rmtree
|
| 26 |
from typing import List, Tuple, Union
|
| 27 |
|
|
|
|
| 31 |
import nvdiffrast.torch as dr
|
| 32 |
import torch
|
| 33 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
import trimesh
|
| 35 |
from kaolin.render.camera import Camera
|
| 36 |
+
from PIL import Image, ImageEnhance
|
| 37 |
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
|
|
|
| 42 |
"DiffrastRender",
|
| 43 |
"save_images",
|
| 44 |
"render_pbr",
|
|
|
|
| 45 |
"calc_vertex_normals",
|
| 46 |
"normalize_vertices_array",
|
|
|
|
| 47 |
"as_list",
|
| 48 |
"CameraSetting",
|
| 49 |
"import_kaolin_mesh",
|
|
|
|
| 57 |
"trellis_preprocess",
|
| 58 |
"delete_dir",
|
| 59 |
"kaolin_to_opencv_view",
|
| 60 |
+
"model_device_ctx",
|
| 61 |
]
|
| 62 |
|
| 63 |
|
|
|
|
| 511 |
return image, albedo, diffuse, normal
|
| 512 |
|
| 513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
def _calc_face_normals(
|
| 515 |
vertices: torch.Tensor, # V,3 first vertex may be unreferenced
|
| 516 |
faces: torch.Tensor, # F,3 long, first face may be all zero
|
|
|
|
| 566 |
return vertices, scale, center
|
| 567 |
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
def as_list(obj):
|
| 570 |
if isinstance(obj, (list, tuple)):
|
| 571 |
return obj
|
|
|
|
| 726 |
texture: Union[Image.Image, np.ndarray],
|
| 727 |
output_path: str,
|
| 728 |
material_base=(250, 250, 250, 255),
|
| 729 |
+
mesh_process: bool = True,
|
| 730 |
) -> trimesh.Trimesh:
|
| 731 |
if isinstance(texture, np.ndarray):
|
| 732 |
texture = Image.fromarray(texture)
|
|
|
|
| 735 |
vertices,
|
| 736 |
faces,
|
| 737 |
visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
|
| 738 |
+
process=mesh_process, # True for preventing modification of vertices
|
| 739 |
)
|
| 740 |
mesh.visual.material = trimesh.visual.material.SimpleMaterial(
|
| 741 |
image=texture,
|
|
|
|
| 864 |
|
| 865 |
|
| 866 |
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
|
| 867 |
+
current_max_dim = max(image.size)
|
| 868 |
+
scale = min(1, max_size / current_max_dim)
|
| 869 |
+
|
| 870 |
if scale < 1:
|
| 871 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 872 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
|
| 935 |
rmtree(item_path)
|
| 936 |
else:
|
| 937 |
os.remove(item_path)
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
@contextmanager
|
| 941 |
+
def model_device_ctx(
|
| 942 |
+
*models,
|
| 943 |
+
src_device: str = "cpu",
|
| 944 |
+
dst_device: str = "cuda",
|
| 945 |
+
verbose: bool = False,
|
| 946 |
+
):
|
| 947 |
+
start = time.perf_counter()
|
| 948 |
+
for m in models:
|
| 949 |
+
if m is None:
|
| 950 |
+
continue
|
| 951 |
+
m.to(dst_device)
|
| 952 |
+
to_cuda_time = time.perf_counter() - start
|
| 953 |
+
|
| 954 |
+
try:
|
| 955 |
+
yield
|
| 956 |
+
finally:
|
| 957 |
+
start = time.perf_counter()
|
| 958 |
+
for m in models:
|
| 959 |
+
if m is None:
|
| 960 |
+
continue
|
| 961 |
+
m.to(src_device)
|
| 962 |
+
to_cpu_time = time.perf_counter() - start
|
| 963 |
+
|
| 964 |
+
if verbose:
|
| 965 |
+
model_names = [m.__class__.__name__ for m in models]
|
| 966 |
+
logger.debug(
|
| 967 |
+
f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s"
|
| 968 |
+
)
|
embodied_gen/models/sam3d.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
|
| 18 |
+
|
| 19 |
+
monkey_patch_sam3d()
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from hydra.utils import instantiate
|
| 26 |
+
from modelscope import snapshot_download
|
| 27 |
+
from omegaconf import OmegaConf
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
current_file_path = os.path.abspath(__file__)
|
| 31 |
+
current_dir = os.path.dirname(current_file_path)
|
| 32 |
+
sys.path.append(os.path.join(current_dir, "../.."))
|
| 33 |
+
from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 34 |
+
InferencePipelinePointMap,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
__all__ = ["Sam3dInference"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_image(path: str) -> np.ndarray:
|
| 41 |
+
image = Image.open(path)
|
| 42 |
+
image = np.array(image)
|
| 43 |
+
image = image.astype(np.uint8)
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_mask(path: str) -> np.ndarray:
|
| 48 |
+
mask = load_image(path)
|
| 49 |
+
mask = mask > 0
|
| 50 |
+
if mask.ndim == 3:
|
| 51 |
+
mask = mask[..., -1]
|
| 52 |
+
return mask
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Sam3dInference:
|
| 56 |
+
def __init__(
|
| 57 |
+
self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
|
| 58 |
+
) -> None:
|
| 59 |
+
if not os.path.exists(local_dir):
|
| 60 |
+
snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
|
| 61 |
+
config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
|
| 62 |
+
config = OmegaConf.load(config_file)
|
| 63 |
+
config.rendering_engine = "nvdiffrast"
|
| 64 |
+
config.compile_model = compile
|
| 65 |
+
config.workspace_dir = os.path.dirname(config_file)
|
| 66 |
+
# Generate 4 gs in each pixel.
|
| 67 |
+
config["slat_decoder_gs_config_path"] = config.pop(
|
| 68 |
+
"slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
|
| 69 |
+
)
|
| 70 |
+
config["slat_decoder_gs_ckpt_path"] = config.pop(
|
| 71 |
+
"slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
|
| 72 |
+
)
|
| 73 |
+
self.pipeline: InferencePipelinePointMap = instantiate(config)
|
| 74 |
+
|
| 75 |
+
def merge_mask_to_rgba(
|
| 76 |
+
self, image: np.ndarray, mask: np.ndarray
|
| 77 |
+
) -> np.ndarray:
|
| 78 |
+
mask = mask.astype(np.uint8) * 255
|
| 79 |
+
mask = mask[..., None]
|
| 80 |
+
rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
|
| 81 |
+
|
| 82 |
+
return rgba_image
|
| 83 |
+
|
| 84 |
+
def run(
|
| 85 |
+
self,
|
| 86 |
+
image: np.ndarray | Image.Image,
|
| 87 |
+
mask: np.ndarray = None,
|
| 88 |
+
seed: int = None,
|
| 89 |
+
pointmap: np.ndarray = None,
|
| 90 |
+
use_stage1_distillation: bool = False,
|
| 91 |
+
use_stage2_distillation: bool = False,
|
| 92 |
+
stage1_inference_steps: int = 25,
|
| 93 |
+
stage2_inference_steps: int = 25,
|
| 94 |
+
) -> dict:
|
| 95 |
+
if isinstance(image, Image.Image):
|
| 96 |
+
image = np.array(image)
|
| 97 |
+
if mask is not None:
|
| 98 |
+
image = self.merge_mask_to_rgba(image, mask)
|
| 99 |
+
return self.pipeline.run(
|
| 100 |
+
image,
|
| 101 |
+
None,
|
| 102 |
+
seed,
|
| 103 |
+
stage1_only=False,
|
| 104 |
+
with_mesh_postprocess=False,
|
| 105 |
+
with_texture_baking=False,
|
| 106 |
+
with_layout_postprocess=False,
|
| 107 |
+
use_vertex_color=True,
|
| 108 |
+
use_stage1_distillation=use_stage1_distillation,
|
| 109 |
+
use_stage2_distillation=use_stage2_distillation,
|
| 110 |
+
stage1_inference_steps=stage1_inference_steps,
|
| 111 |
+
stage2_inference_steps=stage2_inference_steps,
|
| 112 |
+
pointmap=pointmap,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
pipeline = Sam3dInference()
|
| 118 |
+
|
| 119 |
+
# load image
|
| 120 |
+
image = load_image(
|
| 121 |
+
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
|
| 122 |
+
)
|
| 123 |
+
mask = load_mask(
|
| 124 |
+
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
import torch
|
| 128 |
+
|
| 129 |
+
if torch.cuda.is_available():
|
| 130 |
+
torch.cuda.reset_peak_memory_stats()
|
| 131 |
+
torch.cuda.empty_cache()
|
| 132 |
+
|
| 133 |
+
from time import time
|
| 134 |
+
|
| 135 |
+
start = time()
|
| 136 |
+
|
| 137 |
+
output = pipeline.run(image, mask, seed=42)
|
| 138 |
+
print(f"Running cost: {round(time()-start, 1)}")
|
| 139 |
+
|
| 140 |
+
if torch.cuda.is_available():
|
| 141 |
+
max_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
| 142 |
+
print(f"(Max VRAM): {max_memory:.2f} GB")
|
| 143 |
+
|
| 144 |
+
print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 145 |
+
|
| 146 |
+
output["gs"].save_ply(f"outputs/splat.ply")
|
| 147 |
+
print("Your reconstruction has been saved to outputs/splat.ply")
|
embodied_gen/models/segment_model.py
CHANGED
|
@@ -43,6 +43,7 @@ __all__ = [
|
|
| 43 |
"SAMRemover",
|
| 44 |
"SAMPredictor",
|
| 45 |
"RembgRemover",
|
|
|
|
| 46 |
"get_segmented_image_by_agent",
|
| 47 |
]
|
| 48 |
|
|
@@ -376,7 +377,7 @@ class BMGG14Remover(object):
|
|
| 376 |
|
| 377 |
def __call__(
|
| 378 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 379 |
-
):
|
| 380 |
"""Removes background from an image.
|
| 381 |
|
| 382 |
Args:
|
|
@@ -496,13 +497,18 @@ if __name__ == "__main__":
|
|
| 496 |
# input_image = "outputs/text2image/tmp/bucket.jpeg"
|
| 497 |
# output_image = "outputs/text2image/tmp/bucket_seg.png"
|
| 498 |
|
| 499 |
-
remover = SAMRemover(model_type="vit_h")
|
| 500 |
-
remover = RembgRemover()
|
| 501 |
-
clean_image = remover(input_image)
|
| 502 |
-
clean_image.save(output_image)
|
| 503 |
-
get_segmented_image_by_agent(
|
| 504 |
-
|
| 505 |
-
)
|
| 506 |
|
| 507 |
remover = BMGG14Remover()
|
| 508 |
-
remover("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"SAMRemover",
|
| 44 |
"SAMPredictor",
|
| 45 |
"RembgRemover",
|
| 46 |
+
"BMGG14Remover",
|
| 47 |
"get_segmented_image_by_agent",
|
| 48 |
]
|
| 49 |
|
|
|
|
| 377 |
|
| 378 |
def __call__(
|
| 379 |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
|
| 380 |
+
) -> Image.Image:
|
| 381 |
"""Removes background from an image.
|
| 382 |
|
| 383 |
Args:
|
|
|
|
| 497 |
# input_image = "outputs/text2image/tmp/bucket.jpeg"
|
| 498 |
# output_image = "outputs/text2image/tmp/bucket_seg.png"
|
| 499 |
|
| 500 |
+
# remover = SAMRemover(model_type="vit_h")
|
| 501 |
+
# remover = RembgRemover()
|
| 502 |
+
# clean_image = remover(input_image)
|
| 503 |
+
# clean_image.save(output_image)
|
| 504 |
+
# get_segmented_image_by_agent(
|
| 505 |
+
# Image.open(input_image), remover, remover, None, "./test_seg.png"
|
| 506 |
+
# )
|
| 507 |
|
| 508 |
remover = BMGG14Remover()
|
| 509 |
+
clean_image = remover("./camera.jpeg", "./seg.png")
|
| 510 |
+
from embodied_gen.utils.process_media import (
|
| 511 |
+
keep_largest_connected_component,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
keep_largest_connected_component(clean_image).save("./seg_post.png")
|
embodied_gen/scripts/gen_texture.py
CHANGED
|
@@ -94,6 +94,7 @@ def entrypoint() -> None:
|
|
| 94 |
delight=cfg.delight,
|
| 95 |
no_save_delight_img=True,
|
| 96 |
texture_wh=[cfg.texture_size, cfg.texture_size],
|
|
|
|
| 97 |
)
|
| 98 |
drender_api(
|
| 99 |
mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
|
|
|
|
| 94 |
delight=cfg.delight,
|
| 95 |
no_save_delight_img=True,
|
| 96 |
texture_wh=[cfg.texture_size, cfg.texture_size],
|
| 97 |
+
no_mesh_post_process=True,
|
| 98 |
)
|
| 99 |
drender_api(
|
| 100 |
mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
|
embodied_gen/scripts/imageto3d.py
CHANGED
|
@@ -26,12 +26,14 @@ import numpy as np
|
|
| 26 |
import torch
|
| 27 |
import trimesh
|
| 28 |
from PIL import Image
|
| 29 |
-
from embodied_gen.data.
|
| 30 |
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
| 31 |
-
|
|
|
|
| 32 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 33 |
from embodied_gen.models.segment_model import RembgRemover
|
| 34 |
-
|
|
|
|
| 35 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
| 36 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 37 |
from embodied_gen.utils.log import logger
|
|
@@ -59,8 +61,8 @@ os.environ["SPCONV_ALGO"] = "native"
|
|
| 59 |
random.seed(0)
|
| 60 |
|
| 61 |
logger.info("Loading Image3D Models...")
|
| 62 |
-
DELIGHT = DelightingModel()
|
| 63 |
-
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 64 |
RBG_REMOVER = RembgRemover()
|
| 65 |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 66 |
"microsoft/TRELLIS-image-large"
|
|
@@ -108,9 +110,7 @@ def parse_args():
|
|
| 108 |
default=2,
|
| 109 |
)
|
| 110 |
parser.add_argument("--disable_decompose_convex", action="store_true")
|
| 111 |
-
parser.add_argument(
|
| 112 |
-
"--texture_wh", type=int, nargs=2, default=[2048, 2048]
|
| 113 |
-
)
|
| 114 |
args, unknown = parser.parse_known_args()
|
| 115 |
|
| 116 |
return args
|
|
@@ -248,16 +248,14 @@ def entrypoint(**kwargs):
|
|
| 248 |
mesh.export(mesh_obj_path)
|
| 249 |
|
| 250 |
mesh = backproject_api(
|
| 251 |
-
delight_model=DELIGHT,
|
| 252 |
-
imagesr_model=IMAGESR_MODEL,
|
| 253 |
-
|
| 254 |
mesh_path=mesh_obj_path,
|
| 255 |
output_path=mesh_obj_path,
|
| 256 |
skip_fix_mesh=False,
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
elevation=[20, -10, 60, -50],
|
| 260 |
-
num_images=12,
|
| 261 |
)
|
| 262 |
|
| 263 |
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
|
|
|
| 26 |
import torch
|
| 27 |
import trimesh
|
| 28 |
from PIL import Image
|
| 29 |
+
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
|
| 30 |
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
| 31 |
+
|
| 32 |
+
# from embodied_gen.models.delight_model import DelightingModel
|
| 33 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 34 |
from embodied_gen.models.segment_model import RembgRemover
|
| 35 |
+
|
| 36 |
+
# from embodied_gen.models.sr_model import ImageRealESRGAN
|
| 37 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
| 38 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 39 |
from embodied_gen.utils.log import logger
|
|
|
|
| 61 |
random.seed(0)
|
| 62 |
|
| 63 |
logger.info("Loading Image3D Models...")
|
| 64 |
+
# DELIGHT = DelightingModel()
|
| 65 |
+
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 66 |
RBG_REMOVER = RembgRemover()
|
| 67 |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 68 |
"microsoft/TRELLIS-image-large"
|
|
|
|
| 110 |
default=2,
|
| 111 |
)
|
| 112 |
parser.add_argument("--disable_decompose_convex", action="store_true")
|
| 113 |
+
parser.add_argument("--texture_size", type=int, default=2048)
|
|
|
|
|
|
|
| 114 |
args, unknown = parser.parse_known_args()
|
| 115 |
|
| 116 |
return args
|
|
|
|
| 248 |
mesh.export(mesh_obj_path)
|
| 249 |
|
| 250 |
mesh = backproject_api(
|
| 251 |
+
# delight_model=DELIGHT,
|
| 252 |
+
# imagesr_model=IMAGESR_MODEL,
|
| 253 |
+
gs_path=aligned_gs_path,
|
| 254 |
mesh_path=mesh_obj_path,
|
| 255 |
output_path=mesh_obj_path,
|
| 256 |
skip_fix_mesh=False,
|
| 257 |
+
texture_size=args.texture_size,
|
| 258 |
+
delight=False,
|
|
|
|
|
|
|
| 259 |
)
|
| 260 |
|
| 261 |
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
embodied_gen/utils/monkey_patches.py
CHANGED
|
@@ -25,6 +25,12 @@ from omegaconf import OmegaConf
|
|
| 25 |
from PIL import Image
|
| 26 |
from torchvision import transforms
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def monkey_patch_pano2room():
|
| 30 |
current_file_path = os.path.abspath(__file__)
|
|
@@ -216,3 +222,374 @@ def monkey_patch_maniskill():
|
|
| 216 |
ManiSkillScene.get_human_render_camera_images = (
|
| 217 |
get_human_render_camera_images
|
| 218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
from PIL import Image
|
| 26 |
from torchvision import transforms
|
| 27 |
|
| 28 |
+
__all__ = [
|
| 29 |
+
"monkey_patch_pano2room",
|
| 30 |
+
"monkey_patch_maniskill",
|
| 31 |
+
"monkey_patch_sam3d",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
|
| 35 |
def monkey_patch_pano2room():
|
| 36 |
current_file_path = os.path.abspath(__file__)
|
|
|
|
| 222 |
ManiSkillScene.get_human_render_camera_images = (
|
| 223 |
get_human_render_camera_images
|
| 224 |
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def monkey_patch_sam3d():
|
| 228 |
+
from typing import Optional, Union
|
| 229 |
+
|
| 230 |
+
from embodied_gen.data.utils import model_device_ctx
|
| 231 |
+
from embodied_gen.utils.log import logger
|
| 232 |
+
|
| 233 |
+
os.environ["LIDRA_SKIP_INIT"] = "true"
|
| 234 |
+
|
| 235 |
+
current_file_path = os.path.abspath(__file__)
|
| 236 |
+
current_dir = os.path.dirname(current_file_path)
|
| 237 |
+
sam3d_root = os.path.abspath(
|
| 238 |
+
os.path.join(current_dir, "../../thirdparty/sam3d")
|
| 239 |
+
)
|
| 240 |
+
if sam3d_root not in sys.path:
|
| 241 |
+
sys.path.insert(0, sam3d_root)
|
| 242 |
+
|
| 243 |
+
print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
|
| 244 |
+
|
| 245 |
+
def patch_pointmap_infer_pipeline():
|
| 246 |
+
from copy import deepcopy
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 250 |
+
InferencePipelinePointMap,
|
| 251 |
+
)
|
| 252 |
+
except ImportError:
|
| 253 |
+
logger.error(
|
| 254 |
+
"[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
|
| 255 |
+
)
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
def patch_run(
|
| 259 |
+
self,
|
| 260 |
+
image: Union[None, Image.Image, np.ndarray],
|
| 261 |
+
mask: Union[None, Image.Image, np.ndarray] = None,
|
| 262 |
+
seed: Optional[int] = None,
|
| 263 |
+
stage1_only=False,
|
| 264 |
+
with_mesh_postprocess=True,
|
| 265 |
+
with_texture_baking=True,
|
| 266 |
+
with_layout_postprocess=True,
|
| 267 |
+
use_vertex_color=False,
|
| 268 |
+
stage1_inference_steps=None,
|
| 269 |
+
stage2_inference_steps=None,
|
| 270 |
+
use_stage1_distillation=False,
|
| 271 |
+
use_stage2_distillation=False,
|
| 272 |
+
pointmap=None,
|
| 273 |
+
decode_formats=None,
|
| 274 |
+
estimate_plane=False,
|
| 275 |
+
) -> dict:
|
| 276 |
+
image = self.merge_image_and_mask(image, mask)
|
| 277 |
+
with self.device:
|
| 278 |
+
pointmap_dict = self.compute_pointmap(image, pointmap)
|
| 279 |
+
pointmap = pointmap_dict["pointmap"]
|
| 280 |
+
pts = type(self)._down_sample_img(pointmap)
|
| 281 |
+
pts_colors = type(self)._down_sample_img(
|
| 282 |
+
pointmap_dict["pts_color"]
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if estimate_plane:
|
| 286 |
+
return self.estimate_plane(pointmap_dict, image)
|
| 287 |
+
|
| 288 |
+
ss_input_dict = self.preprocess_image(
|
| 289 |
+
image, self.ss_preprocessor, pointmap=pointmap
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
slat_input_dict = self.preprocess_image(
|
| 293 |
+
image, self.slat_preprocessor
|
| 294 |
+
)
|
| 295 |
+
if seed is not None:
|
| 296 |
+
torch.manual_seed(seed)
|
| 297 |
+
|
| 298 |
+
with model_device_ctx(
|
| 299 |
+
self.models["ss_generator"],
|
| 300 |
+
self.models["ss_decoder"],
|
| 301 |
+
self.condition_embedders["ss_condition_embedder"],
|
| 302 |
+
):
|
| 303 |
+
ss_return_dict = self.sample_sparse_structure(
|
| 304 |
+
ss_input_dict,
|
| 305 |
+
inference_steps=stage1_inference_steps,
|
| 306 |
+
use_distillation=use_stage1_distillation,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# We could probably use the decoder from the models themselves
|
| 310 |
+
pointmap_scale = ss_input_dict.get("pointmap_scale", None)
|
| 311 |
+
pointmap_shift = ss_input_dict.get("pointmap_shift", None)
|
| 312 |
+
ss_return_dict.update(
|
| 313 |
+
self.pose_decoder(
|
| 314 |
+
ss_return_dict,
|
| 315 |
+
scene_scale=pointmap_scale,
|
| 316 |
+
scene_shift=pointmap_shift,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
logger.info(
|
| 321 |
+
f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
|
| 322 |
+
)
|
| 323 |
+
ss_return_dict["scale"] = (
|
| 324 |
+
ss_return_dict["scale"]
|
| 325 |
+
* ss_return_dict["downsample_factor"]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if stage1_only:
|
| 329 |
+
logger.info("Finished!")
|
| 330 |
+
ss_return_dict["voxel"] = (
|
| 331 |
+
ss_return_dict["coords"][:, 1:] / 64 - 0.5
|
| 332 |
+
)
|
| 333 |
+
return {
|
| 334 |
+
**ss_return_dict,
|
| 335 |
+
"pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
|
| 336 |
+
"pointmap_colors": pts_colors.cpu().permute(
|
| 337 |
+
(1, 2, 0)
|
| 338 |
+
), # HxWx3
|
| 339 |
+
}
|
| 340 |
+
# return ss_return_dict
|
| 341 |
+
|
| 342 |
+
coords = ss_return_dict["coords"]
|
| 343 |
+
with model_device_ctx(
|
| 344 |
+
self.models["slat_generator"],
|
| 345 |
+
self.condition_embedders["slat_condition_embedder"],
|
| 346 |
+
):
|
| 347 |
+
slat = self.sample_slat(
|
| 348 |
+
slat_input_dict,
|
| 349 |
+
coords,
|
| 350 |
+
inference_steps=stage2_inference_steps,
|
| 351 |
+
use_distillation=use_stage2_distillation,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
with model_device_ctx(
|
| 355 |
+
self.models["slat_decoder_mesh"],
|
| 356 |
+
self.models["slat_decoder_gs"],
|
| 357 |
+
self.models["slat_decoder_gs_4"],
|
| 358 |
+
):
|
| 359 |
+
outputs = self.decode_slat(
|
| 360 |
+
slat,
|
| 361 |
+
(
|
| 362 |
+
self.decode_formats
|
| 363 |
+
if decode_formats is None
|
| 364 |
+
else decode_formats
|
| 365 |
+
),
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
outputs = self.postprocess_slat_output(
|
| 369 |
+
outputs,
|
| 370 |
+
with_mesh_postprocess,
|
| 371 |
+
with_texture_baking,
|
| 372 |
+
use_vertex_color,
|
| 373 |
+
)
|
| 374 |
+
glb = outputs.get("glb", None)
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
if (
|
| 378 |
+
with_layout_postprocess
|
| 379 |
+
and self.layout_post_optimization_method is not None
|
| 380 |
+
):
|
| 381 |
+
assert (
|
| 382 |
+
glb is not None
|
| 383 |
+
), "require mesh to run postprocessing"
|
| 384 |
+
logger.info(
|
| 385 |
+
"Running layout post optimization method..."
|
| 386 |
+
)
|
| 387 |
+
postprocessed_pose = self.run_post_optimization(
|
| 388 |
+
deepcopy(glb),
|
| 389 |
+
pointmap_dict["intrinsics"],
|
| 390 |
+
ss_return_dict,
|
| 391 |
+
ss_input_dict,
|
| 392 |
+
)
|
| 393 |
+
ss_return_dict.update(postprocessed_pose)
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(
|
| 396 |
+
f"Error during layout post optimization: {e}",
|
| 397 |
+
exc_info=True,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
result = {
|
| 401 |
+
**ss_return_dict,
|
| 402 |
+
**outputs,
|
| 403 |
+
"pointmap": pts.cpu().permute((1, 2, 0)),
|
| 404 |
+
"pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
|
| 405 |
+
}
|
| 406 |
+
return result
|
| 407 |
+
|
| 408 |
+
InferencePipelinePointMap.run = patch_run
|
| 409 |
+
|
| 410 |
+
def patch_infer_init():
|
| 411 |
+
import torch
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
from sam3d_objects.pipeline import preprocess_utils
|
| 415 |
+
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 416 |
+
InferencePipeline,
|
| 417 |
+
)
|
| 418 |
+
from sam3d_objects.pipeline.inference_utils import (
|
| 419 |
+
SLAT_MEAN,
|
| 420 |
+
SLAT_STD,
|
| 421 |
+
)
|
| 422 |
+
except ImportError:
|
| 423 |
+
print(
|
| 424 |
+
"[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
|
| 425 |
+
)
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
def patch_init(
|
| 429 |
+
self,
|
| 430 |
+
ss_generator_config_path,
|
| 431 |
+
ss_generator_ckpt_path,
|
| 432 |
+
slat_generator_config_path,
|
| 433 |
+
slat_generator_ckpt_path,
|
| 434 |
+
ss_decoder_config_path,
|
| 435 |
+
ss_decoder_ckpt_path,
|
| 436 |
+
slat_decoder_gs_config_path,
|
| 437 |
+
slat_decoder_gs_ckpt_path,
|
| 438 |
+
slat_decoder_mesh_config_path,
|
| 439 |
+
slat_decoder_mesh_ckpt_path,
|
| 440 |
+
slat_decoder_gs_4_config_path=None,
|
| 441 |
+
slat_decoder_gs_4_ckpt_path=None,
|
| 442 |
+
ss_encoder_config_path=None,
|
| 443 |
+
ss_encoder_ckpt_path=None,
|
| 444 |
+
decode_formats=["gaussian", "mesh"],
|
| 445 |
+
dtype="bfloat16",
|
| 446 |
+
pad_size=1.0,
|
| 447 |
+
version="v0",
|
| 448 |
+
device="cuda",
|
| 449 |
+
ss_preprocessor=preprocess_utils.get_default_preprocessor(),
|
| 450 |
+
slat_preprocessor=preprocess_utils.get_default_preprocessor(),
|
| 451 |
+
ss_condition_input_mapping=["image"],
|
| 452 |
+
slat_condition_input_mapping=["image"],
|
| 453 |
+
pose_decoder_name="default",
|
| 454 |
+
workspace_dir="",
|
| 455 |
+
downsample_ss_dist=0, # the distance we use to downsample
|
| 456 |
+
ss_inference_steps=25,
|
| 457 |
+
ss_rescale_t=3,
|
| 458 |
+
ss_cfg_strength=7,
|
| 459 |
+
ss_cfg_interval=[0, 500],
|
| 460 |
+
ss_cfg_strength_pm=0.0,
|
| 461 |
+
slat_inference_steps=25,
|
| 462 |
+
slat_rescale_t=3,
|
| 463 |
+
slat_cfg_strength=5,
|
| 464 |
+
slat_cfg_interval=[0, 500],
|
| 465 |
+
rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
|
| 466 |
+
shape_model_dtype=None,
|
| 467 |
+
compile_model=False,
|
| 468 |
+
slat_mean=SLAT_MEAN,
|
| 469 |
+
slat_std=SLAT_STD,
|
| 470 |
+
):
|
| 471 |
+
self.rendering_engine = rendering_engine
|
| 472 |
+
self.device = torch.device(device)
|
| 473 |
+
self.compile_model = compile_model
|
| 474 |
+
logger.info(f"self.device: {self.device}")
|
| 475 |
+
logger.info(
|
| 476 |
+
f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
|
| 477 |
+
)
|
| 478 |
+
logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
|
| 479 |
+
with self.device:
|
| 480 |
+
self.decode_formats = decode_formats
|
| 481 |
+
self.pad_size = pad_size
|
| 482 |
+
self.version = version
|
| 483 |
+
self.ss_condition_input_mapping = ss_condition_input_mapping
|
| 484 |
+
self.slat_condition_input_mapping = (
|
| 485 |
+
slat_condition_input_mapping
|
| 486 |
+
)
|
| 487 |
+
self.workspace_dir = workspace_dir
|
| 488 |
+
self.downsample_ss_dist = downsample_ss_dist
|
| 489 |
+
self.ss_inference_steps = ss_inference_steps
|
| 490 |
+
self.ss_rescale_t = ss_rescale_t
|
| 491 |
+
self.ss_cfg_strength = ss_cfg_strength
|
| 492 |
+
self.ss_cfg_interval = ss_cfg_interval
|
| 493 |
+
self.ss_cfg_strength_pm = ss_cfg_strength_pm
|
| 494 |
+
self.slat_inference_steps = slat_inference_steps
|
| 495 |
+
self.slat_rescale_t = slat_rescale_t
|
| 496 |
+
self.slat_cfg_strength = slat_cfg_strength
|
| 497 |
+
self.slat_cfg_interval = slat_cfg_interval
|
| 498 |
+
|
| 499 |
+
self.dtype = self._get_dtype(dtype)
|
| 500 |
+
if shape_model_dtype is None:
|
| 501 |
+
self.shape_model_dtype = self.dtype
|
| 502 |
+
else:
|
| 503 |
+
self.shape_model_dtype = self._get_dtype(shape_model_dtype)
|
| 504 |
+
|
| 505 |
+
# Setup preprocessors
|
| 506 |
+
self.pose_decoder = self.init_pose_decoder(
|
| 507 |
+
ss_generator_config_path, pose_decoder_name
|
| 508 |
+
)
|
| 509 |
+
self.ss_preprocessor = self.init_ss_preprocessor(
|
| 510 |
+
ss_preprocessor, ss_generator_config_path
|
| 511 |
+
)
|
| 512 |
+
self.slat_preprocessor = slat_preprocessor
|
| 513 |
+
|
| 514 |
+
logger.info("Loading model weights...")
|
| 515 |
+
raw_device = self.device
|
| 516 |
+
self.device = torch.device("cpu")
|
| 517 |
+
ss_generator = self.init_ss_generator(
|
| 518 |
+
ss_generator_config_path, ss_generator_ckpt_path
|
| 519 |
+
)
|
| 520 |
+
slat_generator = self.init_slat_generator(
|
| 521 |
+
slat_generator_config_path, slat_generator_ckpt_path
|
| 522 |
+
)
|
| 523 |
+
ss_decoder = self.init_ss_decoder(
|
| 524 |
+
ss_decoder_config_path, ss_decoder_ckpt_path
|
| 525 |
+
)
|
| 526 |
+
ss_encoder = self.init_ss_encoder(
|
| 527 |
+
ss_encoder_config_path, ss_encoder_ckpt_path
|
| 528 |
+
)
|
| 529 |
+
slat_decoder_gs = self.init_slat_decoder_gs(
|
| 530 |
+
slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
|
| 531 |
+
)
|
| 532 |
+
slat_decoder_gs_4 = self.init_slat_decoder_gs(
|
| 533 |
+
slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
|
| 534 |
+
)
|
| 535 |
+
slat_decoder_mesh = self.init_slat_decoder_mesh(
|
| 536 |
+
slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Load conditioner embedder so that we only load it once
|
| 540 |
+
ss_condition_embedder = self.init_ss_condition_embedder(
|
| 541 |
+
ss_generator_config_path, ss_generator_ckpt_path
|
| 542 |
+
)
|
| 543 |
+
slat_condition_embedder = self.init_slat_condition_embedder(
|
| 544 |
+
slat_generator_config_path, slat_generator_ckpt_path
|
| 545 |
+
)
|
| 546 |
+
self.device = raw_device
|
| 547 |
+
|
| 548 |
+
self.condition_embedders = {
|
| 549 |
+
"ss_condition_embedder": ss_condition_embedder,
|
| 550 |
+
"slat_condition_embedder": slat_condition_embedder,
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
# override generator and condition embedder setting
|
| 554 |
+
self.override_ss_generator_cfg_config(
|
| 555 |
+
ss_generator,
|
| 556 |
+
cfg_strength=ss_cfg_strength,
|
| 557 |
+
inference_steps=ss_inference_steps,
|
| 558 |
+
rescale_t=ss_rescale_t,
|
| 559 |
+
cfg_interval=ss_cfg_interval,
|
| 560 |
+
cfg_strength_pm=ss_cfg_strength_pm,
|
| 561 |
+
)
|
| 562 |
+
self.override_slat_generator_cfg_config(
|
| 563 |
+
slat_generator,
|
| 564 |
+
cfg_strength=slat_cfg_strength,
|
| 565 |
+
inference_steps=slat_inference_steps,
|
| 566 |
+
rescale_t=slat_rescale_t,
|
| 567 |
+
cfg_interval=slat_cfg_interval,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
self.models = torch.nn.ModuleDict(
|
| 571 |
+
{
|
| 572 |
+
"ss_generator": ss_generator,
|
| 573 |
+
"slat_generator": slat_generator,
|
| 574 |
+
"ss_encoder": ss_encoder,
|
| 575 |
+
"ss_decoder": ss_decoder,
|
| 576 |
+
"slat_decoder_gs": slat_decoder_gs,
|
| 577 |
+
"slat_decoder_gs_4": slat_decoder_gs_4,
|
| 578 |
+
"slat_decoder_mesh": slat_decoder_mesh,
|
| 579 |
+
}
|
| 580 |
+
)
|
| 581 |
+
logger.info("Loading model weights completed!")
|
| 582 |
+
|
| 583 |
+
if self.compile_model:
|
| 584 |
+
logger.info("Compiling model...")
|
| 585 |
+
self._compile()
|
| 586 |
+
logger.info("Model compilation completed!")
|
| 587 |
+
self.slat_mean = torch.tensor(slat_mean)
|
| 588 |
+
self.slat_std = torch.tensor(slat_std)
|
| 589 |
+
|
| 590 |
+
InferencePipeline.__init__ = patch_init
|
| 591 |
+
|
| 592 |
+
patch_pointmap_infer_pipeline()
|
| 593 |
+
patch_infer_init()
|
| 594 |
+
|
| 595 |
+
return
|
embodied_gen/utils/process_media.py
CHANGED
|
@@ -230,6 +230,29 @@ def filter_image_small_connected_components(
|
|
| 230 |
return image
|
| 231 |
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def combine_images_to_grid(
|
| 234 |
images: list[str | Image.Image],
|
| 235 |
cat_row_col: tuple[int, int] = None,
|
|
@@ -439,7 +462,7 @@ class SceneTreeVisualizer:
|
|
| 439 |
plt.axis("off")
|
| 440 |
|
| 441 |
legend_handles = [
|
| 442 |
-
Patch(facecolor=color, edgecolor=
|
| 443 |
for role, color in self.role_colors.items()
|
| 444 |
]
|
| 445 |
plt.legend(
|
|
@@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict:
|
|
| 465 |
dict: Mapping from scene ID to description.
|
| 466 |
"""
|
| 467 |
scene_dict = {}
|
| 468 |
-
with open(file_path, "r", encoding=
|
| 469 |
for line in f:
|
| 470 |
line = line.strip()
|
| 471 |
if not line or ":" not in line:
|
|
@@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool:
|
|
| 487 |
"""
|
| 488 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 489 |
|
| 490 |
-
return mime_type is not None and mime_type.startswith(
|
| 491 |
|
| 492 |
|
| 493 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
|
|
|
| 230 |
return image
|
| 231 |
|
| 232 |
|
| 233 |
+
def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image:
|
| 234 |
+
if pil_img.mode != "RGBA":
|
| 235 |
+
pil_img = pil_img.convert("RGBA")
|
| 236 |
+
|
| 237 |
+
img_arr = np.array(pil_img)
|
| 238 |
+
alpha_channel = img_arr[:, :, 3]
|
| 239 |
+
|
| 240 |
+
_, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)
|
| 241 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
| 242 |
+
binary_mask, connectivity=8
|
| 243 |
+
)
|
| 244 |
+
if num_labels < 2:
|
| 245 |
+
return pil_img
|
| 246 |
+
|
| 247 |
+
largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
|
| 248 |
+
new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype(
|
| 249 |
+
np.uint8
|
| 250 |
+
)
|
| 251 |
+
img_arr[:, :, 3] = new_alpha
|
| 252 |
+
|
| 253 |
+
return Image.fromarray(img_arr)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
def combine_images_to_grid(
|
| 257 |
images: list[str | Image.Image],
|
| 258 |
cat_row_col: tuple[int, int] = None,
|
|
|
|
| 462 |
plt.axis("off")
|
| 463 |
|
| 464 |
legend_handles = [
|
| 465 |
+
Patch(facecolor=color, edgecolor="black", label=role)
|
| 466 |
for role, color in self.role_colors.items()
|
| 467 |
]
|
| 468 |
plt.legend(
|
|
|
|
| 488 |
dict: Mapping from scene ID to description.
|
| 489 |
"""
|
| 490 |
scene_dict = {}
|
| 491 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 492 |
for line in f:
|
| 493 |
line = line.strip()
|
| 494 |
if not line or ":" not in line:
|
|
|
|
| 510 |
"""
|
| 511 |
mime_type, _ = mimetypes.guess_type(filename)
|
| 512 |
|
| 513 |
+
return mime_type is not None and mime_type.startswith("image")
|
| 514 |
|
| 515 |
|
| 516 |
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
embodied_gen/utils/trender.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import sys
|
|
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import spaces
|
|
@@ -25,10 +26,8 @@ from tqdm import tqdm
|
|
| 25 |
current_file_path = os.path.abspath(__file__)
|
| 26 |
current_dir = os.path.dirname(current_file_path)
|
| 27 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 28 |
-
from thirdparty.TRELLIS.trellis.renderers
|
| 29 |
-
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
| 30 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
| 31 |
-
render_frames,
|
| 32 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 33 |
)
|
| 34 |
|
|
@@ -38,7 +37,7 @@ __all__ = [
|
|
| 38 |
|
| 39 |
|
| 40 |
@spaces.GPU
|
| 41 |
-
def
|
| 42 |
renderer = MeshRenderer()
|
| 43 |
renderer.rendering_options.resolution = options.get("resolution", 512)
|
| 44 |
renderer.rendering_options.near = options.get("near", 1)
|
|
@@ -60,6 +59,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
|
| 60 |
return rets
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
@spaces.GPU
|
| 64 |
def render_video(
|
| 65 |
sample,
|
|
@@ -77,7 +127,9 @@ def render_video(
|
|
| 77 |
yaws, pitch, r, fov
|
| 78 |
)
|
| 79 |
render_fn = (
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
result = render_fn(
|
| 83 |
sample,
|
|
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import sys
|
| 19 |
+
from collections import defaultdict
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
import spaces
|
|
|
|
| 26 |
current_file_path = os.path.abspath(__file__)
|
| 27 |
current_dir = os.path.dirname(current_file_path)
|
| 28 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 29 |
+
from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
|
|
|
|
| 30 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
|
|
|
| 31 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 32 |
)
|
| 33 |
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
@spaces.GPU
|
| 40 |
+
def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs):
|
| 41 |
renderer = MeshRenderer()
|
| 42 |
renderer.rendering_options.resolution = options.get("resolution", 512)
|
| 43 |
renderer.rendering_options.near = options.get("near", 1)
|
|
|
|
| 59 |
return rets
|
| 60 |
|
| 61 |
|
| 62 |
+
@spaces.GPU
|
| 63 |
+
def render_gs_frames(
|
| 64 |
+
sample,
|
| 65 |
+
extrinsics,
|
| 66 |
+
intrinsics,
|
| 67 |
+
options=None,
|
| 68 |
+
colors_overwrite=None,
|
| 69 |
+
verbose=True,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
def to_img(tensor):
|
| 73 |
+
return np.clip(
|
| 74 |
+
tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
| 75 |
+
).astype(np.uint8)
|
| 76 |
+
|
| 77 |
+
def to_numpy(tensor):
|
| 78 |
+
return tensor.detach().cpu().numpy()
|
| 79 |
+
|
| 80 |
+
renderer = GaussianRenderer()
|
| 81 |
+
renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1)
|
| 82 |
+
renderer.pipe.use_mip_gaussian = True
|
| 83 |
+
|
| 84 |
+
defaults = {
|
| 85 |
+
"resolution": 512,
|
| 86 |
+
"near": 0.8,
|
| 87 |
+
"far": 1.6,
|
| 88 |
+
"bg_color": (0, 0, 0),
|
| 89 |
+
"ssaa": 1,
|
| 90 |
+
}
|
| 91 |
+
final_options = {**defaults, **(options or {})}
|
| 92 |
+
|
| 93 |
+
for k, v in final_options.items():
|
| 94 |
+
if hasattr(renderer.rendering_options, k):
|
| 95 |
+
setattr(renderer.rendering_options, k, v)
|
| 96 |
+
|
| 97 |
+
outputs = defaultdict(list)
|
| 98 |
+
iterator = zip(extrinsics, intrinsics)
|
| 99 |
+
if verbose:
|
| 100 |
+
iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering")
|
| 101 |
+
|
| 102 |
+
for extr, intr in iterator:
|
| 103 |
+
res = renderer.render(
|
| 104 |
+
sample, extr, intr, colors_overwrite=colors_overwrite
|
| 105 |
+
)
|
| 106 |
+
outputs["color"].append(to_img(res["color"]))
|
| 107 |
+
depth = res.get("percent_depth") or res.get("depth")
|
| 108 |
+
outputs["depth"].append(to_numpy(depth) if depth is not None else None)
|
| 109 |
+
|
| 110 |
+
return dict(outputs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
@spaces.GPU
|
| 114 |
def render_video(
|
| 115 |
sample,
|
|
|
|
| 127 |
yaws, pitch, r, fov
|
| 128 |
)
|
| 129 |
render_fn = (
|
| 130 |
+
render_mesh_frames
|
| 131 |
+
if sample.__class__.__name__ == "MeshExtractResult"
|
| 132 |
+
else render_gs_frames
|
| 133 |
)
|
| 134 |
result = render_fn(
|
| 135 |
sample,
|
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py
CHANGED
|
@@ -440,7 +440,7 @@ def to_glb(
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
-
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
|
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|