import io import os import shlex import subprocess import tempfile import zipfile from functools import partial from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import numpy as np import rembg import torch from PIL import Image from tsr.system import TSR from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation # ------------------------------------------------------------ # 1. Model & utils initialization (runs at startup) # ------------------------------------------------------------ # Install any local wheels (if needed) import sys # device device = "cuda:0" if torch.cuda.is_available() else "cpu" # load model model = TSR.from_pretrained( "stabilityai/TripoSR", config_name="config.yaml", weight_name="model.ckpt", ) model.renderer.set_chunk_size(131072) model.to(device) # background removal rembg_session = rembg.new_session() def check_input_image(image: Image.Image): if image is None: raise HTTPException(status_code=400, detail="No image uploaded!") def preprocess( input_image: Image.Image, do_remove_background: bool, foreground_ratio: float ) -> Image.Image: """ Mimics the Gradio preprocess(...) function. """ def fill_background(image: Image.Image) -> Image.Image: arr = np.array(image).astype(np.float32) / 255.0 arr = arr[:, :, :3] * arr[:, :, 3:4] + (1 - arr[:, :, 3:4]) * 0.5 out = (arr * 255.0).astype(np.uint8) return Image.fromarray(out) if do_remove_background: image = input_image.convert("RGB") image = remove_background(image, rembg_session) image = resize_foreground(image, foreground_ratio) image = fill_background(image) else: image = input_image if image.mode == "RGBA": image = fill_background(image) return image def generate( image: Image.Image, mc_resolution: int, formats=["obj", "glb"] ) -> tuple[str, str]: """ Mimics the Gradio generate(...) function. Returns paths to .obj and .glb on disk. """ # 1. inference scene_codes = model(image, device=device) mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0] mesh = to_gradio_3d_orientation(mesh) # 2. export GLB glb_tmp = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) mesh.export(glb_tmp.name) # 3. export OBJ (flip x-axis so OBJ is not mirrored) obj_tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False) mesh.apply_scale([-1, 1, 1]) mesh.export(obj_tmp.name) return obj_tmp.name, glb_tmp.name # ------------------------------------------------------------ # 2. FastAPI app # ------------------------------------------------------------ app = FastAPI(title="TripoSR FastAPI Demo") # If you need CORS (e.g. calling from a browser-based front-end) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST", "GET", "OPTIONS"], allow_headers=["*"], ) @app.post("/generate", response_class=StreamingResponse) async def generate_endpoint( image_file: UploadFile = File(...), do_remove_background: bool = Form(True), foreground_ratio: float = Form(0.85), mc_resolution: int = Form(256), ): """ 1. Read & validate image 2. Preprocess 3. Generate mesh 4. Package processed image + .obj + .glb into a ZIP """ # 1) Read image bytes contents = await image_file.read() try: pil_img = Image.open(io.BytesIO(contents)) except Exception: raise HTTPException(status_code=400, detail="Invalid image file") check_input_image(pil_img) # 2) Preprocess processed = preprocess(pil_img, do_remove_background, foreground_ratio) # 3) Generate mesh obj_path, glb_path = generate(processed, mc_resolution) # 4) Create in-memory ZIP zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, mode="w") as zf: # processed image buf = io.BytesIO() processed.save(buf, format="PNG") zf.writestr("processed.png", buf.getvalue()) # .obj with open(obj_path, "rb") as f: zf.writestr(os.path.basename(obj_path), f.read()) # .glb with open(glb_path, "rb") as f: zf.writestr(os.path.basename(glb_path), f.read()) zip_buffer.seek(0) # Cleanup temp files os.remove(obj_path) os.remove(glb_path) headers = { "Content-Disposition": 'attachment; filename="tripo_output.zip"' } return StreamingResponse( zip_buffer, media_type="application/zip", headers=headers )