Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import uuid | |
| import shutil | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import trimesh | |
| from urllib.parse import urlparse | |
| from urllib.request import urlretrieve, Request, urlopen | |
| import re | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| try: | |
| import gradio_client.utils as gc_utils | |
| _orig_get_type = gc_utils.get_type | |
| _orig_json_schema = gc_utils._json_schema_to_python_type | |
| def _safe_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_get_type(schema) | |
| def _safe_json_schema(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_json_schema(schema, defs) | |
| gc_utils.get_type = _safe_get_type | |
| gc_utils._json_schema_to_python_type = _safe_json_schema | |
| except Exception: | |
| pass | |
| try: | |
| import spaces | |
| gpu = spaces.GPU | |
| except Exception: | |
| def gpu(*_args, **_kwargs): | |
| def _wrap(fn): | |
| return fn | |
| return _wrap | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git" | |
| TRIPOSG_CODE_DIR = "./triposg" | |
| CHECKPOINT_DIR = "checkpoints" | |
| RMBG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "RMBG-1.4") | |
| TRIPOSG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "TripoSG") | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| if not os.path.exists(TRIPOSG_CODE_DIR): | |
| os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}") | |
| sys.path.append(TRIPOSG_CODE_DIR) | |
| sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts")) | |
| from image_process import prepare_image | |
| from briarmbg import BriaRMBG | |
| from triposg.pipelines.pipeline_triposg import TripoSGPipeline | |
| from utils import simplify_mesh | |
| snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL) | |
| rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE) | |
| rmbg_net.eval() | |
| snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL) | |
| triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to( | |
| DEVICE, DTYPE | |
| ) | |
| def _session_dir(req: gr.Request | None) -> str: | |
| if req is None: | |
| return TMP_DIR | |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(save_dir, exist_ok=True) | |
| return save_dir | |
| def _unique_glb_path(save_dir: str) -> str: | |
| return os.path.join(save_dir, f"triposg_{uuid.uuid4().hex}.glb") | |
| def _resolve_image_path(image_input, save_dir: str) -> str: | |
| if isinstance(image_input, dict): | |
| image_input = image_input.get("path") or image_input.get("url") | |
| if not image_input: | |
| raise gr.Error("Upload an image first.") | |
| if isinstance(image_input, Image.Image): | |
| local_path = os.path.join(save_dir, "input.png") | |
| image_input.save(local_path) | |
| return local_path | |
| if isinstance(image_input, str) and os.path.exists(image_input): | |
| return image_input | |
| if isinstance(image_input, str) and image_input.startswith(("http://", "https://")): | |
| hf_match = re.match( | |
| r"^https?://huggingface\.co/(?:(?P<type>spaces|datasets)/)?(?P<repo>[^/]+/[^/]+)/resolve/(?P<rev>[^/]+)/(?P<path>.+)$", | |
| image_input, | |
| ) | |
| if hf_match: | |
| repo_id = hf_match.group("repo") | |
| filename = hf_match.group("path") | |
| revision = hf_match.group("rev") | |
| repo_type = hf_match.group("type")[:-1] if hf_match.group("type") else "model" | |
| try: | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| repo_type=repo_type, | |
| revision=revision, | |
| local_dir=save_dir, | |
| ) | |
| except Exception: | |
| pass | |
| parsed = urlparse(image_input) | |
| suffix = os.path.splitext(parsed.path)[1] or ".png" | |
| local_path = os.path.join(save_dir, f"input{suffix}") | |
| try: | |
| req = Request(image_input, headers={"User-Agent": "Mozilla/5.0"}) | |
| with urlopen(req) as resp, open(local_path, "wb") as out: | |
| out.write(resp.read()) | |
| except Exception: | |
| try: | |
| urlretrieve(image_input, local_path) | |
| except Exception as err: | |
| raise gr.Error(f"Failed to download image URL: {err}") from err | |
| return local_path | |
| if isinstance(image_input, str) and image_input.startswith("data:image"): | |
| header, b64_data = image_input.split(",", 1) | |
| image = Image.open(BytesIO(base64.b64decode(b64_data))) | |
| local_path = os.path.join(save_dir, "input.png") | |
| image.save(local_path) | |
| return local_path | |
| return image_input | |
| def _run_triposg( | |
| image_path: str, | |
| seed: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| simplify: bool, | |
| target_face_num: int, | |
| req: gr.Request | None = None, | |
| ): | |
| save_dir = _session_dir(req) | |
| image_path = _resolve_image_path(image_path, save_dir) | |
| image_seg = prepare_image( | |
| image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net | |
| ) | |
| generator = torch.Generator(device=triposg_pipe.device).manual_seed(seed) | |
| outputs = triposg_pipe( | |
| image=image_seg, | |
| generator=generator, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ).samples[0] | |
| mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1])) | |
| if simplify: | |
| mesh = simplify_mesh(mesh, target_face_num) | |
| mesh_path = _unique_glb_path(save_dir) | |
| mesh.export(mesh_path) | |
| return image_seg, mesh_path | |
| def generate_mesh( | |
| image_path: str, | |
| seed: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| simplify: bool, | |
| target_face_num: int, | |
| req: gr.Request | None = None, | |
| ): | |
| image_seg, mesh_path = _run_triposg( | |
| image_path, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale, | |
| simplify, | |
| target_face_num, | |
| req, | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return image_seg, mesh_path | |
| def api_generate( | |
| image_path: str, | |
| seed: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| simplify: bool, | |
| target_face_num: int, | |
| req: gr.Request | None = None, | |
| ): | |
| _, mesh_path = _run_triposg( | |
| image_path, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale, | |
| simplify, | |
| target_face_num, | |
| req, | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return mesh_path | |
| def api_generate_text( | |
| image_text: str, | |
| seed: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| simplify: bool, | |
| target_face_num: int, | |
| req: gr.Request | None = None, | |
| ): | |
| _, mesh_path = _run_triposg( | |
| image_text, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale, | |
| simplify, | |
| target_face_num, | |
| req, | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return mesh_path | |
| def _cleanup_session(req: gr.Request): | |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| if os.path.exists(save_dir): | |
| shutil.rmtree(save_dir) | |
| TITLE = "TripoSG Image-to-3D API" | |
| DESCRIPTION = ( | |
| "Upload a single-object image to generate a 3D mesh (GLB). " | |
| "This demo exposes a /predict API endpoint." | |
| ) | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}\n\n{DESCRIPTION}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image", type="filepath") | |
| seg_output = gr.Image( | |
| label="Segmentation Preview", type="pil", format="png" | |
| ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| seed = gr.Slider( | |
| label="Seed", minimum=0, maximum=2**31 - 1, step=1, value=0 | |
| ) | |
| steps = gr.Slider( | |
| label="Inference Steps", minimum=8, maximum=50, step=1, value=50 | |
| ) | |
| guidance = gr.Slider( | |
| label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5 | |
| ) | |
| simplify = gr.Checkbox(label="Simplify Mesh", value=True) | |
| face_count = gr.Slider( | |
| label="Target Face Count", | |
| minimum=10000, | |
| maximum=1000000, | |
| step=1000, | |
| value=100000, | |
| ) | |
| generate_btn = gr.Button("Generate 3D", variant="primary") | |
| with gr.Column(): | |
| model_output = gr.Model3D(label="Generated GLB", interactive=False) | |
| file_output = gr.File(label="Download GLB", interactive=False) | |
| api_image_text = gr.Textbox(visible=False) | |
| generate_btn.click( | |
| generate_mesh, | |
| inputs=[image_input, seed, steps, guidance, simplify, face_count], | |
| outputs=[seg_output, model_output], | |
| api_name=False, | |
| ).then(lambda path: path, inputs=model_output, outputs=file_output, api_name=False) | |
| api_btn = gr.Button(visible=False) | |
| api_btn.click( | |
| api_generate_text, | |
| inputs=[api_image_text, seed, steps, guidance, simplify, face_count], | |
| outputs=[file_output], | |
| api_name="predict", | |
| ) | |
| demo.unload(_cleanup_session) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |