|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
from glob import glob |
|
|
from shutil import copy, copytree, rmtree |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import trimesh |
|
|
from PIL import Image |
|
|
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api |
|
|
from embodied_gen.data.utils import delete_dir, trellis_preprocess |
|
|
from embodied_gen.models.delight_model import DelightingModel |
|
|
from embodied_gen.models.gs_model import GaussianOperator |
|
|
from embodied_gen.models.segment_model import RembgRemover |
|
|
from embodied_gen.models.sr_model import ImageRealESRGAN |
|
|
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api |
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT |
|
|
from embodied_gen.utils.log import logger |
|
|
from embodied_gen.utils.process_media import merge_images_video |
|
|
from embodied_gen.utils.tags import VERSION |
|
|
from embodied_gen.utils.trender import render_video |
|
|
from embodied_gen.validators.quality_checkers import ( |
|
|
BaseChecker, |
|
|
ImageAestheticChecker, |
|
|
ImageSegChecker, |
|
|
MeshGeoChecker, |
|
|
) |
|
|
from embodied_gen.validators.urdf_convertor import URDFGenerator |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
current_dir = os.path.dirname(current_file_path) |
|
|
sys.path.append(os.path.join(current_dir, "../..")) |
|
|
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline |
|
|
|
|
|
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( |
|
|
"~/.cache/torch_extensions" |
|
|
) |
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" |
|
|
os.environ["SPCONV_ALGO"] = "native" |
|
|
random.seed(0) |
|
|
|
|
|
logger.info("Loading Image3D Models...") |
|
|
DELIGHT = DelightingModel() |
|
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) |
|
|
RBG_REMOVER = RembgRemover() |
|
|
PIPELINE = TrellisImageTo3DPipeline.from_pretrained( |
|
|
"microsoft/TRELLIS-image-large" |
|
|
) |
|
|
|
|
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT) |
|
|
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) |
|
|
AESTHETIC_CHECKER = ImageAestheticChecker() |
|
|
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") |
|
|
parser.add_argument( |
|
|
"--image_path", type=str, nargs="+", help="Path to the input images." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--image_root", type=str, help="Path to the input images folder." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_root", |
|
|
type=str, |
|
|
help="Root directory for saving outputs.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--height_range", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The hight in meter to restore the mesh real size.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mass_range", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The mass in kg to restore the mesh real weight.", |
|
|
) |
|
|
parser.add_argument("--asset_type", type=str, nargs="+", default=None) |
|
|
parser.add_argument("--skip_exists", action="store_true") |
|
|
parser.add_argument("--version", type=str, default=VERSION) |
|
|
parser.add_argument("--keep_intermediate", action="store_true") |
|
|
parser.add_argument("--seed", type=int, default=0) |
|
|
parser.add_argument( |
|
|
"--n_retry", |
|
|
type=int, |
|
|
default=2, |
|
|
) |
|
|
parser.add_argument("--disable_decompose_convex", action="store_true") |
|
|
parser.add_argument( |
|
|
"--texture_wh", type=int, nargs=2, default=[2048, 2048] |
|
|
) |
|
|
args, unknown = parser.parse_known_args() |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def entrypoint(**kwargs): |
|
|
args = parse_args() |
|
|
for k, v in kwargs.items(): |
|
|
if hasattr(args, k) and v is not None: |
|
|
setattr(args, k, v) |
|
|
|
|
|
assert ( |
|
|
args.image_path or args.image_root |
|
|
), "Please provide either --image_path or --image_root." |
|
|
if not args.image_path: |
|
|
args.image_path = glob(os.path.join(args.image_root, "*.png")) |
|
|
args.image_path += glob(os.path.join(args.image_root, "*.jpg")) |
|
|
args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) |
|
|
|
|
|
for idx, image_path in enumerate(args.image_path): |
|
|
try: |
|
|
filename = os.path.basename(image_path).split(".")[0] |
|
|
output_root = args.output_root |
|
|
if args.image_root is not None or len(args.image_path) > 1: |
|
|
output_root = os.path.join(output_root, filename) |
|
|
os.makedirs(output_root, exist_ok=True) |
|
|
|
|
|
mesh_out = f"{output_root}/{filename}.obj" |
|
|
if args.skip_exists and os.path.exists(mesh_out): |
|
|
logger.warning( |
|
|
f"Skip {image_path}, already processed in {mesh_out}" |
|
|
) |
|
|
continue |
|
|
|
|
|
image = Image.open(image_path) |
|
|
image.save(f"{output_root}/{filename}_raw.png") |
|
|
|
|
|
|
|
|
seg_path = f"{output_root}/{filename}_cond.png" |
|
|
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image |
|
|
seg_image = trellis_preprocess(seg_image) |
|
|
seg_image.save(seg_path) |
|
|
|
|
|
seed = args.seed |
|
|
asset_node = "unknown" |
|
|
if isinstance(args.asset_type, list) and args.asset_type[idx]: |
|
|
asset_node = args.asset_type[idx] |
|
|
for try_idx in range(args.n_retry): |
|
|
logger.info( |
|
|
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}" |
|
|
) |
|
|
|
|
|
try: |
|
|
PIPELINE.cuda() |
|
|
outputs = PIPELINE.run( |
|
|
seg_image, |
|
|
preprocess_image=False, |
|
|
seed=( |
|
|
random.randint(0, 100000) if seed is None else seed |
|
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
PIPELINE.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"[Pipeline Failed] process {image_path}: {e}, skip." |
|
|
) |
|
|
continue |
|
|
|
|
|
gs_model = outputs["gaussian"][0] |
|
|
mesh_model = outputs["mesh"][0] |
|
|
|
|
|
|
|
|
gs_path = mesh_out.replace(".obj", "_gs.ply") |
|
|
gs_model.save_ply(gs_path) |
|
|
|
|
|
|
|
|
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] |
|
|
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] |
|
|
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] |
|
|
|
|
|
|
|
|
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) |
|
|
pose = GaussianOperator.trans_to_quatpose(gs_rot) |
|
|
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") |
|
|
GaussianOperator.resave_ply( |
|
|
in_ply=gs_path, |
|
|
out_ply=aligned_gs_path, |
|
|
instance_pose=pose, |
|
|
device="cpu", |
|
|
) |
|
|
color_path = os.path.join(output_root, "color.png") |
|
|
render_gs_api( |
|
|
input_gs=aligned_gs_path, |
|
|
output_path=color_path, |
|
|
elevation=[20, -10, 60, -50], |
|
|
num_images=12, |
|
|
) |
|
|
|
|
|
color_img = Image.open(color_path) |
|
|
keep_height = int(color_img.height * 2 / 3) |
|
|
crop_img = color_img.crop((0, 0, color_img.width, keep_height)) |
|
|
geo_flag, geo_result = GEO_CHECKER([crop_img], text=asset_node) |
|
|
logger.warning( |
|
|
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}" |
|
|
) |
|
|
if geo_flag is True or geo_flag is None: |
|
|
break |
|
|
|
|
|
seed = random.randint(0, 100000) if seed is not None else None |
|
|
|
|
|
|
|
|
color_images = render_video(gs_model)["color"] |
|
|
normal_images = render_video(mesh_model)["normal"] |
|
|
video_path = os.path.join(output_root, "gs_mesh.mp4") |
|
|
merge_images_video(color_images, normal_images, video_path) |
|
|
|
|
|
mesh = trimesh.Trimesh( |
|
|
vertices=mesh_model.vertices.cpu().numpy(), |
|
|
faces=mesh_model.faces.cpu().numpy(), |
|
|
) |
|
|
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) |
|
|
mesh.vertices = mesh.vertices @ np.array(rot_matrix) |
|
|
|
|
|
mesh_obj_path = os.path.join(output_root, f"{filename}.obj") |
|
|
mesh.export(mesh_obj_path) |
|
|
|
|
|
mesh = backproject_api( |
|
|
delight_model=DELIGHT, |
|
|
imagesr_model=IMAGESR_MODEL, |
|
|
color_path=color_path, |
|
|
mesh_path=mesh_obj_path, |
|
|
output_path=mesh_obj_path, |
|
|
skip_fix_mesh=False, |
|
|
delight=True, |
|
|
texture_wh=args.texture_wh, |
|
|
elevation=[20, -10, 60, -50], |
|
|
num_images=12, |
|
|
) |
|
|
|
|
|
mesh_glb_path = os.path.join(output_root, f"{filename}.glb") |
|
|
mesh.export(mesh_glb_path) |
|
|
|
|
|
urdf_convertor = URDFGenerator( |
|
|
GPT_CLIENT, |
|
|
render_view_num=4, |
|
|
decompose_convex=not args.disable_decompose_convex, |
|
|
) |
|
|
asset_attrs = { |
|
|
"version": VERSION, |
|
|
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", |
|
|
} |
|
|
if args.height_range: |
|
|
min_height, max_height = map( |
|
|
float, args.height_range.split("-") |
|
|
) |
|
|
asset_attrs["min_height"] = min_height |
|
|
asset_attrs["max_height"] = max_height |
|
|
if args.mass_range: |
|
|
min_mass, max_mass = map(float, args.mass_range.split("-")) |
|
|
asset_attrs["min_mass"] = min_mass |
|
|
asset_attrs["max_mass"] = max_mass |
|
|
if isinstance(args.asset_type, list) and args.asset_type[idx]: |
|
|
asset_attrs["category"] = args.asset_type[idx] |
|
|
if args.version: |
|
|
asset_attrs["version"] = args.version |
|
|
|
|
|
urdf_root = f"{output_root}/URDF_{filename}" |
|
|
urdf_path = urdf_convertor( |
|
|
mesh_path=mesh_obj_path, |
|
|
output_root=urdf_root, |
|
|
**asset_attrs, |
|
|
) |
|
|
|
|
|
|
|
|
real_height = urdf_convertor.get_attr_from_urdf( |
|
|
urdf_path, attr_name="real_height" |
|
|
) |
|
|
out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" |
|
|
GaussianOperator.resave_ply( |
|
|
in_ply=aligned_gs_path, |
|
|
out_ply=out_gs, |
|
|
real_height=real_height, |
|
|
device="cpu", |
|
|
) |
|
|
|
|
|
|
|
|
mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" |
|
|
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) |
|
|
|
|
|
image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" |
|
|
image_paths = glob(f"{image_dir}/*.png") |
|
|
images_list = [] |
|
|
for checker in CHECKERS: |
|
|
images = image_paths |
|
|
if isinstance(checker, ImageSegChecker): |
|
|
images = [ |
|
|
f"{output_root}/{filename}_raw.png", |
|
|
f"{output_root}/{filename}_cond.png", |
|
|
] |
|
|
images_list.append(images) |
|
|
|
|
|
qa_results = BaseChecker.validate(CHECKERS, images_list) |
|
|
urdf_convertor.add_quality_tag(urdf_path, qa_results) |
|
|
|
|
|
|
|
|
result_dir = f"{output_root}/result" |
|
|
if os.path.exists(result_dir): |
|
|
rmtree(result_dir, ignore_errors=True) |
|
|
os.makedirs(result_dir, exist_ok=True) |
|
|
copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}") |
|
|
copytree( |
|
|
f"{urdf_root}/{urdf_convertor.output_mesh_dir}", |
|
|
f"{result_dir}/{urdf_convertor.output_mesh_dir}", |
|
|
) |
|
|
copy(video_path, f"{result_dir}/video.mp4") |
|
|
if not args.keep_intermediate: |
|
|
delete_dir(output_root, keep_subs=["result"]) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to process {image_path}: {e}, skip.") |
|
|
continue |
|
|
|
|
|
logger.info(f"Processing complete. Outputs saved to {args.output_root}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
entrypoint() |
|
|
|