trip_w_oblaka / app.py
keeendaaa
Improve HF URL handling for API image input
bc929c6
import os
import sys
import uuid
import shutil
import gradio as gr
import numpy as np
import torch
from huggingface_hub import snapshot_download, hf_hub_download
import trimesh
from urllib.parse import urlparse
from urllib.request import urlretrieve, Request, urlopen
import re
import base64
from io import BytesIO
from PIL import Image
try:
import gradio_client.utils as gc_utils
_orig_get_type = gc_utils.get_type
_orig_json_schema = gc_utils._json_schema_to_python_type
def _safe_get_type(schema):
if isinstance(schema, bool):
return "Any"
return _orig_get_type(schema)
def _safe_json_schema(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_schema(schema, defs)
gc_utils.get_type = _safe_get_type
gc_utils._json_schema_to_python_type = _safe_json_schema
except Exception:
pass
try:
import spaces
gpu = spaces.GPU
except Exception:
def gpu(*_args, **_kwargs):
def _wrap(fn):
return fn
return _wrap
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
TRIPOSG_CODE_DIR = "./triposg"
CHECKPOINT_DIR = "checkpoints"
RMBG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "RMBG-1.4")
TRIPOSG_PRETRAINED_MODEL = os.path.join(CHECKPOINT_DIR, "TripoSG")
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)
if not os.path.exists(TRIPOSG_CODE_DIR):
os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
sys.path.append(TRIPOSG_CODE_DIR)
sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
from image_process import prepare_image
from briarmbg import BriaRMBG
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from utils import simplify_mesh
snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
rmbg_net.eval()
snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(
DEVICE, DTYPE
)
def _session_dir(req: gr.Request | None) -> str:
if req is None:
return TMP_DIR
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(save_dir, exist_ok=True)
return save_dir
def _unique_glb_path(save_dir: str) -> str:
return os.path.join(save_dir, f"triposg_{uuid.uuid4().hex}.glb")
def _resolve_image_path(image_input, save_dir: str) -> str:
if isinstance(image_input, dict):
image_input = image_input.get("path") or image_input.get("url")
if not image_input:
raise gr.Error("Upload an image first.")
if isinstance(image_input, Image.Image):
local_path = os.path.join(save_dir, "input.png")
image_input.save(local_path)
return local_path
if isinstance(image_input, str) and os.path.exists(image_input):
return image_input
if isinstance(image_input, str) and image_input.startswith(("http://", "https://")):
hf_match = re.match(
r"^https?://huggingface\.co/(?:(?P<type>spaces|datasets)/)?(?P<repo>[^/]+/[^/]+)/resolve/(?P<rev>[^/]+)/(?P<path>.+)$",
image_input,
)
if hf_match:
repo_id = hf_match.group("repo")
filename = hf_match.group("path")
revision = hf_match.group("rev")
repo_type = hf_match.group("type")[:-1] if hf_match.group("type") else "model"
try:
return hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
revision=revision,
local_dir=save_dir,
)
except Exception:
pass
parsed = urlparse(image_input)
suffix = os.path.splitext(parsed.path)[1] or ".png"
local_path = os.path.join(save_dir, f"input{suffix}")
try:
req = Request(image_input, headers={"User-Agent": "Mozilla/5.0"})
with urlopen(req) as resp, open(local_path, "wb") as out:
out.write(resp.read())
except Exception:
try:
urlretrieve(image_input, local_path)
except Exception as err:
raise gr.Error(f"Failed to download image URL: {err}") from err
return local_path
if isinstance(image_input, str) and image_input.startswith("data:image"):
header, b64_data = image_input.split(",", 1)
image = Image.open(BytesIO(base64.b64decode(b64_data)))
local_path = os.path.join(save_dir, "input.png")
image.save(local_path)
return local_path
return image_input
def _run_triposg(
image_path: str,
seed: int,
num_inference_steps: int,
guidance_scale: float,
simplify: bool,
target_face_num: int,
req: gr.Request | None = None,
):
save_dir = _session_dir(req)
image_path = _resolve_image_path(image_path, save_dir)
image_seg = prepare_image(
image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net
)
generator = torch.Generator(device=triposg_pipe.device).manual_seed(seed)
outputs = triposg_pipe(
image=image_seg,
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).samples[0]
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
if simplify:
mesh = simplify_mesh(mesh, target_face_num)
mesh_path = _unique_glb_path(save_dir)
mesh.export(mesh_path)
return image_seg, mesh_path
@gpu(duration=60)
@torch.no_grad()
def generate_mesh(
image_path: str,
seed: int,
num_inference_steps: int,
guidance_scale: float,
simplify: bool,
target_face_num: int,
req: gr.Request | None = None,
):
image_seg, mesh_path = _run_triposg(
image_path,
seed,
num_inference_steps,
guidance_scale,
simplify,
target_face_num,
req,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return image_seg, mesh_path
@gpu(duration=60)
@torch.no_grad()
def api_generate(
image_path: str,
seed: int,
num_inference_steps: int,
guidance_scale: float,
simplify: bool,
target_face_num: int,
req: gr.Request | None = None,
):
_, mesh_path = _run_triposg(
image_path,
seed,
num_inference_steps,
guidance_scale,
simplify,
target_face_num,
req,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return mesh_path
@gpu(duration=60)
@torch.no_grad()
def api_generate_text(
image_text: str,
seed: int,
num_inference_steps: int,
guidance_scale: float,
simplify: bool,
target_face_num: int,
req: gr.Request | None = None,
):
_, mesh_path = _run_triposg(
image_text,
seed,
num_inference_steps,
guidance_scale,
simplify,
target_face_num,
req,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return mesh_path
def _cleanup_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
TITLE = "TripoSG Image-to-3D API"
DESCRIPTION = (
"Upload a single-object image to generate a 3D mesh (GLB). "
"This demo exposes a /predict API endpoint."
)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}\n\n{DESCRIPTION}")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image", type="filepath")
seg_output = gr.Image(
label="Segmentation Preview", type="pil", format="png"
)
with gr.Accordion("Generation Settings", open=True):
seed = gr.Slider(
label="Seed", minimum=0, maximum=2**31 - 1, step=1, value=0
)
steps = gr.Slider(
label="Inference Steps", minimum=8, maximum=50, step=1, value=50
)
guidance = gr.Slider(
label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5
)
simplify = gr.Checkbox(label="Simplify Mesh", value=True)
face_count = gr.Slider(
label="Target Face Count",
minimum=10000,
maximum=1000000,
step=1000,
value=100000,
)
generate_btn = gr.Button("Generate 3D", variant="primary")
with gr.Column():
model_output = gr.Model3D(label="Generated GLB", interactive=False)
file_output = gr.File(label="Download GLB", interactive=False)
api_image_text = gr.Textbox(visible=False)
generate_btn.click(
generate_mesh,
inputs=[image_input, seed, steps, guidance, simplify, face_count],
outputs=[seg_output, model_output],
api_name=False,
).then(lambda path: path, inputs=model_output, outputs=file_output, api_name=False)
api_btn = gr.Button(visible=False)
api_btn.click(
api_generate_text,
inputs=[api_image_text, seed, steps, guidance, simplify, face_count],
outputs=[file_output],
api_name="predict",
)
demo.unload(_cleanup_session)
demo.launch(server_name="0.0.0.0", server_port=7860)