Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
263611e
1
Parent(s):
1599289
update
Browse files- app.py +30 -6
- app_style.py +17 -1
- common.py +95 -8
- embodied_gen/data/backproject_v3.py +7 -7
- embodied_gen/data/utils.py +40 -144
- embodied_gen/models/sam3d.py +145 -0
- embodied_gen/utils/monkey_patches.py +381 -0
- embodied_gen/utils/trender.py +57 -5
- requirements.txt +9 -1
- thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py +1 -1
app.py
CHANGED
|
@@ -19,6 +19,9 @@ import os
|
|
| 19 |
|
| 20 |
os.environ["GRADIO_APP"] = "textto3d"
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
import gradio as gr
|
| 24 |
from app_style import custom_theme, image_css, lighting_css
|
|
@@ -32,11 +35,22 @@ from common import (
|
|
| 32 |
get_cached_image,
|
| 33 |
get_seed,
|
| 34 |
get_selected_image,
|
| 35 |
-
image_to_3d,
|
| 36 |
start_session,
|
| 37 |
text2image_fn,
|
| 38 |
)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
| 41 |
gr.HTML(image_css, visible=False)
|
| 42 |
# gr.HTML(lighting_css, visible=False)
|
|
@@ -162,7 +176,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 162 |
step=0.1,
|
| 163 |
)
|
| 164 |
ss_sampling_steps = gr.Slider(
|
| 165 |
-
1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
)
|
| 167 |
gr.Markdown("Visual Appearance Generation")
|
| 168 |
with gr.Row():
|
|
@@ -174,7 +192,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 174 |
step=0.1,
|
| 175 |
)
|
| 176 |
slat_sampling_steps = gr.Slider(
|
| 177 |
-
1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
|
| 180 |
generate_btn = gr.Button(
|
|
@@ -285,7 +307,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 285 |
|
| 286 |
model_output_mesh = gr.Model3D(
|
| 287 |
label="Mesh Representation",
|
| 288 |
-
clear_color=[0
|
| 289 |
height=300,
|
| 290 |
interactive=False,
|
| 291 |
elem_id="lighter_mesh",
|
|
@@ -323,6 +345,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 323 |
)
|
| 324 |
|
| 325 |
output_buf = gr.State()
|
|
|
|
| 326 |
|
| 327 |
demo.load(start_session)
|
| 328 |
demo.unload(end_session)
|
|
@@ -389,6 +412,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 389 |
img_resolution,
|
| 390 |
rmbg_tag,
|
| 391 |
seed,
|
|
|
|
| 392 |
],
|
| 393 |
outputs=[
|
| 394 |
image_sample1,
|
|
@@ -420,11 +444,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 420 |
inputs=[
|
| 421 |
select_img,
|
| 422 |
seed,
|
| 423 |
-
ss_guidance_strength,
|
| 424 |
ss_sampling_steps,
|
| 425 |
-
slat_guidance_strength,
|
| 426 |
slat_sampling_steps,
|
| 427 |
raw_image_cache,
|
|
|
|
|
|
|
| 428 |
],
|
| 429 |
outputs=[output_buf, video_output],
|
| 430 |
).success(
|
|
|
|
| 19 |
|
| 20 |
os.environ["GRADIO_APP"] = "textto3d"
|
| 21 |
|
| 22 |
+
# GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
|
| 23 |
+
# GRADIO_APP == "textto3d", TRELLIS model.
|
| 24 |
+
os.environ["GRADIO_APP"] = "textto3d_sam3d"
|
| 25 |
|
| 26 |
import gradio as gr
|
| 27 |
from app_style import custom_theme, image_css, lighting_css
|
|
|
|
| 35 |
get_cached_image,
|
| 36 |
get_seed,
|
| 37 |
get_selected_image,
|
|
|
|
| 38 |
start_session,
|
| 39 |
text2image_fn,
|
| 40 |
)
|
| 41 |
|
| 42 |
+
app_name = os.getenv("GRADIO_APP")
|
| 43 |
+
if app_name == "textto3d_sam3d":
|
| 44 |
+
from common import image_to_3d_sam3d as image_to_3d
|
| 45 |
+
|
| 46 |
+
enable_pre_resize = False
|
| 47 |
+
sample_step = 25
|
| 48 |
+
elif app_name == "textto3d":
|
| 49 |
+
from common import image_to_3d
|
| 50 |
+
|
| 51 |
+
enable_pre_resize = True
|
| 52 |
+
sample_step = 12
|
| 53 |
+
|
| 54 |
with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
| 55 |
gr.HTML(image_css, visible=False)
|
| 56 |
# gr.HTML(lighting_css, visible=False)
|
|
|
|
| 176 |
step=0.1,
|
| 177 |
)
|
| 178 |
ss_sampling_steps = gr.Slider(
|
| 179 |
+
1,
|
| 180 |
+
50,
|
| 181 |
+
label="Sampling Steps",
|
| 182 |
+
value=sample_step,
|
| 183 |
+
step=1,
|
| 184 |
)
|
| 185 |
gr.Markdown("Visual Appearance Generation")
|
| 186 |
with gr.Row():
|
|
|
|
| 192 |
step=0.1,
|
| 193 |
)
|
| 194 |
slat_sampling_steps = gr.Slider(
|
| 195 |
+
1,
|
| 196 |
+
50,
|
| 197 |
+
label="Sampling Steps",
|
| 198 |
+
value=sample_step,
|
| 199 |
+
step=1,
|
| 200 |
)
|
| 201 |
|
| 202 |
generate_btn = gr.Button(
|
|
|
|
| 307 |
|
| 308 |
model_output_mesh = gr.Model3D(
|
| 309 |
label="Mesh Representation",
|
| 310 |
+
clear_color=[0, 0, 0, 1],
|
| 311 |
height=300,
|
| 312 |
interactive=False,
|
| 313 |
elem_id="lighter_mesh",
|
|
|
|
| 345 |
)
|
| 346 |
|
| 347 |
output_buf = gr.State()
|
| 348 |
+
enable_pre_resize = gr.State(enable_pre_resize)
|
| 349 |
|
| 350 |
demo.load(start_session)
|
| 351 |
demo.unload(end_session)
|
|
|
|
| 412 |
img_resolution,
|
| 413 |
rmbg_tag,
|
| 414 |
seed,
|
| 415 |
+
enable_pre_resize,
|
| 416 |
],
|
| 417 |
outputs=[
|
| 418 |
image_sample1,
|
|
|
|
| 444 |
inputs=[
|
| 445 |
select_img,
|
| 446 |
seed,
|
|
|
|
| 447 |
ss_sampling_steps,
|
|
|
|
| 448 |
slat_sampling_steps,
|
| 449 |
raw_image_cache,
|
| 450 |
+
ss_guidance_strength,
|
| 451 |
+
slat_guidance_strength,
|
| 452 |
],
|
| 453 |
outputs=[output_buf, video_output],
|
| 454 |
).success(
|
app_style.py
CHANGED
|
@@ -1,10 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from gradio.themes import Soft
|
| 2 |
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
| 3 |
|
| 4 |
lighting_css = """
|
| 5 |
<style>
|
| 6 |
#lighter_mesh canvas {
|
| 7 |
-
filter: brightness(2.
|
| 8 |
}
|
| 9 |
</style>
|
| 10 |
"""
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
from gradio.themes import Soft
|
| 18 |
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
| 19 |
|
| 20 |
lighting_css = """
|
| 21 |
<style>
|
| 22 |
#lighter_mesh canvas {
|
| 23 |
+
filter: brightness(2.3) !important;
|
| 24 |
}
|
| 25 |
</style>
|
| 26 |
"""
|
common.py
CHANGED
|
@@ -151,6 +151,21 @@ if os.getenv("GRADIO_APP") == "imageto3d":
|
|
| 151 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 152 |
)
|
| 153 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
elif os.getenv("GRADIO_APP") == "textto3d":
|
| 155 |
RBG_REMOVER = RembgRemover()
|
| 156 |
RBG14_REMOVER = BMGG14Remover()
|
|
@@ -169,6 +184,23 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|
| 169 |
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 170 |
)
|
| 171 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 173 |
DELIGHT = DelightingModel()
|
| 174 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
@@ -201,18 +233,22 @@ def end_session(req: gr.Request) -> None:
|
|
| 201 |
|
| 202 |
@spaces.GPU
|
| 203 |
def preprocess_image_fn(
|
| 204 |
-
image: str | np.ndarray | Image.Image,
|
|
|
|
|
|
|
| 205 |
) -> tuple[Image.Image, Image.Image]:
|
| 206 |
if isinstance(image, str):
|
| 207 |
image = Image.open(image)
|
| 208 |
elif isinstance(image, np.ndarray):
|
| 209 |
image = Image.fromarray(image)
|
| 210 |
|
| 211 |
-
image_cache = resize_pil(image.copy(), 1024)
|
| 212 |
|
| 213 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 214 |
image = bg_remover(image)
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return image, image_cache
|
| 218 |
|
|
@@ -349,11 +385,11 @@ def select_point(
|
|
| 349 |
def image_to_3d(
|
| 350 |
image: Image.Image,
|
| 351 |
seed: int,
|
| 352 |
-
ss_guidance_strength: float,
|
| 353 |
ss_sampling_steps: int,
|
| 354 |
-
slat_guidance_strength: float,
|
| 355 |
slat_sampling_steps: int,
|
| 356 |
raw_image_cache: Image.Image,
|
|
|
|
|
|
|
| 357 |
sam_image: Image.Image = None,
|
| 358 |
is_sam_image: bool = False,
|
| 359 |
req: gr.Request = None,
|
|
@@ -392,8 +428,56 @@ def image_to_3d(
|
|
| 392 |
|
| 393 |
gs_model = outputs["gaussian"][0]
|
| 394 |
mesh_model = outputs["mesh"][0]
|
| 395 |
-
color_images = render_video(gs_model)["color"]
|
| 396 |
-
normal_images = render_video(mesh_model)["normal"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 399 |
merge_images_video(color_images, normal_images, video_path)
|
|
@@ -688,6 +772,7 @@ def text2image_fn(
|
|
| 688 |
image_wh: int | tuple[int, int] = [1024, 1024],
|
| 689 |
rmbg_tag: str = "rembg",
|
| 690 |
seed: int = None,
|
|
|
|
| 691 |
n_sample: int = 3,
|
| 692 |
req: gr.Request = None,
|
| 693 |
):
|
|
@@ -715,7 +800,9 @@ def text2image_fn(
|
|
| 715 |
|
| 716 |
for idx in range(len(images)):
|
| 717 |
image = images[idx]
|
| 718 |
-
images[idx], _ = preprocess_image_fn(
|
|
|
|
|
|
|
| 719 |
|
| 720 |
save_paths = []
|
| 721 |
for idx, image in enumerate(images):
|
|
|
|
| 151 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 152 |
)
|
| 153 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 154 |
+
elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
|
| 155 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 156 |
+
|
| 157 |
+
RBG_REMOVER = RembgRemover()
|
| 158 |
+
RBG14_REMOVER = BMGG14Remover()
|
| 159 |
+
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 160 |
+
PIPELINE = Sam3dInference()
|
| 161 |
+
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 162 |
+
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 163 |
+
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 164 |
+
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 165 |
+
TMP_DIR = os.path.join(
|
| 166 |
+
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 167 |
+
)
|
| 168 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 169 |
elif os.getenv("GRADIO_APP") == "textto3d":
|
| 170 |
RBG_REMOVER = RembgRemover()
|
| 171 |
RBG14_REMOVER = BMGG14Remover()
|
|
|
|
| 184 |
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 185 |
)
|
| 186 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 187 |
+
elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
|
| 188 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 189 |
+
|
| 190 |
+
RBG_REMOVER = RembgRemover()
|
| 191 |
+
RBG14_REMOVER = BMGG14Remover()
|
| 192 |
+
PIPELINE = Sam3dInference()
|
| 193 |
+
text_model_dir = "weights/Kolors"
|
| 194 |
+
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
| 195 |
+
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
| 196 |
+
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 197 |
+
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 198 |
+
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 199 |
+
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 200 |
+
TMP_DIR = os.path.join(
|
| 201 |
+
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 202 |
+
)
|
| 203 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
| 204 |
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 205 |
DELIGHT = DelightingModel()
|
| 206 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
|
|
| 233 |
|
| 234 |
@spaces.GPU
|
| 235 |
def preprocess_image_fn(
|
| 236 |
+
image: str | np.ndarray | Image.Image,
|
| 237 |
+
rmbg_tag: str = "rembg",
|
| 238 |
+
preprocess: bool = True,
|
| 239 |
) -> tuple[Image.Image, Image.Image]:
|
| 240 |
if isinstance(image, str):
|
| 241 |
image = Image.open(image)
|
| 242 |
elif isinstance(image, np.ndarray):
|
| 243 |
image = Image.fromarray(image)
|
| 244 |
|
| 245 |
+
image_cache = image.copy() # resize_pil(image.copy(), 1024)
|
| 246 |
|
| 247 |
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
| 248 |
image = bg_remover(image)
|
| 249 |
+
|
| 250 |
+
if preprocess:
|
| 251 |
+
image = trellis_preprocess(image)
|
| 252 |
|
| 253 |
return image, image_cache
|
| 254 |
|
|
|
|
| 385 |
def image_to_3d(
|
| 386 |
image: Image.Image,
|
| 387 |
seed: int,
|
|
|
|
| 388 |
ss_sampling_steps: int,
|
|
|
|
| 389 |
slat_sampling_steps: int,
|
| 390 |
raw_image_cache: Image.Image,
|
| 391 |
+
ss_guidance_strength: float,
|
| 392 |
+
slat_guidance_strength: float,
|
| 393 |
sam_image: Image.Image = None,
|
| 394 |
is_sam_image: bool = False,
|
| 395 |
req: gr.Request = None,
|
|
|
|
| 428 |
|
| 429 |
gs_model = outputs["gaussian"][0]
|
| 430 |
mesh_model = outputs["mesh"][0]
|
| 431 |
+
color_images = render_video(gs_model, r=1.85)["color"]
|
| 432 |
+
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 433 |
+
|
| 434 |
+
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 435 |
+
merge_images_video(color_images, normal_images, video_path)
|
| 436 |
+
state = pack_state(gs_model, mesh_model)
|
| 437 |
+
|
| 438 |
+
gc.collect()
|
| 439 |
+
torch.cuda.empty_cache()
|
| 440 |
+
|
| 441 |
+
return state, video_path
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@spaces.GPU
|
| 445 |
+
def image_to_3d_sam3d(
|
| 446 |
+
image: Image.Image,
|
| 447 |
+
seed: int,
|
| 448 |
+
ss_sampling_steps: int,
|
| 449 |
+
slat_sampling_steps: int,
|
| 450 |
+
raw_image_cache: Image.Image,
|
| 451 |
+
ss_guidance_strength: float = None,
|
| 452 |
+
slat_guidance_strength: float = None,
|
| 453 |
+
sam_image: Image.Image = None,
|
| 454 |
+
is_sam_image: bool = False,
|
| 455 |
+
req: gr.Request = None,
|
| 456 |
+
) -> tuple[dict, str]:
|
| 457 |
+
if is_sam_image:
|
| 458 |
+
seg_image = filter_image_small_connected_components(sam_image)
|
| 459 |
+
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
| 460 |
+
else:
|
| 461 |
+
seg_image = image
|
| 462 |
+
|
| 463 |
+
if isinstance(seg_image, np.ndarray):
|
| 464 |
+
seg_image = Image.fromarray(seg_image)
|
| 465 |
+
|
| 466 |
+
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 467 |
+
os.makedirs(output_root, exist_ok=True)
|
| 468 |
+
seg_image.save(f"{output_root}/seg_image.png")
|
| 469 |
+
raw_image_cache.save(f"{output_root}/raw_image.png")
|
| 470 |
+
outputs = PIPELINE.run(
|
| 471 |
+
seg_image,
|
| 472 |
+
seed=seed,
|
| 473 |
+
stage1_inference_steps=ss_sampling_steps,
|
| 474 |
+
stage2_inference_steps=slat_sampling_steps,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
gs_model = outputs["gaussian"][0]
|
| 478 |
+
mesh_model = outputs["mesh"][0]
|
| 479 |
+
color_images = render_video(gs_model, r=1.85)["color"]
|
| 480 |
+
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 481 |
|
| 482 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 483 |
merge_images_video(color_images, normal_images, video_path)
|
|
|
|
| 772 |
image_wh: int | tuple[int, int] = [1024, 1024],
|
| 773 |
rmbg_tag: str = "rembg",
|
| 774 |
seed: int = None,
|
| 775 |
+
enable_pre_resize: bool = True,
|
| 776 |
n_sample: int = 3,
|
| 777 |
req: gr.Request = None,
|
| 778 |
):
|
|
|
|
| 800 |
|
| 801 |
for idx in range(len(images)):
|
| 802 |
image = images[idx]
|
| 803 |
+
images[idx], _ = preprocess_image_fn(
|
| 804 |
+
image, rmbg_tag, enable_pre_resize
|
| 805 |
+
)
|
| 806 |
|
| 807 |
save_paths = []
|
| 808 |
for idx, image in enumerate(images):
|
embodied_gen/data/backproject_v3.py
CHANGED
|
@@ -14,10 +14,10 @@
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
-
import os
|
| 18 |
import argparse
|
| 19 |
import logging
|
| 20 |
import math
|
|
|
|
| 21 |
from typing import Literal, Union
|
| 22 |
|
| 23 |
import cv2
|
|
@@ -353,8 +353,8 @@ def parse_args():
|
|
| 353 |
parser.add_argument(
|
| 354 |
"--distance",
|
| 355 |
type=float,
|
| 356 |
-
default=5,
|
| 357 |
-
help="Camera distance (default: 5)",
|
| 358 |
)
|
| 359 |
parser.add_argument(
|
| 360 |
"--resolution_hw",
|
|
@@ -400,8 +400,8 @@ def parse_args():
|
|
| 400 |
parser.add_argument(
|
| 401 |
"--mesh_sipmlify_ratio",
|
| 402 |
type=float,
|
| 403 |
-
default=0.
|
| 404 |
-
help="Mesh simplification ratio (default: 0.
|
| 405 |
)
|
| 406 |
parser.add_argument(
|
| 407 |
"--delight", action="store_true", help="Use delighting model."
|
|
@@ -500,7 +500,7 @@ def entrypoint(
|
|
| 500 |
faces = mesh.faces.astype(np.int32)
|
| 501 |
vertices = vertices.astype(np.float32)
|
| 502 |
|
| 503 |
-
if not args.skip_fix_mesh
|
| 504 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 505 |
vertices, faces = mesh_fixer(
|
| 506 |
filter_ratio=args.mesh_sipmlify_ratio,
|
|
@@ -512,7 +512,7 @@ def entrypoint(
|
|
| 512 |
if len(faces) > args.n_max_faces:
|
| 513 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 514 |
vertices, faces = mesh_fixer(
|
| 515 |
-
filter_ratio=max(0.
|
| 516 |
max_hole_size=0.04,
|
| 517 |
resolution=1024,
|
| 518 |
num_views=1000,
|
|
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
|
|
|
| 17 |
import argparse
|
| 18 |
import logging
|
| 19 |
import math
|
| 20 |
+
import os
|
| 21 |
from typing import Literal, Union
|
| 22 |
|
| 23 |
import cv2
|
|
|
|
| 353 |
parser.add_argument(
|
| 354 |
"--distance",
|
| 355 |
type=float,
|
| 356 |
+
default=4.5,
|
| 357 |
+
help="Camera distance (default: 4.5)",
|
| 358 |
)
|
| 359 |
parser.add_argument(
|
| 360 |
"--resolution_hw",
|
|
|
|
| 400 |
parser.add_argument(
|
| 401 |
"--mesh_sipmlify_ratio",
|
| 402 |
type=float,
|
| 403 |
+
default=0.85,
|
| 404 |
+
help="Mesh simplification ratio (default: 0.85)",
|
| 405 |
)
|
| 406 |
parser.add_argument(
|
| 407 |
"--delight", action="store_true", help="Use delighting model."
|
|
|
|
| 500 |
faces = mesh.faces.astype(np.int32)
|
| 501 |
vertices = vertices.astype(np.float32)
|
| 502 |
|
| 503 |
+
if not args.skip_fix_mesh:
|
| 504 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 505 |
vertices, faces = mesh_fixer(
|
| 506 |
filter_ratio=args.mesh_sipmlify_ratio,
|
|
|
|
| 512 |
if len(faces) > args.n_max_faces:
|
| 513 |
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
| 514 |
vertices, faces = mesh_fixer(
|
| 515 |
+
filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1),
|
| 516 |
max_hole_size=0.04,
|
| 517 |
resolution=1024,
|
| 518 |
num_views=1000,
|
embodied_gen/data/utils.py
CHANGED
|
@@ -15,10 +15,13 @@
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
|
|
|
|
| 18 |
import math
|
| 19 |
import os
|
| 20 |
-
import
|
| 21 |
import zipfile
|
|
|
|
|
|
|
| 22 |
from shutil import rmtree
|
| 23 |
from typing import List, Tuple, Union
|
| 24 |
|
|
@@ -28,20 +31,9 @@ import numpy as np
|
|
| 28 |
import nvdiffrast.torch as dr
|
| 29 |
import torch
|
| 30 |
import torch.nn.functional as F
|
| 31 |
-
from PIL import Image, ImageEnhance
|
| 32 |
-
|
| 33 |
-
try:
|
| 34 |
-
from kolors.models.modeling_chatglm import ChatGLMModel
|
| 35 |
-
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
|
| 36 |
-
except ImportError:
|
| 37 |
-
ChatGLMTokenizer = None
|
| 38 |
-
ChatGLMModel = None
|
| 39 |
-
import logging
|
| 40 |
-
from dataclasses import dataclass, field
|
| 41 |
-
|
| 42 |
import trimesh
|
| 43 |
from kaolin.render.camera import Camera
|
| 44 |
-
from
|
| 45 |
|
| 46 |
logger = logging.getLogger(__name__)
|
| 47 |
|
|
@@ -50,10 +42,8 @@ __all__ = [
|
|
| 50 |
"DiffrastRender",
|
| 51 |
"save_images",
|
| 52 |
"render_pbr",
|
| 53 |
-
"prelabel_text_feature",
|
| 54 |
"calc_vertex_normals",
|
| 55 |
"normalize_vertices_array",
|
| 56 |
-
"load_mesh_to_unit_cube",
|
| 57 |
"as_list",
|
| 58 |
"CameraSetting",
|
| 59 |
"import_kaolin_mesh",
|
|
@@ -67,6 +57,7 @@ __all__ = [
|
|
| 67 |
"trellis_preprocess",
|
| 68 |
"delete_dir",
|
| 69 |
"kaolin_to_opencv_view",
|
|
|
|
| 70 |
]
|
| 71 |
|
| 72 |
|
|
@@ -520,114 +511,6 @@ def render_pbr(
|
|
| 520 |
return image, albedo, diffuse, normal
|
| 521 |
|
| 522 |
|
| 523 |
-
def _move_to_target_device(data, device: str):
|
| 524 |
-
if isinstance(data, dict):
|
| 525 |
-
for key, value in data.items():
|
| 526 |
-
data[key] = _move_to_target_device(value, device)
|
| 527 |
-
elif isinstance(data, torch.Tensor):
|
| 528 |
-
return data.to(device)
|
| 529 |
-
|
| 530 |
-
return data
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
def _encode_prompt(
|
| 534 |
-
prompt_batch,
|
| 535 |
-
text_encoders,
|
| 536 |
-
tokenizers,
|
| 537 |
-
proportion_empty_prompts=0,
|
| 538 |
-
is_train=True,
|
| 539 |
-
):
|
| 540 |
-
prompt_embeds_list = []
|
| 541 |
-
|
| 542 |
-
captions = []
|
| 543 |
-
for caption in prompt_batch:
|
| 544 |
-
if random.random() < proportion_empty_prompts:
|
| 545 |
-
captions.append("")
|
| 546 |
-
elif isinstance(caption, str):
|
| 547 |
-
captions.append(caption)
|
| 548 |
-
elif isinstance(caption, (list, np.ndarray)):
|
| 549 |
-
captions.append(random.choice(caption) if is_train else caption[0])
|
| 550 |
-
|
| 551 |
-
with torch.no_grad():
|
| 552 |
-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 553 |
-
text_inputs = tokenizer(
|
| 554 |
-
captions,
|
| 555 |
-
padding="max_length",
|
| 556 |
-
max_length=256,
|
| 557 |
-
truncation=True,
|
| 558 |
-
return_tensors="pt",
|
| 559 |
-
).to(text_encoder.device)
|
| 560 |
-
|
| 561 |
-
output = text_encoder(
|
| 562 |
-
input_ids=text_inputs.input_ids,
|
| 563 |
-
attention_mask=text_inputs.attention_mask,
|
| 564 |
-
position_ids=text_inputs.position_ids,
|
| 565 |
-
output_hidden_states=True,
|
| 566 |
-
)
|
| 567 |
-
|
| 568 |
-
# We are only interested in the pooled output of the text encoder.
|
| 569 |
-
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
|
| 570 |
-
pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
|
| 571 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 572 |
-
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 573 |
-
prompt_embeds_list.append(prompt_embeds)
|
| 574 |
-
|
| 575 |
-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 576 |
-
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 577 |
-
|
| 578 |
-
return prompt_embeds, pooled_prompt_embeds
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
def load_llm_models(pretrained_model_name_or_path: str, device: str):
|
| 582 |
-
tokenizer = ChatGLMTokenizer.from_pretrained(
|
| 583 |
-
pretrained_model_name_or_path,
|
| 584 |
-
subfolder="text_encoder",
|
| 585 |
-
)
|
| 586 |
-
text_encoder = ChatGLMModel.from_pretrained(
|
| 587 |
-
pretrained_model_name_or_path,
|
| 588 |
-
subfolder="text_encoder",
|
| 589 |
-
).to(device)
|
| 590 |
-
|
| 591 |
-
text_encoders = [
|
| 592 |
-
text_encoder,
|
| 593 |
-
]
|
| 594 |
-
tokenizers = [
|
| 595 |
-
tokenizer,
|
| 596 |
-
]
|
| 597 |
-
|
| 598 |
-
logger.info(f"Load model from {pretrained_model_name_or_path} done.")
|
| 599 |
-
|
| 600 |
-
return tokenizers, text_encoders
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
def prelabel_text_feature(
|
| 604 |
-
prompt_batch: List[str],
|
| 605 |
-
output_dir: str,
|
| 606 |
-
tokenizers: nn.Module,
|
| 607 |
-
text_encoders: nn.Module,
|
| 608 |
-
) -> List[str]:
|
| 609 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 610 |
-
|
| 611 |
-
# prompt_batch ["text..."]
|
| 612 |
-
prompt_embeds, pooled_prompt_embeds = _encode_prompt(
|
| 613 |
-
prompt_batch, text_encoders, tokenizers
|
| 614 |
-
)
|
| 615 |
-
|
| 616 |
-
prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
|
| 617 |
-
pooled_prompt_embeds = _move_to_target_device(
|
| 618 |
-
pooled_prompt_embeds, device="cpu"
|
| 619 |
-
)
|
| 620 |
-
|
| 621 |
-
data_dict = dict(
|
| 622 |
-
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
save_path = os.path.join(output_dir, "text_feat.pth")
|
| 626 |
-
torch.save(data_dict, save_path)
|
| 627 |
-
|
| 628 |
-
return save_path
|
| 629 |
-
|
| 630 |
-
|
| 631 |
def _calc_face_normals(
|
| 632 |
vertices: torch.Tensor, # V,3 first vertex may be unreferenced
|
| 633 |
faces: torch.Tensor, # F,3 long, first face may be all zero
|
|
@@ -683,25 +566,6 @@ def normalize_vertices_array(
|
|
| 683 |
return vertices, scale, center
|
| 684 |
|
| 685 |
|
| 686 |
-
def load_mesh_to_unit_cube(
|
| 687 |
-
mesh_file: str,
|
| 688 |
-
mesh_scale: float = 1.0,
|
| 689 |
-
) -> tuple[trimesh.Trimesh, float, list[float]]:
|
| 690 |
-
if not os.path.exists(mesh_file):
|
| 691 |
-
raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
|
| 692 |
-
|
| 693 |
-
mesh = trimesh.load(mesh_file)
|
| 694 |
-
if isinstance(mesh, trimesh.Scene):
|
| 695 |
-
mesh = trimesh.utils.concatenate(mesh)
|
| 696 |
-
|
| 697 |
-
vertices, scale, center = normalize_vertices_array(
|
| 698 |
-
mesh.vertices, mesh_scale
|
| 699 |
-
)
|
| 700 |
-
mesh.vertices = vertices
|
| 701 |
-
|
| 702 |
-
return mesh, scale, center
|
| 703 |
-
|
| 704 |
-
|
| 705 |
def as_list(obj):
|
| 706 |
if isinstance(obj, (list, tuple)):
|
| 707 |
return obj
|
|
@@ -998,8 +862,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
|
|
| 998 |
|
| 999 |
|
| 1000 |
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
|
| 1001 |
-
|
| 1002 |
-
scale = min(1,
|
|
|
|
| 1003 |
if scale < 1:
|
| 1004 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 1005 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
@@ -1068,3 +933,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
|
|
| 1068 |
rmtree(item_path)
|
| 1069 |
else:
|
| 1070 |
os.remove(item_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
|
| 18 |
+
import logging
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
+
import time
|
| 22 |
import zipfile
|
| 23 |
+
from contextlib import contextmanager
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
from shutil import rmtree
|
| 26 |
from typing import List, Tuple, Union
|
| 27 |
|
|
|
|
| 31 |
import nvdiffrast.torch as dr
|
| 32 |
import torch
|
| 33 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
import trimesh
|
| 35 |
from kaolin.render.camera import Camera
|
| 36 |
+
from PIL import Image, ImageEnhance
|
| 37 |
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
|
|
|
| 42 |
"DiffrastRender",
|
| 43 |
"save_images",
|
| 44 |
"render_pbr",
|
|
|
|
| 45 |
"calc_vertex_normals",
|
| 46 |
"normalize_vertices_array",
|
|
|
|
| 47 |
"as_list",
|
| 48 |
"CameraSetting",
|
| 49 |
"import_kaolin_mesh",
|
|
|
|
| 57 |
"trellis_preprocess",
|
| 58 |
"delete_dir",
|
| 59 |
"kaolin_to_opencv_view",
|
| 60 |
+
"model_device_ctx",
|
| 61 |
]
|
| 62 |
|
| 63 |
|
|
|
|
| 511 |
return image, albedo, diffuse, normal
|
| 512 |
|
| 513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
def _calc_face_normals(
|
| 515 |
vertices: torch.Tensor, # V,3 first vertex may be unreferenced
|
| 516 |
faces: torch.Tensor, # F,3 long, first face may be all zero
|
|
|
|
| 566 |
return vertices, scale, center
|
| 567 |
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
def as_list(obj):
|
| 570 |
if isinstance(obj, (list, tuple)):
|
| 571 |
return obj
|
|
|
|
| 862 |
|
| 863 |
|
| 864 |
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
|
| 865 |
+
current_max_dim = max(image.size)
|
| 866 |
+
scale = min(1, max_size / current_max_dim)
|
| 867 |
+
|
| 868 |
if scale < 1:
|
| 869 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 870 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
|
| 933 |
rmtree(item_path)
|
| 934 |
else:
|
| 935 |
os.remove(item_path)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@contextmanager
|
| 939 |
+
def model_device_ctx(
|
| 940 |
+
*models,
|
| 941 |
+
src_device: str = "cpu",
|
| 942 |
+
dst_device: str = "cuda",
|
| 943 |
+
verbose: bool = False,
|
| 944 |
+
):
|
| 945 |
+
start = time.perf_counter()
|
| 946 |
+
for m in models:
|
| 947 |
+
if m is None:
|
| 948 |
+
continue
|
| 949 |
+
m.to(dst_device)
|
| 950 |
+
to_cuda_time = time.perf_counter() - start
|
| 951 |
+
|
| 952 |
+
try:
|
| 953 |
+
yield
|
| 954 |
+
finally:
|
| 955 |
+
start = time.perf_counter()
|
| 956 |
+
for m in models:
|
| 957 |
+
if m is None:
|
| 958 |
+
continue
|
| 959 |
+
m.to(src_device)
|
| 960 |
+
to_cpu_time = time.perf_counter() - start
|
| 961 |
+
|
| 962 |
+
if verbose:
|
| 963 |
+
model_names = [m.__class__.__name__ for m in models]
|
| 964 |
+
logger.debug(
|
| 965 |
+
f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s"
|
| 966 |
+
)
|
embodied_gen/models/sam3d.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
|
| 18 |
+
|
| 19 |
+
monkey_patch_sam3d()
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from hydra.utils import instantiate
|
| 26 |
+
from modelscope import snapshot_download
|
| 27 |
+
from omegaconf import OmegaConf
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
current_file_path = os.path.abspath(__file__)
|
| 31 |
+
current_dir = os.path.dirname(current_file_path)
|
| 32 |
+
sys.path.append(os.path.join(current_dir, "../.."))
|
| 33 |
+
from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 34 |
+
InferencePipelinePointMap,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
__all__ = ["Sam3dInference"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_image(path: str) -> np.ndarray:
|
| 41 |
+
image = Image.open(path)
|
| 42 |
+
image = np.array(image)
|
| 43 |
+
image = image.astype(np.uint8)
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_mask(path: str) -> np.ndarray:
|
| 48 |
+
mask = load_image(path)
|
| 49 |
+
mask = mask > 0
|
| 50 |
+
if mask.ndim == 3:
|
| 51 |
+
mask = mask[..., -1]
|
| 52 |
+
return mask
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Sam3dInference:
|
| 56 |
+
def __init__(
|
| 57 |
+
self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
|
| 58 |
+
) -> None:
|
| 59 |
+
if not os.path.exists(local_dir):
|
| 60 |
+
snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
|
| 61 |
+
config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
|
| 62 |
+
config = OmegaConf.load(config_file)
|
| 63 |
+
config.rendering_engine = "nvdiffrast"
|
| 64 |
+
config.compile_model = compile
|
| 65 |
+
config.workspace_dir = os.path.dirname(config_file)
|
| 66 |
+
# Generate 4 gs in each pixel.
|
| 67 |
+
config["slat_decoder_gs_config_path"] = config.pop(
|
| 68 |
+
"slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
|
| 69 |
+
)
|
| 70 |
+
config["slat_decoder_gs_ckpt_path"] = config.pop(
|
| 71 |
+
"slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
|
| 72 |
+
)
|
| 73 |
+
self.pipeline: InferencePipelinePointMap = instantiate(config)
|
| 74 |
+
|
| 75 |
+
def merge_mask_to_rgba(
|
| 76 |
+
self, image: np.ndarray, mask: np.ndarray
|
| 77 |
+
) -> np.ndarray:
|
| 78 |
+
mask = mask.astype(np.uint8) * 255
|
| 79 |
+
mask = mask[..., None]
|
| 80 |
+
rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
|
| 81 |
+
|
| 82 |
+
return rgba_image
|
| 83 |
+
|
| 84 |
+
def run(
|
| 85 |
+
self,
|
| 86 |
+
image: np.ndarray | Image.Image,
|
| 87 |
+
mask: np.ndarray = None,
|
| 88 |
+
seed: int = None,
|
| 89 |
+
pointmap: np.ndarray = None,
|
| 90 |
+
use_stage1_distillation: bool = False,
|
| 91 |
+
use_stage2_distillation: bool = False,
|
| 92 |
+
stage1_inference_steps: int = 25,
|
| 93 |
+
stage2_inference_steps: int = 25,
|
| 94 |
+
) -> dict:
|
| 95 |
+
if isinstance(image, Image.Image):
|
| 96 |
+
image = np.array(image)
|
| 97 |
+
return self.pipeline.run(
|
| 98 |
+
image,
|
| 99 |
+
mask,
|
| 100 |
+
seed,
|
| 101 |
+
stage1_only=False,
|
| 102 |
+
with_mesh_postprocess=False,
|
| 103 |
+
with_texture_baking=False,
|
| 104 |
+
with_layout_postprocess=False,
|
| 105 |
+
use_vertex_color=True,
|
| 106 |
+
use_stage1_distillation=use_stage1_distillation,
|
| 107 |
+
use_stage2_distillation=use_stage2_distillation,
|
| 108 |
+
stage1_inference_steps=stage1_inference_steps,
|
| 109 |
+
stage2_inference_steps=stage2_inference_steps,
|
| 110 |
+
pointmap=pointmap,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
pipeline = Sam3dInference()
|
| 116 |
+
|
| 117 |
+
# load image
|
| 118 |
+
image = load_image(
|
| 119 |
+
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
|
| 120 |
+
)
|
| 121 |
+
mask = load_mask(
|
| 122 |
+
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
import torch
|
| 126 |
+
|
| 127 |
+
if torch.cuda.is_available():
|
| 128 |
+
torch.cuda.reset_peak_memory_stats()
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
+
|
| 131 |
+
from time import time
|
| 132 |
+
|
| 133 |
+
start = time()
|
| 134 |
+
|
| 135 |
+
output = pipeline(image, mask, seed=42)
|
| 136 |
+
print(f"Running cost: {round(time()-start, 1)}")
|
| 137 |
+
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
max_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
| 140 |
+
print(f"(Max VRAM): {max_memory:.2f} GB")
|
| 141 |
+
|
| 142 |
+
print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 143 |
+
|
| 144 |
+
output["gs"].save_ply(f"outputs/splat.ply")
|
| 145 |
+
print("Your reconstruction has been saved to outputs/splat.ply")
|
embodied_gen/utils/monkey_patches.py
CHANGED
|
@@ -25,6 +25,12 @@ from omegaconf import OmegaConf
|
|
| 25 |
from PIL import Image
|
| 26 |
from torchvision import transforms
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def monkey_patch_pano2room():
|
| 30 |
current_file_path = os.path.abspath(__file__)
|
|
@@ -216,3 +222,378 @@ def monkey_patch_maniskill():
|
|
| 216 |
ManiSkillScene.get_human_render_camera_images = (
|
| 217 |
get_human_render_camera_images
|
| 218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
from PIL import Image
|
| 26 |
from torchvision import transforms
|
| 27 |
|
| 28 |
+
__all__ = [
|
| 29 |
+
"monkey_patch_pano2room",
|
| 30 |
+
"monkey_patch_maniskill",
|
| 31 |
+
"monkey_patch_sam3d",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
|
| 35 |
def monkey_patch_pano2room():
|
| 36 |
current_file_path = os.path.abspath(__file__)
|
|
|
|
| 222 |
ManiSkillScene.get_human_render_camera_images = (
|
| 223 |
get_human_render_camera_images
|
| 224 |
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def monkey_patch_sam3d():
|
| 228 |
+
from typing import Optional, Union
|
| 229 |
+
|
| 230 |
+
from embodied_gen.data.utils import model_device_ctx
|
| 231 |
+
from embodied_gen.utils.log import logger
|
| 232 |
+
|
| 233 |
+
os.environ["LIDRA_SKIP_INIT"] = "true"
|
| 234 |
+
|
| 235 |
+
current_file_path = os.path.abspath(__file__)
|
| 236 |
+
current_dir = os.path.dirname(current_file_path)
|
| 237 |
+
sam3d_root = os.path.abspath(
|
| 238 |
+
os.path.join(current_dir, "../../thirdparty/sam3d")
|
| 239 |
+
)
|
| 240 |
+
if sam3d_root not in sys.path:
|
| 241 |
+
sys.path.insert(0, sam3d_root)
|
| 242 |
+
|
| 243 |
+
print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
|
| 244 |
+
|
| 245 |
+
def patch_pointmap_infer_pipeline():
|
| 246 |
+
from copy import deepcopy
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 250 |
+
InferencePipelinePointMap,
|
| 251 |
+
)
|
| 252 |
+
except ImportError:
|
| 253 |
+
logger.error(
|
| 254 |
+
"[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
|
| 255 |
+
)
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
def patch_run(
|
| 259 |
+
self,
|
| 260 |
+
image: Union[None, Image.Image, np.ndarray],
|
| 261 |
+
mask: Union[None, Image.Image, np.ndarray] = None,
|
| 262 |
+
seed: Optional[int] = None,
|
| 263 |
+
stage1_only=False,
|
| 264 |
+
with_mesh_postprocess=True,
|
| 265 |
+
with_texture_baking=True,
|
| 266 |
+
with_layout_postprocess=True,
|
| 267 |
+
use_vertex_color=False,
|
| 268 |
+
stage1_inference_steps=None,
|
| 269 |
+
stage2_inference_steps=None,
|
| 270 |
+
use_stage1_distillation=False,
|
| 271 |
+
use_stage2_distillation=False,
|
| 272 |
+
pointmap=None,
|
| 273 |
+
decode_formats=None,
|
| 274 |
+
estimate_plane=False,
|
| 275 |
+
) -> dict:
|
| 276 |
+
image = self.merge_image_and_mask(image, mask)
|
| 277 |
+
with self.device:
|
| 278 |
+
pointmap_dict = self.compute_pointmap(image, pointmap)
|
| 279 |
+
pointmap = pointmap_dict["pointmap"]
|
| 280 |
+
pts = type(self)._down_sample_img(pointmap)
|
| 281 |
+
pts_colors = type(self)._down_sample_img(
|
| 282 |
+
pointmap_dict["pts_color"]
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if estimate_plane:
|
| 286 |
+
return self.estimate_plane(pointmap_dict, image)
|
| 287 |
+
|
| 288 |
+
ss_input_dict = self.preprocess_image(
|
| 289 |
+
image, self.ss_preprocessor, pointmap=pointmap
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
slat_input_dict = self.preprocess_image(
|
| 293 |
+
image, self.slat_preprocessor
|
| 294 |
+
)
|
| 295 |
+
if seed is not None:
|
| 296 |
+
torch.manual_seed(seed)
|
| 297 |
+
|
| 298 |
+
with model_device_ctx(
|
| 299 |
+
self.models["ss_generator"],
|
| 300 |
+
self.models["ss_decoder"],
|
| 301 |
+
self.condition_embedders["ss_condition_embedder"],
|
| 302 |
+
):
|
| 303 |
+
ss_return_dict = self.sample_sparse_structure(
|
| 304 |
+
ss_input_dict,
|
| 305 |
+
inference_steps=stage1_inference_steps,
|
| 306 |
+
use_distillation=use_stage1_distillation,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# We could probably use the decoder from the models themselves
|
| 310 |
+
pointmap_scale = ss_input_dict.get("pointmap_scale", None)
|
| 311 |
+
pointmap_shift = ss_input_dict.get("pointmap_shift", None)
|
| 312 |
+
ss_return_dict.update(
|
| 313 |
+
self.pose_decoder(
|
| 314 |
+
ss_return_dict,
|
| 315 |
+
scene_scale=pointmap_scale,
|
| 316 |
+
scene_shift=pointmap_shift,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
logger.info(
|
| 321 |
+
f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
|
| 322 |
+
)
|
| 323 |
+
ss_return_dict["scale"] = (
|
| 324 |
+
ss_return_dict["scale"]
|
| 325 |
+
* ss_return_dict["downsample_factor"]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if stage1_only:
|
| 329 |
+
logger.info("Finished!")
|
| 330 |
+
ss_return_dict["voxel"] = (
|
| 331 |
+
ss_return_dict["coords"][:, 1:] / 64 - 0.5
|
| 332 |
+
)
|
| 333 |
+
return {
|
| 334 |
+
**ss_return_dict,
|
| 335 |
+
"pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
|
| 336 |
+
"pointmap_colors": pts_colors.cpu().permute(
|
| 337 |
+
(1, 2, 0)
|
| 338 |
+
), # HxWx3
|
| 339 |
+
}
|
| 340 |
+
# return ss_return_dict
|
| 341 |
+
|
| 342 |
+
coords = ss_return_dict["coords"]
|
| 343 |
+
with model_device_ctx(
|
| 344 |
+
self.models["slat_generator"],
|
| 345 |
+
self.condition_embedders["slat_condition_embedder"],
|
| 346 |
+
):
|
| 347 |
+
slat = self.sample_slat(
|
| 348 |
+
slat_input_dict,
|
| 349 |
+
coords,
|
| 350 |
+
inference_steps=stage2_inference_steps,
|
| 351 |
+
use_distillation=use_stage2_distillation,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
with model_device_ctx(
|
| 355 |
+
self.models["slat_decoder_mesh"],
|
| 356 |
+
self.models["slat_decoder_gs"],
|
| 357 |
+
self.models["slat_decoder_gs_4"],
|
| 358 |
+
):
|
| 359 |
+
outputs = self.decode_slat(
|
| 360 |
+
slat,
|
| 361 |
+
(
|
| 362 |
+
self.decode_formats
|
| 363 |
+
if decode_formats is None
|
| 364 |
+
else decode_formats
|
| 365 |
+
),
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
outputs = self.postprocess_slat_output(
|
| 369 |
+
outputs,
|
| 370 |
+
with_mesh_postprocess,
|
| 371 |
+
with_texture_baking,
|
| 372 |
+
use_vertex_color,
|
| 373 |
+
)
|
| 374 |
+
glb = outputs.get("glb", None)
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
if (
|
| 378 |
+
with_layout_postprocess
|
| 379 |
+
and self.layout_post_optimization_method is not None
|
| 380 |
+
):
|
| 381 |
+
assert (
|
| 382 |
+
glb is not None
|
| 383 |
+
), "require mesh to run postprocessing"
|
| 384 |
+
logger.info(
|
| 385 |
+
"Running layout post optimization method..."
|
| 386 |
+
)
|
| 387 |
+
postprocessed_pose = self.run_post_optimization(
|
| 388 |
+
deepcopy(glb),
|
| 389 |
+
pointmap_dict["intrinsics"],
|
| 390 |
+
ss_return_dict,
|
| 391 |
+
ss_input_dict,
|
| 392 |
+
)
|
| 393 |
+
ss_return_dict.update(postprocessed_pose)
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(
|
| 396 |
+
f"Error during layout post optimization: {e}",
|
| 397 |
+
exc_info=True,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# glb.export("sample.glb")
|
| 401 |
+
logger.info("Finished!")
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
**ss_return_dict,
|
| 405 |
+
**outputs,
|
| 406 |
+
"pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
|
| 407 |
+
"pointmap_colors": pts_colors.cpu().permute(
|
| 408 |
+
(1, 2, 0)
|
| 409 |
+
), # HxWx3
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
InferencePipelinePointMap.run = patch_run
|
| 413 |
+
|
| 414 |
+
def patch_infer_init():
|
| 415 |
+
import torch
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
from sam3d_objects.pipeline import preprocess_utils
|
| 419 |
+
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 420 |
+
InferencePipeline,
|
| 421 |
+
)
|
| 422 |
+
from sam3d_objects.pipeline.inference_utils import (
|
| 423 |
+
SLAT_MEAN,
|
| 424 |
+
SLAT_STD,
|
| 425 |
+
)
|
| 426 |
+
except ImportError:
|
| 427 |
+
print(
|
| 428 |
+
"[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
|
| 429 |
+
)
|
| 430 |
+
return
|
| 431 |
+
|
| 432 |
+
def patch_init(
|
| 433 |
+
self,
|
| 434 |
+
ss_generator_config_path,
|
| 435 |
+
ss_generator_ckpt_path,
|
| 436 |
+
slat_generator_config_path,
|
| 437 |
+
slat_generator_ckpt_path,
|
| 438 |
+
ss_decoder_config_path,
|
| 439 |
+
ss_decoder_ckpt_path,
|
| 440 |
+
slat_decoder_gs_config_path,
|
| 441 |
+
slat_decoder_gs_ckpt_path,
|
| 442 |
+
slat_decoder_mesh_config_path,
|
| 443 |
+
slat_decoder_mesh_ckpt_path,
|
| 444 |
+
slat_decoder_gs_4_config_path=None,
|
| 445 |
+
slat_decoder_gs_4_ckpt_path=None,
|
| 446 |
+
ss_encoder_config_path=None,
|
| 447 |
+
ss_encoder_ckpt_path=None,
|
| 448 |
+
decode_formats=["gaussian", "mesh"],
|
| 449 |
+
dtype="bfloat16",
|
| 450 |
+
pad_size=1.0,
|
| 451 |
+
version="v0",
|
| 452 |
+
device="cuda",
|
| 453 |
+
ss_preprocessor=preprocess_utils.get_default_preprocessor(),
|
| 454 |
+
slat_preprocessor=preprocess_utils.get_default_preprocessor(),
|
| 455 |
+
ss_condition_input_mapping=["image"],
|
| 456 |
+
slat_condition_input_mapping=["image"],
|
| 457 |
+
pose_decoder_name="default",
|
| 458 |
+
workspace_dir="",
|
| 459 |
+
downsample_ss_dist=0, # the distance we use to downsample
|
| 460 |
+
ss_inference_steps=25,
|
| 461 |
+
ss_rescale_t=3,
|
| 462 |
+
ss_cfg_strength=7,
|
| 463 |
+
ss_cfg_interval=[0, 500],
|
| 464 |
+
ss_cfg_strength_pm=0.0,
|
| 465 |
+
slat_inference_steps=25,
|
| 466 |
+
slat_rescale_t=3,
|
| 467 |
+
slat_cfg_strength=5,
|
| 468 |
+
slat_cfg_interval=[0, 500],
|
| 469 |
+
rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
|
| 470 |
+
shape_model_dtype=None,
|
| 471 |
+
compile_model=False,
|
| 472 |
+
slat_mean=SLAT_MEAN,
|
| 473 |
+
slat_std=SLAT_STD,
|
| 474 |
+
):
|
| 475 |
+
self.rendering_engine = rendering_engine
|
| 476 |
+
self.device = torch.device(device)
|
| 477 |
+
self.compile_model = compile_model
|
| 478 |
+
logger.info(f"self.device: {self.device}")
|
| 479 |
+
logger.info(
|
| 480 |
+
f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
|
| 481 |
+
)
|
| 482 |
+
logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
|
| 483 |
+
with self.device:
|
| 484 |
+
self.decode_formats = decode_formats
|
| 485 |
+
self.pad_size = pad_size
|
| 486 |
+
self.version = version
|
| 487 |
+
self.ss_condition_input_mapping = ss_condition_input_mapping
|
| 488 |
+
self.slat_condition_input_mapping = (
|
| 489 |
+
slat_condition_input_mapping
|
| 490 |
+
)
|
| 491 |
+
self.workspace_dir = workspace_dir
|
| 492 |
+
self.downsample_ss_dist = downsample_ss_dist
|
| 493 |
+
self.ss_inference_steps = ss_inference_steps
|
| 494 |
+
self.ss_rescale_t = ss_rescale_t
|
| 495 |
+
self.ss_cfg_strength = ss_cfg_strength
|
| 496 |
+
self.ss_cfg_interval = ss_cfg_interval
|
| 497 |
+
self.ss_cfg_strength_pm = ss_cfg_strength_pm
|
| 498 |
+
self.slat_inference_steps = slat_inference_steps
|
| 499 |
+
self.slat_rescale_t = slat_rescale_t
|
| 500 |
+
self.slat_cfg_strength = slat_cfg_strength
|
| 501 |
+
self.slat_cfg_interval = slat_cfg_interval
|
| 502 |
+
|
| 503 |
+
self.dtype = self._get_dtype(dtype)
|
| 504 |
+
if shape_model_dtype is None:
|
| 505 |
+
self.shape_model_dtype = self.dtype
|
| 506 |
+
else:
|
| 507 |
+
self.shape_model_dtype = self._get_dtype(shape_model_dtype)
|
| 508 |
+
|
| 509 |
+
# Setup preprocessors
|
| 510 |
+
self.pose_decoder = self.init_pose_decoder(
|
| 511 |
+
ss_generator_config_path, pose_decoder_name
|
| 512 |
+
)
|
| 513 |
+
self.ss_preprocessor = self.init_ss_preprocessor(
|
| 514 |
+
ss_preprocessor, ss_generator_config_path
|
| 515 |
+
)
|
| 516 |
+
self.slat_preprocessor = slat_preprocessor
|
| 517 |
+
|
| 518 |
+
logger.info("Loading model weights...")
|
| 519 |
+
raw_device = self.device
|
| 520 |
+
self.device = torch.device("cpu")
|
| 521 |
+
ss_generator = self.init_ss_generator(
|
| 522 |
+
ss_generator_config_path, ss_generator_ckpt_path
|
| 523 |
+
)
|
| 524 |
+
slat_generator = self.init_slat_generator(
|
| 525 |
+
slat_generator_config_path, slat_generator_ckpt_path
|
| 526 |
+
)
|
| 527 |
+
ss_decoder = self.init_ss_decoder(
|
| 528 |
+
ss_decoder_config_path, ss_decoder_ckpt_path
|
| 529 |
+
)
|
| 530 |
+
ss_encoder = self.init_ss_encoder(
|
| 531 |
+
ss_encoder_config_path, ss_encoder_ckpt_path
|
| 532 |
+
)
|
| 533 |
+
slat_decoder_gs = self.init_slat_decoder_gs(
|
| 534 |
+
slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
|
| 535 |
+
)
|
| 536 |
+
slat_decoder_gs_4 = self.init_slat_decoder_gs(
|
| 537 |
+
slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
|
| 538 |
+
)
|
| 539 |
+
slat_decoder_mesh = self.init_slat_decoder_mesh(
|
| 540 |
+
slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Load conditioner embedder so that we only load it once
|
| 544 |
+
ss_condition_embedder = self.init_ss_condition_embedder(
|
| 545 |
+
ss_generator_config_path, ss_generator_ckpt_path
|
| 546 |
+
)
|
| 547 |
+
slat_condition_embedder = self.init_slat_condition_embedder(
|
| 548 |
+
slat_generator_config_path, slat_generator_ckpt_path
|
| 549 |
+
)
|
| 550 |
+
self.device = raw_device
|
| 551 |
+
|
| 552 |
+
self.condition_embedders = {
|
| 553 |
+
"ss_condition_embedder": ss_condition_embedder,
|
| 554 |
+
"slat_condition_embedder": slat_condition_embedder,
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
# override generator and condition embedder setting
|
| 558 |
+
self.override_ss_generator_cfg_config(
|
| 559 |
+
ss_generator,
|
| 560 |
+
cfg_strength=ss_cfg_strength,
|
| 561 |
+
inference_steps=ss_inference_steps,
|
| 562 |
+
rescale_t=ss_rescale_t,
|
| 563 |
+
cfg_interval=ss_cfg_interval,
|
| 564 |
+
cfg_strength_pm=ss_cfg_strength_pm,
|
| 565 |
+
)
|
| 566 |
+
self.override_slat_generator_cfg_config(
|
| 567 |
+
slat_generator,
|
| 568 |
+
cfg_strength=slat_cfg_strength,
|
| 569 |
+
inference_steps=slat_inference_steps,
|
| 570 |
+
rescale_t=slat_rescale_t,
|
| 571 |
+
cfg_interval=slat_cfg_interval,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
self.models = torch.nn.ModuleDict(
|
| 575 |
+
{
|
| 576 |
+
"ss_generator": ss_generator,
|
| 577 |
+
"slat_generator": slat_generator,
|
| 578 |
+
"ss_encoder": ss_encoder,
|
| 579 |
+
"ss_decoder": ss_decoder,
|
| 580 |
+
"slat_decoder_gs": slat_decoder_gs,
|
| 581 |
+
"slat_decoder_gs_4": slat_decoder_gs_4,
|
| 582 |
+
"slat_decoder_mesh": slat_decoder_mesh,
|
| 583 |
+
}
|
| 584 |
+
)
|
| 585 |
+
logger.info("Loading model weights completed!")
|
| 586 |
+
|
| 587 |
+
if self.compile_model:
|
| 588 |
+
logger.info("Compiling model...")
|
| 589 |
+
self._compile()
|
| 590 |
+
logger.info("Model compilation completed!")
|
| 591 |
+
self.slat_mean = torch.tensor(slat_mean)
|
| 592 |
+
self.slat_std = torch.tensor(slat_std)
|
| 593 |
+
|
| 594 |
+
InferencePipeline.__init__ = patch_init
|
| 595 |
+
|
| 596 |
+
patch_pointmap_infer_pipeline()
|
| 597 |
+
patch_infer_init()
|
| 598 |
+
|
| 599 |
+
return
|
embodied_gen/utils/trender.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import sys
|
|
|
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import spaces
|
|
@@ -25,10 +26,8 @@ from tqdm import tqdm
|
|
| 25 |
current_file_path = os.path.abspath(__file__)
|
| 26 |
current_dir = os.path.dirname(current_file_path)
|
| 27 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 28 |
-
from thirdparty.TRELLIS.trellis.renderers
|
| 29 |
-
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
| 30 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
| 31 |
-
render_frames,
|
| 32 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 33 |
)
|
| 34 |
|
|
@@ -38,7 +37,7 @@ __all__ = [
|
|
| 38 |
|
| 39 |
|
| 40 |
@spaces.GPU
|
| 41 |
-
def
|
| 42 |
renderer = MeshRenderer()
|
| 43 |
renderer.rendering_options.resolution = options.get("resolution", 512)
|
| 44 |
renderer.rendering_options.near = options.get("near", 1)
|
|
@@ -60,6 +59,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
|
| 60 |
return rets
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
@spaces.GPU
|
| 64 |
def render_video(
|
| 65 |
sample,
|
|
@@ -77,7 +127,9 @@ def render_video(
|
|
| 77 |
yaws, pitch, r, fov
|
| 78 |
)
|
| 79 |
render_fn = (
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
)
|
| 82 |
result = render_fn(
|
| 83 |
sample,
|
|
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import sys
|
| 19 |
+
from collections import defaultdict
|
| 20 |
|
| 21 |
import numpy as np
|
| 22 |
import spaces
|
|
|
|
| 26 |
current_file_path = os.path.abspath(__file__)
|
| 27 |
current_dir = os.path.dirname(current_file_path)
|
| 28 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 29 |
+
from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
|
|
|
|
| 30 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
|
|
|
| 31 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 32 |
)
|
| 33 |
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
@spaces.GPU
|
| 40 |
+
def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs):
|
| 41 |
renderer = MeshRenderer()
|
| 42 |
renderer.rendering_options.resolution = options.get("resolution", 512)
|
| 43 |
renderer.rendering_options.near = options.get("near", 1)
|
|
|
|
| 59 |
return rets
|
| 60 |
|
| 61 |
|
| 62 |
+
@spaces.GPU
|
| 63 |
+
def render_gs_frames(
|
| 64 |
+
sample,
|
| 65 |
+
extrinsics,
|
| 66 |
+
intrinsics,
|
| 67 |
+
options=None,
|
| 68 |
+
colors_overwrite=None,
|
| 69 |
+
verbose=True,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
def to_img(tensor):
|
| 73 |
+
return np.clip(
|
| 74 |
+
tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
| 75 |
+
).astype(np.uint8)
|
| 76 |
+
|
| 77 |
+
def to_numpy(tensor):
|
| 78 |
+
return tensor.detach().cpu().numpy()
|
| 79 |
+
|
| 80 |
+
renderer = GaussianRenderer()
|
| 81 |
+
renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1)
|
| 82 |
+
renderer.pipe.use_mip_gaussian = True
|
| 83 |
+
|
| 84 |
+
defaults = {
|
| 85 |
+
"resolution": 512,
|
| 86 |
+
"near": 0.8,
|
| 87 |
+
"far": 1.6,
|
| 88 |
+
"bg_color": (0, 0, 0),
|
| 89 |
+
"ssaa": 1,
|
| 90 |
+
}
|
| 91 |
+
final_options = {**defaults, **(options or {})}
|
| 92 |
+
|
| 93 |
+
for k, v in final_options.items():
|
| 94 |
+
if hasattr(renderer.rendering_options, k):
|
| 95 |
+
setattr(renderer.rendering_options, k, v)
|
| 96 |
+
|
| 97 |
+
outputs = defaultdict(list)
|
| 98 |
+
iterator = zip(extrinsics, intrinsics)
|
| 99 |
+
if verbose:
|
| 100 |
+
iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering")
|
| 101 |
+
|
| 102 |
+
for extr, intr in iterator:
|
| 103 |
+
res = renderer.render(
|
| 104 |
+
sample, extr, intr, colors_overwrite=colors_overwrite
|
| 105 |
+
)
|
| 106 |
+
outputs["color"].append(to_img(res["color"]))
|
| 107 |
+
depth = res.get("percent_depth") or res.get("depth")
|
| 108 |
+
outputs["depth"].append(to_numpy(depth) if depth is not None else None)
|
| 109 |
+
|
| 110 |
+
return dict(outputs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
@spaces.GPU
|
| 114 |
def render_video(
|
| 115 |
sample,
|
|
|
|
| 127 |
yaws, pitch, r, fov
|
| 128 |
)
|
| 129 |
render_fn = (
|
| 130 |
+
render_mesh_frames
|
| 131 |
+
if sample.__class__.__name__ == "MeshExtractResult"
|
| 132 |
+
else render_gs_frames
|
| 133 |
)
|
| 134 |
result = render_fn(
|
| 135 |
sample,
|
requirements.txt
CHANGED
|
@@ -52,4 +52,12 @@ pyquaternion
|
|
| 52 |
shapely
|
| 53 |
sapien==3.0.0b1
|
| 54 |
typing_extensions==4.14.1
|
| 55 |
-
coacd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
shapely
|
| 53 |
sapien==3.0.0b1
|
| 54 |
typing_extensions==4.14.1
|
| 55 |
+
coacd
|
| 56 |
+
ninja
|
| 57 |
+
packaging
|
| 58 |
+
lightning
|
| 59 |
+
astor
|
| 60 |
+
optree
|
| 61 |
+
loguru
|
| 62 |
+
seaborn
|
| 63 |
+
hydra-core
|
thirdparty/TRELLIS/trellis/utils/postprocessing_utils.py
CHANGED
|
@@ -440,7 +440,7 @@ def to_glb(
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
-
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
|
|
|
| 440 |
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
|
| 442 |
# bake texture
|
| 443 |
+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
| 444 |
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|