triposr-3d / app.py
NanoBotAIAgent's picture
Upload app.py
592328d verified
import io
import logging
import os
import tempfile
import gradio as gr
import numpy as np
import rembg
import spaces
import torch
import uvicorn
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from gradio.routes import mount_gradio_app
from PIL import Image
from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"TripoSR using device: {device}")
model = TSR.from_pretrained(
"stabilityai/TripoSR",
config_name="config.yaml",
weight_name="model.ckpt",
)
model.renderer.set_chunk_size(131072)
model.to(device)
rembg_session = rembg.new_session()
def preprocess(input_image, do_remove_background=True, foreground_ratio=0.85):
def fill_background(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
image = Image.fromarray((image * 255.0).astype(np.uint8))
return image
if do_remove_background:
image = input_image.convert("RGB")
image = remove_background(image, rembg_session)
image = resize_foreground(image, foreground_ratio)
image = fill_background(image)
else:
image = input_image
if image.mode == "RGBA":
image = fill_background(image)
return image
@spaces.GPU
def generate_mesh(image, mc_resolution=256):
scene_codes = model(image, device=device)
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
mesh = to_gradio_3d_orientation(mesh)
glb_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.export(glb_path)
obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
mesh.apply_scale([-1, 1, 1])
mesh.export(obj_path)
return obj_path, glb_path
# ── FastAPI ──────────────────────────────────────────────────
app = FastAPI(title="TripoSR 3D Generation")
@app.post("/api/generate")
async def api_generate(
image: UploadFile = File(...),
remove_bg: bool = Form(True),
foreground_ratio: float = Form(0.85),
mc_resolution: int = Form(256),
format: str = Form("glb"),
):
try:
contents = await image.read()
img = Image.open(io.BytesIO(contents))
processed = preprocess(img, remove_bg, foreground_ratio)
obj_path, glb_path = generate_mesh(processed, mc_resolution)
if format.lower() == "obj":
return FileResponse(
obj_path, filename="model.obj", media_type="application/octet-stream"
)
return FileResponse(
glb_path, filename="model.glb", media_type="model/gltf-binary"
)
except Exception as e:
logger.exception("Generation failed")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/api/health")
async def health():
return {"status": "ok", "device": device}
# ── Gradio UI ──────────────────────────────────────────────
HEADER = """
# TripoSR 3D Reconstruction
**Fast feedforward 3D reconstruction from a single image.**
Developed by [Tripo AI](https://www.tripo3d.ai/) & [Stability AI](https://stability.ai/).
**API:** `POST /api/generate` β€” accepts image file, returns GLB/OBJ mesh.
"""
def check_input_image(input_image):
if input_image is None:
raise gr.Error("No image uploaded!")
def gradio_generate(image, do_remove_background, foreground_ratio, mc_resolution):
processed = preprocess(image, do_remove_background, foreground_ratio)
return generate_mesh(processed, mc_resolution)
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Row(variant="panel"):
with gr.Column():
with gr.Row():
input_image = gr.Image(
label="Input Image",
image_mode="RGBA",
sources="upload",
type="pil",
elem_id="content_image",
)
processed_image = gr.Image(label="Processed Image", interactive=False)
with gr.Row():
with gr.Group():
do_remove_background = gr.Checkbox(
label="Remove Background", value=True
)
foreground_ratio = gr.Slider(
label="Foreground Ratio",
minimum=0.5,
maximum=1.0,
value=0.85,
step=0.05,
)
mc_resolution = gr.Slider(
label="Marching Cubes Resolution",
minimum=32,
maximum=320,
value=256,
step=32,
)
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Column():
with gr.Tab("OBJ"):
output_model_obj = gr.Model3D(
label="Output Model (OBJ)", interactive=False
)
with gr.Tab("GLB"):
output_model_glb = gr.Model3D(
label="Output Model (GLB)", interactive=False
)
with gr.Row(variant="panel"):
example_dir = "examples"
if os.path.isdir(example_dir):
examples = [
os.path.join(example_dir, img_name)
for img_name in sorted(os.listdir(example_dir))
]
else:
examples = []
gr.Examples(
examples=examples,
inputs=[input_image],
label="Examples",
examples_per_page=20,
)
submit.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess,
inputs=[input_image, do_remove_background, foreground_ratio],
outputs=[processed_image],
).success(
fn=gradio_generate,
inputs=[processed_image, mc_resolution],
outputs=[output_model_obj, output_model_glb],
)
app = mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)