import os import sys import uuid import shutil import gradio as gr import numpy as np import torch from huggingface_hub import snapshot_download, hf_hub_download import trimesh from urllib.parse import urlparse from urllib.request import urlretrieve, Request, urlopen import re import base64 from io import BytesIO from PIL import Image try: import gradio_client.utils as gc_utils _orig_get_type = gc_utils.get_type _orig_json_schema = gc_utils._json_schema_to_python_type def _safe_get_type(schema): if isinstance(schema, bool): return "Any" return _orig_get_type(schema) def _safe_json_schema(schema, defs=None): if isinstance(schema, bool): return "Any" return _orig_json_schema(schema, defs) gc_utils.get_type = _safe_get_type gc_utils._json_schema_to_python_type = _safe_json_schema except Exception: pass try: import spaces gpu = spaces.GPU except Exception: def gpu(*_args, **_kwargs): def _wrap(fn): return fn return _wrap DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git" TRIPOSG_CODE_DIR = "./triposg" CHECKPOINT_DIR = "checkpoints" RMBG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "RMBG-1.4") TRIPOSG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "TripoSG") TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") os.makedirs(TMP_DIR, exist_ok=True) if not os.path.exists(TRIPOSG_CODE_DIR): os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}") sys.path.append(TRIPOSG_CODE_DIR) sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts")) from image_process import prepare_image from briarmbg import BriaRMBG from triposg.pipelines.pipeline_triposg import TripoSGPipeline from utils import simplify_mesh snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL) rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE) rmbg_net.eval() snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL) triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to( DEVICE, DTYPE ) def _session_dir(req: gr.Request | None) -> str: if req is None: return TMP_DIR save_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(save_dir, exist_ok=True) return save_dir def _unique_glb_path(save_dir: str) -> str: return os.path.join(save_dir, f"triposg_{uuid.uuid4().hex}.glb") def _resolve_image_path(image_input, save_dir: str) -> str: if isinstance(image_input, dict): image_input = image_input.get("path") or image_input.get("url") if not image_input: raise gr.Error("Upload an image first.") if isinstance(image_input, Image.Image): local_path = os.path.join(save_dir, "input.png") image_input.save(local_path) return local_path if isinstance(image_input, str) and os.path.exists(image_input): return image_input if isinstance(image_input, str) and image_input.startswith(("http://", "https://")): hf_match = re.match( r"^https?://huggingface\.co/(?:(?Pspaces|datasets)/)?(?P[^/]+/[^/]+)/resolve/(?P[^/]+)/(?P.+)$", image_input, ) if hf_match: repo_id = hf_match.group("repo") filename = hf_match.group("path") revision = hf_match.group("rev") repo_type = hf_match.group("type")[:-1] if hf_match.group("type") else "model" try: return hf_hub_download( repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, local_dir=save_dir, ) except Exception: pass parsed = urlparse(image_input) suffix = os.path.splitext(parsed.path)[1] or ".png" local_path = os.path.join(save_dir, f"input{suffix}") try: req = Request(image_input, headers={"User-Agent": "Mozilla/5.0"}) with urlopen(req) as resp, open(local_path, "wb") as out: out.write(resp.read()) except Exception: try: urlretrieve(image_input, local_path) except Exception as err: raise gr.Error(f"Failed to download image URL: {err}") from err return local_path if isinstance(image_input, str) and image_input.startswith("data:image"): header, b64_data = image_input.split(",", 1) image = Image.open(BytesIO(base64.b64decode(b64_data))) local_path = os.path.join(save_dir, "input.png") image.save(local_path) return local_path return image_input def _run_triposg( image_path: str, seed: int, num_inference_steps: int, guidance_scale: float, simplify: bool, target_face_num: int, req: gr.Request | None = None, ): save_dir = _session_dir(req) image_path = _resolve_image_path(image_path, save_dir) image_seg = prepare_image( image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net ) generator = torch.Generator(device=triposg_pipe.device).manual_seed(seed) outputs = triposg_pipe( image=image_seg, generator=generator, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).samples[0] mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1])) if simplify: mesh = simplify_mesh(mesh, target_face_num) mesh_path = _unique_glb_path(save_dir) mesh.export(mesh_path) return image_seg, mesh_path @gpu(duration=60) @torch.no_grad() def generate_mesh( image_path: str, seed: int, num_inference_steps: int, guidance_scale: float, simplify: bool, target_face_num: int, req: gr.Request | None = None, ): image_seg, mesh_path = _run_triposg( image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return image_seg, mesh_path @gpu(duration=60) @torch.no_grad() def api_generate( image_path: str, seed: int, num_inference_steps: int, guidance_scale: float, simplify: bool, target_face_num: int, req: gr.Request | None = None, ): _, mesh_path = _run_triposg( image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return mesh_path @gpu(duration=60) @torch.no_grad() def api_generate_text( image_text: str, seed: int, num_inference_steps: int, guidance_scale: float, simplify: bool, target_face_num: int, req: gr.Request | None = None, ): _, mesh_path = _run_triposg( image_text, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return mesh_path def _cleanup_session(req: gr.Request): save_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(save_dir): shutil.rmtree(save_dir) TITLE = "TripoSG Image-to-3D API" DESCRIPTION = ( "Upload a single-object image to generate a 3D mesh (GLB). " "This demo exposes a /predict API endpoint." ) with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"# {TITLE}\n\n{DESCRIPTION}") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image", type="filepath") seg_output = gr.Image( label="Segmentation Preview", type="pil", format="png" ) with gr.Accordion("Generation Settings", open=True): seed = gr.Slider( label="Seed", minimum=0, maximum=2**31 - 1, step=1, value=0 ) steps = gr.Slider( label="Inference Steps", minimum=8, maximum=50, step=1, value=50 ) guidance = gr.Slider( label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5 ) simplify = gr.Checkbox(label="Simplify Mesh", value=True) face_count = gr.Slider( label="Target Face Count", minimum=10000, maximum=1000000, step=1000, value=100000, ) generate_btn = gr.Button("Generate 3D", variant="primary") with gr.Column(): model_output = gr.Model3D(label="Generated GLB", interactive=False) file_output = gr.File(label="Download GLB", interactive=False) api_image_text = gr.Textbox(visible=False) generate_btn.click( generate_mesh, inputs=[image_input, seed, steps, guidance, simplify, face_count], outputs=[seg_output, model_output], api_name=False, ).then(lambda path: path, inputs=model_output, outputs=file_output, api_name=False) api_btn = gr.Button(visible=False) api_btn.click( api_generate_text, inputs=[api_image_text, seed, steps, guidance, simplify, face_count], outputs=[file_output], api_name="predict", ) demo.unload(_cleanup_session) demo.launch(server_name="0.0.0.0", server_port=7860)