Spaces:
Running on Zero
Running on Zero
xinjie.wang commited on
Commit ·
54da04d
1
Parent(s): be013ba
update
Browse files- app.py +1 -61
- common.py +6 -622
- embodied_gen/utils/monkey_patch/sam3d.py +4 -4
app.py
CHANGED
|
@@ -471,67 +471,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|
| 471 |
inputs=image_seg_sam,
|
| 472 |
outputs=generate_btn,
|
| 473 |
)
|
| 474 |
-
|
| 475 |
-
generate_btn.click(
|
| 476 |
-
get_seed,
|
| 477 |
-
inputs=[randomize_seed, seed],
|
| 478 |
-
outputs=[seed],
|
| 479 |
-
).success(
|
| 480 |
-
image_to_3d,
|
| 481 |
-
inputs=[
|
| 482 |
-
image_prompt,
|
| 483 |
-
seed,
|
| 484 |
-
ss_sampling_steps,
|
| 485 |
-
slat_sampling_steps,
|
| 486 |
-
raw_image_cache,
|
| 487 |
-
ss_guidance_strength,
|
| 488 |
-
slat_guidance_strength,
|
| 489 |
-
image_seg_sam,
|
| 490 |
-
is_samimage,
|
| 491 |
-
],
|
| 492 |
-
outputs=[output_buf, video_output],
|
| 493 |
-
).success(
|
| 494 |
-
extract_3d_representations_v3,
|
| 495 |
-
inputs=[
|
| 496 |
-
output_buf,
|
| 497 |
-
project_delight,
|
| 498 |
-
texture_size,
|
| 499 |
-
],
|
| 500 |
-
outputs=[
|
| 501 |
-
model_output_mesh,
|
| 502 |
-
model_output_gs,
|
| 503 |
-
model_output_obj,
|
| 504 |
-
aligned_gs,
|
| 505 |
-
],
|
| 506 |
-
).success(
|
| 507 |
-
lambda: gr.Button(interactive=True),
|
| 508 |
-
outputs=[extract_urdf_btn],
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
extract_urdf_btn.click(
|
| 512 |
-
extract_urdf,
|
| 513 |
-
inputs=[
|
| 514 |
-
aligned_gs,
|
| 515 |
-
model_output_obj,
|
| 516 |
-
asset_cat_text,
|
| 517 |
-
height_range_text,
|
| 518 |
-
mass_range_text,
|
| 519 |
-
asset_version_text,
|
| 520 |
-
],
|
| 521 |
-
outputs=[
|
| 522 |
-
download_urdf,
|
| 523 |
-
est_type_text,
|
| 524 |
-
est_height_text,
|
| 525 |
-
est_mass_text,
|
| 526 |
-
est_mu_text,
|
| 527 |
-
],
|
| 528 |
-
queue=True,
|
| 529 |
-
show_progress="full",
|
| 530 |
-
).success(
|
| 531 |
-
lambda: gr.Button(interactive=True),
|
| 532 |
-
outputs=[download_urdf],
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
|
| 536 |
if __name__ == "__main__":
|
| 537 |
demo.launch()
|
|
|
|
| 471 |
inputs=image_seg_sam,
|
| 472 |
outputs=generate_btn,
|
| 473 |
)
|
| 474 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
if __name__ == "__main__":
|
| 477 |
demo.launch()
|
common.py
CHANGED
|
@@ -15,10 +15,6 @@
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
import spaces
|
| 18 |
-
from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
|
| 19 |
-
|
| 20 |
-
monkey_path_trellis()
|
| 21 |
-
|
| 22 |
import gc
|
| 23 |
import logging
|
| 24 |
import os
|
|
@@ -32,48 +28,21 @@ import gradio as gr
|
|
| 32 |
import numpy as np
|
| 33 |
import torch
|
| 34 |
import trimesh
|
| 35 |
-
from PIL import Image
|
| 36 |
-
from embodied_gen.data.
|
| 37 |
-
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
| 38 |
-
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
| 39 |
-
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 40 |
-
from embodied_gen.models.delight_model import DelightingModel
|
| 41 |
-
from embodied_gen.models.gs_model import GaussianOperator
|
| 42 |
-
from embodied_gen.models.sam3d import Sam3dInference
|
| 43 |
from embodied_gen.models.segment_model import (
|
| 44 |
BMGG14Remover,
|
| 45 |
RembgRemover,
|
| 46 |
SAMPredictor,
|
| 47 |
-
)
|
| 48 |
-
from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
|
| 49 |
-
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
| 50 |
-
from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
|
| 51 |
-
from embodied_gen.scripts.text2image import (
|
| 52 |
-
build_text2img_ip_pipeline,
|
| 53 |
-
build_text2img_pipeline,
|
| 54 |
-
text2img_gen,
|
| 55 |
-
)
|
| 56 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 57 |
from embodied_gen.utils.process_media import (
|
| 58 |
filter_image_small_connected_components,
|
| 59 |
keep_largest_connected_component,
|
| 60 |
merge_images_video,
|
| 61 |
)
|
| 62 |
-
from embodied_gen.utils.tags import VERSION
|
| 63 |
-
|
| 64 |
-
from embodied_gen.validators.quality_checkers import (
|
| 65 |
-
BaseChecker,
|
| 66 |
-
ImageAestheticChecker,
|
| 67 |
-
ImageSegChecker,
|
| 68 |
-
MeshGeoChecker,
|
| 69 |
-
)
|
| 70 |
-
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
| 71 |
-
|
| 72 |
-
current_file_path = os.path.abspath(__file__)
|
| 73 |
-
current_dir = os.path.dirname(current_file_path)
|
| 74 |
-
sys.path.append(os.path.join(current_dir, ".."))
|
| 75 |
-
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
| 76 |
-
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
|
| 77 |
|
| 78 |
logging.basicConfig(
|
| 79 |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
@@ -83,67 +52,15 @@ logger = logging.getLogger(__name__)
|
|
| 83 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
| 84 |
os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
|
| 85 |
MAX_SEED = 100000
|
| 86 |
-
|
| 87 |
-
# DELIGHT = DelightingModel()
|
| 88 |
-
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 89 |
-
# IMAGESR_MODEL = ImageStableSR()
|
| 90 |
if os.getenv("GRADIO_APP").startswith("imageto3d"):
|
| 91 |
RBG_REMOVER = RembgRemover()
|
| 92 |
RBG14_REMOVER = BMGG14Remover()
|
| 93 |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 94 |
-
# if "sam3d" in os.getenv("GRADIO_APP"):
|
| 95 |
-
# PIPELINE = Sam3dInference()
|
| 96 |
-
# else:
|
| 97 |
-
# PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 98 |
-
# "microsoft/TRELLIS-image-large"
|
| 99 |
-
# )
|
| 100 |
-
# # PIPELINE.cuda()
|
| 101 |
-
# SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 102 |
-
# GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 103 |
-
# AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 104 |
-
# CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 105 |
TMP_DIR = os.path.join(
|
| 106 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 107 |
)
|
| 108 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 109 |
-
elif os.getenv("GRADIO_APP").startswith("textto3d"):
|
| 110 |
-
RBG_REMOVER = RembgRemover()
|
| 111 |
-
RBG14_REMOVER = BMGG14Remover()
|
| 112 |
-
if "sam3d" in os.getenv("GRADIO_APP"):
|
| 113 |
-
PIPELINE = Sam3dInference()
|
| 114 |
-
else:
|
| 115 |
-
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 116 |
-
"microsoft/TRELLIS-image-large"
|
| 117 |
-
)
|
| 118 |
-
# PIPELINE.cuda()
|
| 119 |
-
text_model_dir = "weights/Kolors"
|
| 120 |
-
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
| 121 |
-
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
| 122 |
-
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 123 |
-
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 124 |
-
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 125 |
-
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 126 |
-
TMP_DIR = os.path.join(
|
| 127 |
-
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 128 |
-
)
|
| 129 |
-
os.makedirs(TMP_DIR, exist_ok=True)
|
| 130 |
-
elif os.getenv("GRADIO_APP") == "texture_edit":
|
| 131 |
-
DELIGHT = DelightingModel()
|
| 132 |
-
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 133 |
-
PIPELINE_IP = build_texture_gen_pipe(
|
| 134 |
-
base_ckpt_dir="./weights",
|
| 135 |
-
ip_adapt_scale=0.7,
|
| 136 |
-
device="cuda",
|
| 137 |
-
)
|
| 138 |
-
PIPELINE = build_texture_gen_pipe(
|
| 139 |
-
base_ckpt_dir="./weights",
|
| 140 |
-
ip_adapt_scale=0,
|
| 141 |
-
device="cuda",
|
| 142 |
-
)
|
| 143 |
-
TMP_DIR = os.path.join(
|
| 144 |
-
os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
|
| 145 |
-
)
|
| 146 |
-
os.makedirs(TMP_DIR, exist_ok=True)
|
| 147 |
|
| 148 |
|
| 149 |
def start_session(req: gr.Request) -> None:
|
|
@@ -262,536 +179,3 @@ def select_point(
|
|
| 262 |
|
| 263 |
return (image, masks), seg_image
|
| 264 |
|
| 265 |
-
|
| 266 |
-
@spaces.GPU(duration=300)
|
| 267 |
-
def image_to_3d(
|
| 268 |
-
image: Image.Image,
|
| 269 |
-
seed: int,
|
| 270 |
-
ss_sampling_steps: int,
|
| 271 |
-
slat_sampling_steps: int,
|
| 272 |
-
raw_image_cache: Image.Image,
|
| 273 |
-
ss_guidance_strength: float,
|
| 274 |
-
slat_guidance_strength: float,
|
| 275 |
-
sam_image: Image.Image = None,
|
| 276 |
-
is_sam_image: bool = False,
|
| 277 |
-
req: gr.Request = None,
|
| 278 |
-
) -> tuple[dict, str]:
|
| 279 |
-
if is_sam_image:
|
| 280 |
-
seg_image = filter_image_small_connected_components(sam_image)
|
| 281 |
-
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
| 282 |
-
else:
|
| 283 |
-
seg_image = image
|
| 284 |
-
|
| 285 |
-
if isinstance(seg_image, np.ndarray):
|
| 286 |
-
seg_image = Image.fromarray(seg_image)
|
| 287 |
-
|
| 288 |
-
logger.info("Start generating 3D representation from image...")
|
| 289 |
-
if isinstance(PIPELINE, Sam3dInference):
|
| 290 |
-
outputs = PIPELINE.run(
|
| 291 |
-
seg_image,
|
| 292 |
-
seed=seed,
|
| 293 |
-
stage1_inference_steps=ss_sampling_steps,
|
| 294 |
-
stage2_inference_steps=slat_sampling_steps,
|
| 295 |
-
)
|
| 296 |
-
else:
|
| 297 |
-
PIPELINE.cuda()
|
| 298 |
-
seg_image = trellis_preprocess(seg_image)
|
| 299 |
-
outputs = PIPELINE.run(
|
| 300 |
-
seg_image,
|
| 301 |
-
seed=seed,
|
| 302 |
-
formats=["gaussian", "mesh"],
|
| 303 |
-
preprocess_image=False,
|
| 304 |
-
sparse_structure_sampler_params={
|
| 305 |
-
"steps": ss_sampling_steps,
|
| 306 |
-
"cfg_strength": ss_guidance_strength,
|
| 307 |
-
},
|
| 308 |
-
slat_sampler_params={
|
| 309 |
-
"steps": slat_sampling_steps,
|
| 310 |
-
"cfg_strength": slat_guidance_strength,
|
| 311 |
-
},
|
| 312 |
-
)
|
| 313 |
-
# Set back to cpu for memory saving.
|
| 314 |
-
PIPELINE.cpu()
|
| 315 |
-
|
| 316 |
-
gs_model = outputs["gaussian"][0]
|
| 317 |
-
mesh_model = outputs["mesh"][0]
|
| 318 |
-
color_images = render_video(gs_model, r=1.85)["color"]
|
| 319 |
-
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 320 |
-
|
| 321 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 322 |
-
os.makedirs(output_root, exist_ok=True)
|
| 323 |
-
seg_image.save(f"{output_root}/seg_image.png")
|
| 324 |
-
raw_image_cache.save(f"{output_root}/raw_image.png")
|
| 325 |
-
|
| 326 |
-
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 327 |
-
merge_images_video(color_images, normal_images, video_path)
|
| 328 |
-
state = pack_state(gs_model, mesh_model)
|
| 329 |
-
|
| 330 |
-
gc.collect()
|
| 331 |
-
torch.cuda.empty_cache()
|
| 332 |
-
|
| 333 |
-
return state, video_path
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
def extract_3d_representations_v2(
|
| 337 |
-
state: dict,
|
| 338 |
-
enable_delight: bool,
|
| 339 |
-
texture_size: int,
|
| 340 |
-
req: gr.Request,
|
| 341 |
-
):
|
| 342 |
-
"""Back-Projection Version of Texture Super-Resolution."""
|
| 343 |
-
output_root = TMP_DIR
|
| 344 |
-
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 345 |
-
gs_model, mesh_model = unpack_state(state, device="cpu")
|
| 346 |
-
|
| 347 |
-
filename = "sample"
|
| 348 |
-
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
|
| 349 |
-
gs_model.save_ply(gs_path)
|
| 350 |
-
|
| 351 |
-
# Rotate mesh and GS by 90 degrees around Z-axis.
|
| 352 |
-
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
| 353 |
-
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
| 354 |
-
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
| 355 |
-
|
| 356 |
-
# Addtional rotation for GS to align mesh.
|
| 357 |
-
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
| 358 |
-
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
| 359 |
-
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
| 360 |
-
GaussianOperator.resave_ply(
|
| 361 |
-
in_ply=gs_path,
|
| 362 |
-
out_ply=aligned_gs_path,
|
| 363 |
-
instance_pose=pose,
|
| 364 |
-
device="cpu",
|
| 365 |
-
)
|
| 366 |
-
color_path = os.path.join(user_dir, "color.png")
|
| 367 |
-
render_gs_api(
|
| 368 |
-
input_gs=aligned_gs_path,
|
| 369 |
-
output_path=color_path,
|
| 370 |
-
elevation=[20, -10, 60, -50],
|
| 371 |
-
num_images=12,
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
mesh = trimesh.Trimesh(
|
| 375 |
-
vertices=mesh_model.vertices.cpu().numpy(),
|
| 376 |
-
faces=mesh_model.faces.cpu().numpy(),
|
| 377 |
-
)
|
| 378 |
-
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
| 379 |
-
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
| 380 |
-
|
| 381 |
-
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
|
| 382 |
-
mesh.export(mesh_obj_path)
|
| 383 |
-
|
| 384 |
-
mesh = backproject_api(
|
| 385 |
-
delight_model=DELIGHT,
|
| 386 |
-
imagesr_model=IMAGESR_MODEL,
|
| 387 |
-
color_path=color_path,
|
| 388 |
-
mesh_path=mesh_obj_path,
|
| 389 |
-
output_path=mesh_obj_path,
|
| 390 |
-
skip_fix_mesh=False,
|
| 391 |
-
delight=enable_delight,
|
| 392 |
-
texture_wh=[texture_size, texture_size],
|
| 393 |
-
elevation=[20, -10, 60, -50],
|
| 394 |
-
num_images=12,
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
|
| 398 |
-
mesh.export(mesh_glb_path)
|
| 399 |
-
|
| 400 |
-
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
def extract_3d_representations_v3(
|
| 404 |
-
state: dict,
|
| 405 |
-
enable_delight: bool,
|
| 406 |
-
texture_size: int,
|
| 407 |
-
req: gr.Request,
|
| 408 |
-
):
|
| 409 |
-
"""Back-Projection Version with Optimization-Based."""
|
| 410 |
-
output_root = TMP_DIR
|
| 411 |
-
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 412 |
-
gs_model, mesh_model = unpack_state(state, device="cpu")
|
| 413 |
-
|
| 414 |
-
filename = "sample"
|
| 415 |
-
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
|
| 416 |
-
gs_model.save_ply(gs_path)
|
| 417 |
-
|
| 418 |
-
# Rotate mesh and GS by 90 degrees around Z-axis.
|
| 419 |
-
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
| 420 |
-
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
| 421 |
-
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
| 422 |
-
|
| 423 |
-
# Addtional rotation for GS to align mesh.
|
| 424 |
-
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
| 425 |
-
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
| 426 |
-
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
| 427 |
-
GaussianOperator.resave_ply(
|
| 428 |
-
in_ply=gs_path,
|
| 429 |
-
out_ply=aligned_gs_path,
|
| 430 |
-
instance_pose=pose,
|
| 431 |
-
device="cpu",
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
mesh = trimesh.Trimesh(
|
| 435 |
-
vertices=mesh_model.vertices.cpu().numpy(),
|
| 436 |
-
faces=mesh_model.faces.cpu().numpy(),
|
| 437 |
-
)
|
| 438 |
-
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
| 439 |
-
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
| 440 |
-
|
| 441 |
-
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
|
| 442 |
-
mesh.export(mesh_obj_path)
|
| 443 |
-
|
| 444 |
-
mesh = backproject_api_v3(
|
| 445 |
-
gs_path=aligned_gs_path,
|
| 446 |
-
mesh_path=mesh_obj_path,
|
| 447 |
-
output_path=mesh_obj_path,
|
| 448 |
-
skip_fix_mesh=False,
|
| 449 |
-
texture_size=texture_size,
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
|
| 453 |
-
mesh.export(mesh_glb_path)
|
| 454 |
-
|
| 455 |
-
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
def extract_urdf(
|
| 459 |
-
gs_path: str,
|
| 460 |
-
mesh_obj_path: str,
|
| 461 |
-
asset_cat_text: str,
|
| 462 |
-
height_range_text: str,
|
| 463 |
-
mass_range_text: str,
|
| 464 |
-
asset_version_text: str,
|
| 465 |
-
req: gr.Request = None,
|
| 466 |
-
):
|
| 467 |
-
output_root = TMP_DIR
|
| 468 |
-
if req is not None:
|
| 469 |
-
output_root = os.path.join(output_root, str(req.session_hash))
|
| 470 |
-
|
| 471 |
-
# Convert to URDF and recover attrs by GPT.
|
| 472 |
-
filename = "sample"
|
| 473 |
-
urdf_convertor = URDFGenerator(
|
| 474 |
-
GPT_CLIENT, render_view_num=4, decompose_convex=True
|
| 475 |
-
)
|
| 476 |
-
asset_attrs = {
|
| 477 |
-
"version": VERSION,
|
| 478 |
-
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
|
| 479 |
-
}
|
| 480 |
-
if asset_version_text:
|
| 481 |
-
asset_attrs["version"] = asset_version_text
|
| 482 |
-
if asset_cat_text:
|
| 483 |
-
asset_attrs["category"] = asset_cat_text.lower()
|
| 484 |
-
if height_range_text:
|
| 485 |
-
try:
|
| 486 |
-
min_height, max_height = map(float, height_range_text.split("-"))
|
| 487 |
-
asset_attrs["min_height"] = min_height
|
| 488 |
-
asset_attrs["max_height"] = max_height
|
| 489 |
-
except ValueError:
|
| 490 |
-
return "Invalid height input format. Use the format: min-max."
|
| 491 |
-
if mass_range_text:
|
| 492 |
-
try:
|
| 493 |
-
min_mass, max_mass = map(float, mass_range_text.split("-"))
|
| 494 |
-
asset_attrs["min_mass"] = min_mass
|
| 495 |
-
asset_attrs["max_mass"] = max_mass
|
| 496 |
-
except ValueError:
|
| 497 |
-
return "Invalid mass input format. Use the format: min-max."
|
| 498 |
-
|
| 499 |
-
urdf_path = urdf_convertor(
|
| 500 |
-
mesh_path=mesh_obj_path,
|
| 501 |
-
output_root=f"{output_root}/URDF_{filename}",
|
| 502 |
-
**asset_attrs,
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
# Rescale GS and save to URDF/mesh folder.
|
| 506 |
-
real_height = urdf_convertor.get_attr_from_urdf(
|
| 507 |
-
urdf_path, attr_name="real_height"
|
| 508 |
-
)
|
| 509 |
-
out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
|
| 510 |
-
GaussianOperator.resave_ply(
|
| 511 |
-
in_ply=gs_path,
|
| 512 |
-
out_ply=out_gs,
|
| 513 |
-
real_height=real_height,
|
| 514 |
-
device="cpu",
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
# Quality check and update .urdf file.
|
| 518 |
-
mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
|
| 519 |
-
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
|
| 520 |
-
# image_paths = render_asset3d(
|
| 521 |
-
# mesh_path=mesh_out,
|
| 522 |
-
# output_root=f"{output_root}/URDF_{filename}",
|
| 523 |
-
# output_subdir="qa_renders",
|
| 524 |
-
# num_images=8,
|
| 525 |
-
# elevation=(30, -30),
|
| 526 |
-
# distance=5.5,
|
| 527 |
-
# )
|
| 528 |
-
|
| 529 |
-
image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
|
| 530 |
-
image_paths = glob(f"{image_dir}/*.png")
|
| 531 |
-
images_list = []
|
| 532 |
-
for checker in CHECKERS:
|
| 533 |
-
images = image_paths
|
| 534 |
-
if isinstance(checker, ImageSegChecker):
|
| 535 |
-
images = [
|
| 536 |
-
f"{TMP_DIR}/{req.session_hash}/raw_image.png",
|
| 537 |
-
f"{TMP_DIR}/{req.session_hash}/seg_image.png",
|
| 538 |
-
]
|
| 539 |
-
images_list.append(images)
|
| 540 |
-
|
| 541 |
-
results = BaseChecker.validate(CHECKERS, images_list)
|
| 542 |
-
urdf_convertor.add_quality_tag(urdf_path, results)
|
| 543 |
-
|
| 544 |
-
# Zip urdf files
|
| 545 |
-
urdf_zip = zip_files(
|
| 546 |
-
input_paths=[
|
| 547 |
-
f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
|
| 548 |
-
f"{output_root}/URDF_{filename}/{filename}.urdf",
|
| 549 |
-
],
|
| 550 |
-
output_zip=f"{output_root}/urdf_{filename}.zip",
|
| 551 |
-
)
|
| 552 |
-
|
| 553 |
-
estimated_type = urdf_convertor.estimated_attrs["category"]
|
| 554 |
-
estimated_height = urdf_convertor.estimated_attrs["height"]
|
| 555 |
-
estimated_mass = urdf_convertor.estimated_attrs["mass"]
|
| 556 |
-
estimated_mu = urdf_convertor.estimated_attrs["mu"]
|
| 557 |
-
|
| 558 |
-
return (
|
| 559 |
-
urdf_zip,
|
| 560 |
-
estimated_type,
|
| 561 |
-
estimated_height,
|
| 562 |
-
estimated_mass,
|
| 563 |
-
estimated_mu,
|
| 564 |
-
)
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
@spaces.GPU(duration=300)
|
| 568 |
-
def text2image_fn(
|
| 569 |
-
prompt: str,
|
| 570 |
-
guidance_scale: float,
|
| 571 |
-
infer_step: int = 50,
|
| 572 |
-
ip_image: Image.Image | str = None,
|
| 573 |
-
ip_adapt_scale: float = 0.3,
|
| 574 |
-
image_wh: int | tuple[int, int] = [1024, 1024],
|
| 575 |
-
rmbg_tag: str = "rembg",
|
| 576 |
-
seed: int = None,
|
| 577 |
-
enable_pre_resize: bool = True,
|
| 578 |
-
n_sample: int = 3,
|
| 579 |
-
req: gr.Request = None,
|
| 580 |
-
):
|
| 581 |
-
if isinstance(image_wh, int):
|
| 582 |
-
image_wh = (image_wh, image_wh)
|
| 583 |
-
output_root = TMP_DIR
|
| 584 |
-
if req is not None:
|
| 585 |
-
output_root = os.path.join(output_root, str(req.session_hash))
|
| 586 |
-
os.makedirs(output_root, exist_ok=True)
|
| 587 |
-
|
| 588 |
-
pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
|
| 589 |
-
if ip_image is not None:
|
| 590 |
-
pipeline.set_ip_adapter_scale([ip_adapt_scale])
|
| 591 |
-
|
| 592 |
-
images = text2img_gen(
|
| 593 |
-
prompt=prompt,
|
| 594 |
-
n_sample=n_sample,
|
| 595 |
-
guidance_scale=guidance_scale,
|
| 596 |
-
pipeline=pipeline,
|
| 597 |
-
ip_image=ip_image,
|
| 598 |
-
image_wh=image_wh,
|
| 599 |
-
infer_step=infer_step,
|
| 600 |
-
seed=seed,
|
| 601 |
-
)
|
| 602 |
-
|
| 603 |
-
for idx in range(len(images)):
|
| 604 |
-
image = images[idx]
|
| 605 |
-
images[idx], _ = preprocess_image_fn(
|
| 606 |
-
image, rmbg_tag, enable_pre_resize
|
| 607 |
-
)
|
| 608 |
-
|
| 609 |
-
save_paths = []
|
| 610 |
-
for idx, image in enumerate(images):
|
| 611 |
-
save_path = f"{output_root}/sample_{idx}.png"
|
| 612 |
-
image.save(save_path)
|
| 613 |
-
save_paths.append(save_path)
|
| 614 |
-
|
| 615 |
-
logger.info(f"Images saved to {output_root}")
|
| 616 |
-
|
| 617 |
-
gc.collect()
|
| 618 |
-
torch.cuda.empty_cache()
|
| 619 |
-
|
| 620 |
-
return save_paths + save_paths
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
@spaces.GPU(duration=120)
|
| 624 |
-
def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
|
| 625 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 626 |
-
|
| 627 |
-
_ = render_api(
|
| 628 |
-
mesh_path=mesh_path,
|
| 629 |
-
output_root=f"{output_root}/condition",
|
| 630 |
-
uuid=str(uuid),
|
| 631 |
-
)
|
| 632 |
-
|
| 633 |
-
gc.collect()
|
| 634 |
-
torch.cuda.empty_cache()
|
| 635 |
-
|
| 636 |
-
return None, None, None
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
@spaces.GPU(duration=300)
|
| 640 |
-
def generate_texture_mvimages(
|
| 641 |
-
prompt: str,
|
| 642 |
-
controlnet_cond_scale: float = 0.55,
|
| 643 |
-
guidance_scale: float = 9,
|
| 644 |
-
strength: float = 0.9,
|
| 645 |
-
num_inference_steps: int = 50,
|
| 646 |
-
seed: int = 0,
|
| 647 |
-
ip_adapt_scale: float = 0,
|
| 648 |
-
ip_img_path: str = None,
|
| 649 |
-
uid: str = "sample",
|
| 650 |
-
sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
|
| 651 |
-
req: gr.Request = None,
|
| 652 |
-
) -> list[str]:
|
| 653 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 654 |
-
use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
|
| 655 |
-
PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
|
| 656 |
-
img_save_paths = infer_pipe(
|
| 657 |
-
index_file=f"{output_root}/condition/index.json",
|
| 658 |
-
controlnet_cond_scale=controlnet_cond_scale,
|
| 659 |
-
guidance_scale=guidance_scale,
|
| 660 |
-
strength=strength,
|
| 661 |
-
num_inference_steps=num_inference_steps,
|
| 662 |
-
ip_adapt_scale=ip_adapt_scale,
|
| 663 |
-
ip_img_path=ip_img_path,
|
| 664 |
-
uid=uid,
|
| 665 |
-
prompt=prompt,
|
| 666 |
-
save_dir=f"{output_root}/multi_view",
|
| 667 |
-
sub_idxs=sub_idxs,
|
| 668 |
-
pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
|
| 669 |
-
seed=seed,
|
| 670 |
-
)
|
| 671 |
-
|
| 672 |
-
gc.collect()
|
| 673 |
-
torch.cuda.empty_cache()
|
| 674 |
-
|
| 675 |
-
return img_save_paths + img_save_paths
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
def backproject_texture(
|
| 679 |
-
mesh_path: str,
|
| 680 |
-
input_image: str,
|
| 681 |
-
texture_size: int,
|
| 682 |
-
uuid: str = "sample",
|
| 683 |
-
req: gr.Request = None,
|
| 684 |
-
) -> str:
|
| 685 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 686 |
-
output_dir = os.path.join(output_root, "texture_mesh")
|
| 687 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 688 |
-
command = [
|
| 689 |
-
"backproject-cli",
|
| 690 |
-
"--mesh_path",
|
| 691 |
-
mesh_path,
|
| 692 |
-
"--input_image",
|
| 693 |
-
input_image,
|
| 694 |
-
"--output_root",
|
| 695 |
-
output_dir,
|
| 696 |
-
"--uuid",
|
| 697 |
-
f"{uuid}",
|
| 698 |
-
"--texture_size",
|
| 699 |
-
str(texture_size),
|
| 700 |
-
"--skip_fix_mesh",
|
| 701 |
-
]
|
| 702 |
-
|
| 703 |
-
_ = subprocess.run(
|
| 704 |
-
command, capture_output=True, text=True, encoding="utf-8"
|
| 705 |
-
)
|
| 706 |
-
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
| 707 |
-
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
|
| 708 |
-
_ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
|
| 709 |
-
|
| 710 |
-
zip_file = zip_files(
|
| 711 |
-
input_paths=[
|
| 712 |
-
output_glb_mesh,
|
| 713 |
-
output_obj_mesh,
|
| 714 |
-
os.path.join(output_dir, "material.mtl"),
|
| 715 |
-
os.path.join(output_dir, "material_0.png"),
|
| 716 |
-
],
|
| 717 |
-
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
gc.collect()
|
| 721 |
-
torch.cuda.empty_cache()
|
| 722 |
-
|
| 723 |
-
return output_glb_mesh, output_obj_mesh, zip_file
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
@spaces.GPU(duration=300)
|
| 727 |
-
def backproject_texture_v2(
|
| 728 |
-
mesh_path: str,
|
| 729 |
-
input_image: str,
|
| 730 |
-
texture_size: int,
|
| 731 |
-
enable_delight: bool = True,
|
| 732 |
-
fix_mesh: bool = False,
|
| 733 |
-
no_mesh_post_process: bool = False,
|
| 734 |
-
uuid: str = "sample",
|
| 735 |
-
req: gr.Request = None,
|
| 736 |
-
) -> str:
|
| 737 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 738 |
-
output_dir = os.path.join(output_root, "texture_mesh")
|
| 739 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 740 |
-
|
| 741 |
-
textured_mesh = backproject_api(
|
| 742 |
-
delight_model=DELIGHT,
|
| 743 |
-
imagesr_model=IMAGESR_MODEL,
|
| 744 |
-
color_path=input_image,
|
| 745 |
-
mesh_path=mesh_path,
|
| 746 |
-
output_path=f"{output_dir}/{uuid}.obj",
|
| 747 |
-
skip_fix_mesh=not fix_mesh,
|
| 748 |
-
delight=enable_delight,
|
| 749 |
-
texture_wh=[texture_size, texture_size],
|
| 750 |
-
no_mesh_post_process=no_mesh_post_process,
|
| 751 |
-
)
|
| 752 |
-
|
| 753 |
-
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
|
| 754 |
-
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
|
| 755 |
-
_ = textured_mesh.export(output_glb_mesh)
|
| 756 |
-
|
| 757 |
-
zip_file = zip_files(
|
| 758 |
-
input_paths=[
|
| 759 |
-
output_glb_mesh,
|
| 760 |
-
output_obj_mesh,
|
| 761 |
-
os.path.join(output_dir, "material.mtl"),
|
| 762 |
-
os.path.join(output_dir, "material_0.png"),
|
| 763 |
-
],
|
| 764 |
-
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
|
| 765 |
-
)
|
| 766 |
-
|
| 767 |
-
gc.collect()
|
| 768 |
-
torch.cuda.empty_cache()
|
| 769 |
-
|
| 770 |
-
return output_glb_mesh, output_obj_mesh, zip_file
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
@spaces.GPU(duration=120)
|
| 774 |
-
def render_result_video(
|
| 775 |
-
mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
|
| 776 |
-
) -> str:
|
| 777 |
-
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 778 |
-
output_dir = os.path.join(output_root, "texture_mesh")
|
| 779 |
-
|
| 780 |
-
_ = render_api(
|
| 781 |
-
mesh_path=mesh_path,
|
| 782 |
-
output_root=output_dir,
|
| 783 |
-
num_images=90,
|
| 784 |
-
elevation=[20],
|
| 785 |
-
with_mtl=True,
|
| 786 |
-
pbr_light_factor=1,
|
| 787 |
-
uuid=str(uuid),
|
| 788 |
-
gen_color_mp4=True,
|
| 789 |
-
gen_glonormal_mp4=True,
|
| 790 |
-
distance=5.5,
|
| 791 |
-
resolution_hw=(video_size, video_size),
|
| 792 |
-
)
|
| 793 |
-
|
| 794 |
-
gc.collect()
|
| 795 |
-
torch.cuda.empty_cache()
|
| 796 |
-
|
| 797 |
-
return f"{output_dir}/color.mp4"
|
|
|
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
import gc
|
| 19 |
import logging
|
| 20 |
import os
|
|
|
|
| 28 |
import numpy as np
|
| 29 |
import torch
|
| 30 |
import trimesh
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
from embodied_gen.models.segment_model import (
|
| 34 |
BMGG14Remover,
|
| 35 |
RembgRemover,
|
| 36 |
SAMPredictor,
|
| 37 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 39 |
from embodied_gen.utils.process_media import (
|
| 40 |
filter_image_small_connected_components,
|
| 41 |
keep_largest_connected_component,
|
| 42 |
merge_images_video,
|
| 43 |
)
|
| 44 |
+
from embodied_gen.utils.tags import VERSION
|
| 45 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
logging.basicConfig(
|
| 48 |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
|
|
| 52 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
| 53 |
os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
|
| 54 |
MAX_SEED = 100000
|
| 55 |
+
|
|
|
|
|
|
|
|
|
|
| 56 |
if os.getenv("GRADIO_APP").startswith("imageto3d"):
|
| 57 |
RBG_REMOVER = RembgRemover()
|
| 58 |
RBG14_REMOVER = BMGG14Remover()
|
| 59 |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
TMP_DIR = os.path.join(
|
| 61 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 62 |
)
|
| 63 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def start_session(req: gr.Request) -> None:
|
|
|
|
| 179 |
|
| 180 |
return (image, masks), seg_image
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embodied_gen/utils/monkey_patch/sam3d.py
CHANGED
|
@@ -40,7 +40,7 @@ def monkey_patch_sam3d():
|
|
| 40 |
if sam3d_root not in sys.path:
|
| 41 |
sys.path.insert(0, sam3d_root)
|
| 42 |
|
| 43 |
-
def
|
| 44 |
"""Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
|
| 45 |
try:
|
| 46 |
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
|
@@ -202,7 +202,7 @@ def monkey_patch_sam3d():
|
|
| 202 |
|
| 203 |
InferencePipelinePointMap.run = patch_run
|
| 204 |
|
| 205 |
-
def
|
| 206 |
"""Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
|
| 207 |
import torch
|
| 208 |
|
|
@@ -380,7 +380,7 @@ def monkey_patch_sam3d():
|
|
| 380 |
|
| 381 |
InferencePipeline.__init__ = patch_init
|
| 382 |
|
| 383 |
-
#
|
| 384 |
-
#
|
| 385 |
|
| 386 |
return
|
|
|
|
| 40 |
if sam3d_root not in sys.path:
|
| 41 |
sys.path.insert(0, sam3d_root)
|
| 42 |
|
| 43 |
+
def patch_pointmap_infer_pipeline():
|
| 44 |
"""Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
|
| 45 |
try:
|
| 46 |
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
|
|
|
| 202 |
|
| 203 |
InferencePipelinePointMap.run = patch_run
|
| 204 |
|
| 205 |
+
def patch_infer_init():
|
| 206 |
"""Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
|
| 207 |
import torch
|
| 208 |
|
|
|
|
| 380 |
|
| 381 |
InferencePipeline.__init__ = patch_init
|
| 382 |
|
| 383 |
+
# patch_pointmap_infer_pipeline()
|
| 384 |
+
# patch_infer_init()
|
| 385 |
|
| 386 |
return
|