import io import logging import os import tempfile import gradio as gr import numpy as np import rembg import spaces import torch import uvicorn from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import FileResponse, JSONResponse from gradio.routes import mount_gradio_app from PIL import Image from tsr.system import TSR from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) device = "cuda:0" if torch.cuda.is_available() else "cpu" logger.info(f"TripoSR using device: {device}") model = TSR.from_pretrained( "stabilityai/TripoSR", config_name="config.yaml", weight_name="model.ckpt", ) model.renderer.set_chunk_size(131072) model.to(device) rembg_session = rembg.new_session() def preprocess(input_image, do_remove_background=True, foreground_ratio=0.85): def fill_background(image): image = np.array(image).astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = Image.fromarray((image * 255.0).astype(np.uint8)) return image if do_remove_background: image = input_image.convert("RGB") image = remove_background(image, rembg_session) image = resize_foreground(image, foreground_ratio) image = fill_background(image) else: image = input_image if image.mode == "RGBA": image = fill_background(image) return image @spaces.GPU def generate_mesh(image, mc_resolution=256): scene_codes = model(image, device=device) mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0] mesh = to_gradio_3d_orientation(mesh) glb_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name mesh.export(glb_path) obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name mesh.apply_scale([-1, 1, 1]) mesh.export(obj_path) return obj_path, glb_path # ── FastAPI ────────────────────────────────────────────────── app = FastAPI(title="TripoSR 3D Generation") @app.post("/api/generate") async def api_generate( image: UploadFile = File(...), remove_bg: bool = Form(True), foreground_ratio: float = Form(0.85), mc_resolution: int = Form(256), format: str = Form("glb"), ): try: contents = await image.read() img = Image.open(io.BytesIO(contents)) processed = preprocess(img, remove_bg, foreground_ratio) obj_path, glb_path = generate_mesh(processed, mc_resolution) if format.lower() == "obj": return FileResponse( obj_path, filename="model.obj", media_type="application/octet-stream" ) return FileResponse( glb_path, filename="model.glb", media_type="model/gltf-binary" ) except Exception as e: logger.exception("Generation failed") return JSONResponse(status_code=500, content={"error": str(e)}) @app.get("/api/health") async def health(): return {"status": "ok", "device": device} # ── Gradio UI ────────────────────────────────────────────── HEADER = """ # TripoSR 3D Reconstruction **Fast feedforward 3D reconstruction from a single image.** Developed by [Tripo AI](https://www.tripo3d.ai/) & [Stability AI](https://stability.ai/). **API:** `POST /api/generate` — accepts image file, returns GLB/OBJ mesh. """ def check_input_image(input_image): if input_image is None: raise gr.Error("No image uploaded!") def gradio_generate(image, do_remove_background, foreground_ratio, mc_resolution): processed = preprocess(image, do_remove_background, foreground_ratio) return generate_mesh(processed, mc_resolution) with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(variant="panel"): with gr.Column(): with gr.Row(): input_image = gr.Image( label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image", ) processed_image = gr.Image(label="Processed Image", interactive=False) with gr.Row(): with gr.Group(): do_remove_background = gr.Checkbox( label="Remove Background", value=True ) foreground_ratio = gr.Slider( label="Foreground Ratio", minimum=0.5, maximum=1.0, value=0.85, step=0.05, ) mc_resolution = gr.Slider( label="Marching Cubes Resolution", minimum=32, maximum=320, value=256, step=32, ) with gr.Row(): submit = gr.Button("Generate", elem_id="generate", variant="primary") with gr.Column(): with gr.Tab("OBJ"): output_model_obj = gr.Model3D( label="Output Model (OBJ)", interactive=False ) with gr.Tab("GLB"): output_model_glb = gr.Model3D( label="Output Model (GLB)", interactive=False ) with gr.Row(variant="panel"): example_dir = "examples" if os.path.isdir(example_dir): examples = [ os.path.join(example_dir, img_name) for img_name in sorted(os.listdir(example_dir)) ] else: examples = [] gr.Examples( examples=examples, inputs=[input_image], label="Examples", examples_per_page=20, ) submit.click(fn=check_input_image, inputs=[input_image]).success( fn=preprocess, inputs=[input_image, do_remove_background, foreground_ratio], outputs=[processed_image], ).success( fn=gradio_generate, inputs=[processed_image, mc_resolution], outputs=[output_model_obj, output_model_glb], ) app = mount_gradio_app(app, demo, path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)