prompt2shape / common.py
thisaraniana's picture
final output
9b1d80d
# Project Prompt2Shape
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import spaces
from prompt2shape.utils.monkey_patch.trellis import monkey_path_trellis
monkey_path_trellis()
from prompt2shape.utils.monkey_patch.gradio import (
_patch_open3d_cuda_device_count_bug,
)
_patch_open3d_cuda_device_count_bug()
import gc
import logging
import os
import shutil
import subprocess
import sys
from glob import glob
import cv2
import gradio as gr
import numpy as np
import torch
import trimesh
from PIL import Image
from prompt2shape.data.backproject_v2 import entrypoint as backproject_api
from prompt2shape.data.backproject_v3 import entrypoint as backproject_api_v3
from prompt2shape.data.differentiable_render import entrypoint as render_api
from prompt2shape.data.utils import trellis_preprocess, zip_files
from prompt2shape.models.delight_model import DelightingModel
from prompt2shape.models.gs_model import GaussianOperator
from prompt2shape.models.sam3d import Sam3dInference
from prompt2shape.models.segment_model import (
BMGG14Remover,
RembgRemover,
SAMPredictor,
)
from prompt2shape.models.sr_model import ImageRealESRGAN, ImageStableSR
from prompt2shape.scripts.render_gs import entrypoint as render_gs_api
from prompt2shape.scripts.render_mv import build_texture_gen_pipe, infer_pipe
from prompt2shape.models.text_model import (
build_text2img_ip_pipeline,
build_text2img_pipeline,
text2img_gen,
download_kolors_weights,
)
from prompt2shape.utils.gpt_clients import GPT_CLIENT
from prompt2shape.utils.process_media import (
filter_image_small_connected_components,
keep_largest_connected_component,
merge_images_video,
)
from prompt2shape.utils.tags import VERSION
from prompt2shape.utils.trender import pack_state, render_video, unpack_state
from prompt2shape.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
)
from prompt2shape.validators.urdf_convertor import URDFGenerator
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(current_dir)
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
MAX_SEED = 100000
# Global placeholders
PIPELINE = None
PIPELINE_IMG_IP = None
PIPELINE_IMG = None
PIPELINE_IP = None
PIPELINE_EDIT = None
RBG_REMOVER = None
RBG14_REMOVER = None
SAM_PREDICTOR = None
SEG_CHECKER = None
GEO_CHECKER = None
AESTHETIC_CHECKER = None
CHECKERS = []
DELIGHT = None
IMAGESR_MODEL = None
def get_image_pipeline():
global PIPELINE_IMG_IP, PIPELINE_IMG
gradio_app = os.getenv("GRADIO_APP", "")
if gradio_app.startswith("textto3d"):
text_model_dir = "weights/Kolors"
if PIPELINE_IMG_IP is None:
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
if PIPELINE_IMG is None:
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
return PIPELINE_IMG, PIPELINE_IMG_IP
def get_3d_pipeline():
global PIPELINE
gradio_app = os.getenv("GRADIO_APP", "")
if PIPELINE is None:
if gradio_app.startswith("imageto3d") or gradio_app.startswith("textto3d"):
if "sam3d" in gradio_app:
PIPELINE = Sam3dInference(device="cuda")
else:
PIPELINE = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
return PIPELINE
def get_texture_edit_pipeline():
global PIPELINE_IP, PIPELINE_EDIT, DELIGHT, IMAGESR_MODEL
if DELIGHT is None:
DELIGHT = DelightingModel()
if IMAGESR_MODEL is None:
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
if PIPELINE_IP is None:
PIPELINE_IP = build_texture_gen_pipe(
base_ckpt_dir="./weights",
ip_adapt_scale=0.7,
device="cuda",
)
if PIPELINE_EDIT is None:
PIPELINE_EDIT = build_texture_gen_pipe(
base_ckpt_dir="./weights",
ip_adapt_scale=0,
device="cuda",
)
return PIPELINE_IP, PIPELINE_EDIT
def get_removers():
global RBG_REMOVER, RBG14_REMOVER
if RBG_REMOVER is None:
RBG_REMOVER = RembgRemover()
if RBG14_REMOVER is None:
RBG14_REMOVER = BMGG14Remover()
return RBG_REMOVER, RBG14_REMOVER
def get_sam_predictor():
global SAM_PREDICTOR
if SAM_PREDICTOR is None:
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
return SAM_PREDICTOR
def get_checkers():
global SEG_CHECKER, GEO_CHECKER, AESTHETIC_CHECKER, CHECKERS
if SEG_CHECKER is None:
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
if GEO_CHECKER is None:
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
if AESTHETIC_CHECKER is None:
AESTHETIC_CHECKER = ImageAestheticChecker()
if not CHECKERS:
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
return CHECKERS
def download_all_models():
"""Pre-downloads all model weights at startup to avoid timeouts in ZeroGPU functions."""
logger.info("Initializing pre-download of all required models...")
# download_kolors_weights
try:
download_kolors_weights("weights/Kolors")
except Exception as e:
logger.error(f"Error downloading Kolors weights: {e}")
# download sam-3d-objects
sam3d_local_dir = "weights/sam-3d-objects"
if not os.path.exists(sam3d_local_dir):
try:
from huggingface_hub import snapshot_download
snapshot_download("tuandao-zenai/sam-3d-objects", local_dir=sam3d_local_dir)
except Exception as e:
logger.error(f"Error downloading SAM-3D weights: {e}")
# download trellis
try:
from huggingface_hub import snapshot_download
snapshot_download("microsoft/TRELLIS-image-large", local_dir="weights/TRELLIS-image-large")
except Exception as e:
logger.warning(f"Error pre-downloading Trellis (might be handled by diffusers): {e}")
logger.info("Pre-download process completed.")
# Initialize environment and pre-download models
gradio_app_name = os.getenv("GRADIO_APP", "")
if gradio_app_name.startswith("imageto3d") or gradio_app_name.startswith("textto3d") or gradio_app_name == "texture_edit":
download_all_models()
if gradio_app_name.startswith("imageto3d"):
app_type = "imageto3d"
elif gradio_app_name.startswith("textto3d"):
app_type = "textto3d"
elif gradio_app_name == "texture_edit":
app_type = "texture_edit"
else:
app_type = "default"
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"sessions/{app_type}")
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request) -> None:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request) -> None:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
if os.path.exists(user_dir):
shutil.rmtree(user_dir)
def preprocess_image_fn(
image: str | np.ndarray | Image.Image,
rmbg_tag: str = "rembg",
preprocess: bool = True,
) -> tuple[Image.Image, Image.Image]:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_cache = image.copy() # resize_pil(image.copy(), 1024)
rbg_remover, rbg14_remover = get_removers()
bg_remover = rbg_remover if rmbg_tag == "rembg" else rbg14_remover
image = bg_remover(image)
image = keep_largest_connected_component(image)
if preprocess:
image = trellis_preprocess(image)
return image, image_cache
def preprocess_sam_image_fn(
image: Image.Image,
) -> tuple[Image.Image, Image.Image]:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
sam_predictor = get_sam_predictor()
sam_image = sam_predictor.preprocess_image(image)
image_cache = sam_image.copy()
sam_predictor.predictor.set_image(sam_image)
return sam_image, image_cache
def active_btn_by_content(content: gr.Image) -> gr.Button:
interactive = True if content is not None else False
return gr.Button(interactive=interactive)
def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
if content is not None and len(content) > 0:
interactive = True
else:
interactive = False
return gr.Button(interactive=interactive)
def get_selected_image(
choice: str, sample1: str, sample2: str, sample3: str
) -> str:
if choice == "sample1":
return sample1
elif choice == "sample2":
return sample2
elif choice == "sample3":
return sample3
else:
raise ValueError(f"Invalid choice: {choice}")
def get_cached_image(image_path: str) -> Image.Image:
if isinstance(image_path, Image.Image):
return image_path
return Image.open(image_path).resize((512, 512))
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
return np.random.randint(0, max_seed) if randomize_seed else seed
def select_point(
image: np.ndarray,
sel_pix: list,
point_type: str,
evt: gr.SelectData,
):
if point_type == "foreground_point":
sel_pix.append((evt.index, 1)) # append the foreground_point
elif point_type == "background_point":
sel_pix.append((evt.index, 0)) # append the background_point
else:
sel_pix.append((evt.index, 1)) # default foreground_point
sam_predictor = get_sam_predictor()
masks = sam_predictor.generate_masks(image, sel_pix)
seg_image = sam_predictor.get_segmented_image(image, masks)
for point, label in sel_pix:
color = (255, 0, 0) if label == 0 else (0, 255, 0)
marker_type = 1 if label == 0 else 5
cv2.drawMarker(
image,
point,
color,
markerType=marker_type,
markerSize=15,
thickness=10,
)
torch.cuda.empty_cache()
return (image, masks), seg_image
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_sampling_steps: int,
slat_sampling_steps: int,
raw_image_cache: Image.Image,
ss_guidance_strength: float,
slat_guidance_strength: float,
sam_image: Image.Image = None,
is_sam_image: bool = False,
req: gr.Request = None,
) -> tuple[object, str]:
if is_sam_image:
seg_image = filter_image_small_connected_components(sam_image)
seg_image = Image.fromarray(seg_image, mode="RGBA")
else:
seg_image = image
if isinstance(seg_image, np.ndarray):
seg_image = Image.fromarray(seg_image)
logger.info("Start generating 3D representation from image...")
pipeline = get_3d_pipeline()
if isinstance(pipeline, Sam3dInference):
outputs = pipeline.run(
seg_image,
seed=seed,
stage1_inference_steps=ss_sampling_steps,
stage2_inference_steps=slat_sampling_steps,
)
else:
pipeline.cuda()
seg_image = trellis_preprocess(seg_image)
outputs = pipeline.run(
seg_image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Set back to cpu for memory saving.
PIPELINE.cpu()
gs_model = outputs["gaussian"][0]
mesh_model = outputs["mesh"][0]
color_images = render_video(gs_model, r=1.85)["color"]
normal_images = render_video(mesh_model, r=1.85)["normal"]
output_root = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
seg_image.save(f"{output_root}/seg_image.png")
raw_image_cache.save(f"{output_root}/raw_image.png")
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
state = pack_state(gs_model, mesh_model)
gc.collect()
torch.cuda.empty_cache()
return state, video_path
def extract_3d_representations_v2(
state: object,
enable_delight: bool,
texture_size: int,
req: gr.Request,
):
"""Back-Projection Version of Texture Super-Resolution."""
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(user_dir, "color.png")
render_gs_api(
input_gs=aligned_gs_path,
output_path=color_path,
elevation=[20, -10, 60, -50],
num_images=12,
)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
_, _ = get_texture_edit_pipeline() # Ensure DELIGHT and IMAGESR_MODEL are loaded
mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=color_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
delight=enable_delight,
texture_wh=[texture_size, texture_size],
elevation=[20, -10, 60, -50],
num_images=12,
)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_3d_representations_v3(
state: object,
enable_delight: bool,
texture_size: int,
req: gr.Request,
):
"""Back-Projection Version with Optimization-Based."""
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api_v3(
gs_path=aligned_gs_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
texture_size=texture_size,
)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_urdf(
gs_path: str,
mesh_obj_path: str,
asset_cat_text: str,
height_range_text: str,
mass_range_text: str,
asset_version_text: str,
req: gr.Request = None,
):
output_root = TMP_DIR
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
# Convert to URDF and recover attrs by GPT.
filename = "sample"
urdf_convertor = URDFGenerator(
GPT_CLIENT, render_view_num=4, decompose_convex=True
)
asset_attrs = {
"version": VERSION,
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
}
if asset_version_text:
asset_attrs["version"] = asset_version_text
if asset_cat_text:
asset_attrs["category"] = asset_cat_text.lower()
if height_range_text:
try:
min_height, max_height = map(float, height_range_text.split("-"))
asset_attrs["min_height"] = min_height
asset_attrs["max_height"] = max_height
except ValueError:
return "Invalid height input format. Use the format: min-max."
if mass_range_text:
try:
min_mass, max_mass = map(float, mass_range_text.split("-"))
asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass
except ValueError:
return "Invalid mass input format. Use the format: min-max."
urdf_path = urdf_convertor(
mesh_path=mesh_obj_path,
output_root=f"{output_root}/URDF_{filename}",
**asset_attrs,
)
# Rescale GS and save to URDF/mesh folder.
real_height = urdf_convertor.get_attr_from_urdf(
urdf_path, attr_name="real_height"
)
out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=out_gs,
real_height=real_height,
device="cpu",
)
# Quality check and update .urdf file.
mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
# image_paths = render_asset3d(
# mesh_path=mesh_out,
# output_root=f"{output_root}/URDF_{filename}",
# output_subdir="qa_renders",
# num_images=8,
# elevation=(30, -30),
# distance=5.5,
# )
image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in CHECKERS:
images = image_paths
if isinstance(checker, ImageSegChecker):
images = [
f"{TMP_DIR}/{req.session_hash}/raw_image.png",
f"{TMP_DIR}/{req.session_hash}/seg_image.png",
]
images_list.append(images)
checkers = get_checkers()
results = BaseChecker.validate(checkers, images_list)
urdf_convertor.add_quality_tag(urdf_path, results)
# Zip urdf files
urdf_zip = zip_files(
input_paths=[
f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
f"{output_root}/URDF_{filename}/{filename}.urdf",
],
output_zip=f"{output_root}/urdf_{filename}.zip",
)
estimated_type = urdf_convertor.estimated_attrs["category"]
estimated_height = urdf_convertor.estimated_attrs["height"]
estimated_mass = urdf_convertor.estimated_attrs["mass"]
estimated_mu = urdf_convertor.estimated_attrs["mu"]
return (
urdf_zip,
estimated_type,
estimated_height,
estimated_mass,
estimated_mu,
)
@spaces.GPU
def text2image_fn(
prompt: str,
guidance_scale: float,
infer_step: int = 50,
ip_image: Image.Image | str = None,
ip_adapt_scale: float = 0.3,
image_wh: int | tuple[int, int] = [1024, 1024],
rmbg_tag: str = "rembg",
seed: int = None,
enable_pre_resize: bool = True,
n_sample: int = 1,
req: gr.Request = None,
):
if isinstance(image_wh, int):
image_wh = (image_wh, image_wh)
output_root = TMP_DIR
if req is not None:
output_root = os.path.join(output_root, str(req.session_hash))
os.makedirs(output_root, exist_ok=True)
pipeline_img, pipeline_img_ip = get_image_pipeline()
pipeline = pipeline_img if ip_image is None else pipeline_img_ip
if ip_image is not None:
pipeline.set_ip_adapter_scale([ip_adapt_scale])
images = text2img_gen(
prompt=prompt,
n_sample=n_sample,
guidance_scale=guidance_scale,
pipeline=pipeline,
ip_image=ip_image,
image_wh=image_wh,
infer_step=infer_step,
seed=seed,
)
for idx in range(len(images)):
image = images[idx]
images[idx], _ = preprocess_image_fn(
image, rmbg_tag, enable_pre_resize
)
save_paths = []
for idx, image in enumerate(images):
save_path = f"{output_root}/sample_{idx}.png"
image.save(save_path)
save_paths.append(save_path)
logger.info(f"Images saved to {output_root}")
gc.collect()
torch.cuda.empty_cache()
# Ensure we always return exactly 6 values (3 for gallery, 3 for hidden state)
full_save_paths = save_paths + [None] * (3 - len(save_paths))
return full_save_paths + full_save_paths
@spaces.GPU
def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
output_root = os.path.join(TMP_DIR, str(req.session_hash))
_ = render_api(
mesh_path=mesh_path,
output_root=f"{output_root}/condition",
uuid=str(uuid),
)
gc.collect()
torch.cuda.empty_cache()
return None, None, None
@spaces.GPU
def generate_texture_mvimages(
prompt: str,
controlnet_cond_scale: float = 0.55,
guidance_scale: float = 9,
strength: float = 0.9,
num_inference_steps: int = 50,
seed: int = 0,
ip_adapt_scale: float = 0,
ip_img_path: str = None,
uid: str = "sample",
sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
req: gr.Request = None,
) -> list[str]:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
pipeline_ip, pipeline_edit = get_texture_edit_pipeline()
pipeline_ip.set_ip_adapter_scale([ip_adapt_scale])
img_save_paths = infer_pipe(
index_file=f"{output_root}/condition/index.json",
controlnet_cond_scale=controlnet_cond_scale,
guidance_scale=guidance_scale,
strength=strength,
num_inference_steps=num_inference_steps,
ip_adapt_scale=ip_adapt_scale,
ip_img_path=ip_img_path,
uid=uid,
prompt=prompt,
save_dir=f"{output_root}/multi_view",
sub_idxs=sub_idxs,
pipeline=pipeline_ip if use_ip_adapter else pipeline_edit,
seed=seed,
)
gc.collect()
torch.cuda.empty_cache()
return img_save_paths + img_save_paths
def backproject_texture(
mesh_path: str,
input_image: str,
texture_size: int,
uuid: str = "sample",
req: gr.Request = None,
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
os.makedirs(output_dir, exist_ok=True)
command = [
"backproject-cli",
"--mesh_path",
mesh_path,
"--input_image",
input_image,
"--output_root",
output_dir,
"--uuid",
f"{uuid}",
"--texture_size",
str(texture_size),
"--skip_fix_mesh",
]
_ = subprocess.run(
command, capture_output=True, text=True, encoding="utf-8"
)
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
_ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
zip_file = zip_files(
input_paths=[
output_glb_mesh,
output_obj_mesh,
os.path.join(output_dir, "material.mtl"),
os.path.join(output_dir, "material_0.png"),
],
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
)
gc.collect()
torch.cuda.empty_cache()
return output_glb_mesh, output_obj_mesh, zip_file
@spaces.GPU
def backproject_texture_v2(
mesh_path: str,
input_image: str,
texture_size: int,
enable_delight: bool = True,
fix_mesh: bool = False,
no_mesh_post_process: bool = False,
uuid: str = "sample",
req: gr.Request = None,
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
os.makedirs(output_dir, exist_ok=True)
textured_mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=input_image,
mesh_path=mesh_path,
output_path=f"{output_dir}/{uuid}.obj",
skip_fix_mesh=not fix_mesh,
delight=enable_delight,
texture_wh=[texture_size, texture_size],
no_mesh_post_process=no_mesh_post_process,
)
output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
_ = textured_mesh.export(output_glb_mesh)
zip_file = zip_files(
input_paths=[
output_glb_mesh,
output_obj_mesh,
os.path.join(output_dir, "material.mtl"),
os.path.join(output_dir, "material_0.png"),
],
output_zip=os.path.join(output_dir, f"{uuid}.zip"),
)
gc.collect()
torch.cuda.empty_cache()
return output_glb_mesh, output_obj_mesh, zip_file
@spaces.GPU
def render_result_video(
mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
) -> str:
output_root = os.path.join(TMP_DIR, str(req.session_hash))
output_dir = os.path.join(output_root, "texture_mesh")
_ = render_api(
mesh_path=mesh_path,
output_root=output_dir,
num_images=90,
elevation=[20],
with_mtl=True,
pbr_light_factor=1,
uuid=str(uuid),
gen_color_mp4=True,
gen_glonormal_mp4=True,
distance=5.5,
resolution_hw=(video_size, video_size),
)
gc.collect()
torch.cuda.empty_cache()
return f"{output_dir}/color.mp4"