Vo Minh Vu
update req
bfa1afe
import io
import os
import shlex
import subprocess
import tempfile
import zipfile
from functools import partial
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import rembg
import torch
from PIL import Image
from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
# ------------------------------------------------------------
# 1. Model & utils initialization (runs at startup)
# ------------------------------------------------------------
# Install any local wheels (if needed)
import sys
# device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# load model
model = TSR.from_pretrained(
"stabilityai/TripoSR",
config_name="config.yaml",
weight_name="model.ckpt",
)
model.renderer.set_chunk_size(131072)
model.to(device)
# background removal
rembg_session = rembg.new_session()
def check_input_image(image: Image.Image):
if image is None:
raise HTTPException(status_code=400, detail="No image uploaded!")
def preprocess(
input_image: Image.Image, do_remove_background: bool, foreground_ratio: float
) -> Image.Image:
"""
Mimics the Gradio preprocess(...) function.
"""
def fill_background(image: Image.Image) -> Image.Image:
arr = np.array(image).astype(np.float32) / 255.0
arr = arr[:, :, :3] * arr[:, :, 3:4] + (1 - arr[:, :, 3:4]) * 0.5
out = (arr * 255.0).astype(np.uint8)
return Image.fromarray(out)
if do_remove_background:
image = input_image.convert("RGB")
image = remove_background(image, rembg_session)
image = resize_foreground(image, foreground_ratio)
image = fill_background(image)
else:
image = input_image
if image.mode == "RGBA":
image = fill_background(image)
return image
def generate(
image: Image.Image, mc_resolution: int, formats=["obj", "glb"]
) -> tuple[str, str]:
"""
Mimics the Gradio generate(...) function.
Returns paths to .obj and .glb on disk.
"""
# 1. inference
scene_codes = model(image, device=device)
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
mesh = to_gradio_3d_orientation(mesh)
# 2. export GLB
glb_tmp = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
mesh.export(glb_tmp.name)
# 3. export OBJ (flip x-axis so OBJ is not mirrored)
obj_tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
mesh.apply_scale([-1, 1, 1])
mesh.export(obj_tmp.name)
return obj_tmp.name, glb_tmp.name
# ------------------------------------------------------------
# 2. FastAPI app
# ------------------------------------------------------------
app = FastAPI(title="TripoSR FastAPI Demo")
# If you need CORS (e.g. calling from a browser-based front-end)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST", "GET", "OPTIONS"],
allow_headers=["*"],
)
@app.post("/generate", response_class=StreamingResponse)
async def generate_endpoint(
image_file: UploadFile = File(...),
do_remove_background: bool = Form(True),
foreground_ratio: float = Form(0.85),
mc_resolution: int = Form(256),
):
"""
1. Read & validate image
2. Preprocess
3. Generate mesh
4. Package processed image + .obj + .glb into a ZIP
"""
# 1) Read image bytes
contents = await image_file.read()
try:
pil_img = Image.open(io.BytesIO(contents))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
check_input_image(pil_img)
# 2) Preprocess
processed = preprocess(pil_img, do_remove_background, foreground_ratio)
# 3) Generate mesh
obj_path, glb_path = generate(processed, mc_resolution)
# 4) Create in-memory ZIP
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w") as zf:
# processed image
buf = io.BytesIO()
processed.save(buf, format="PNG")
zf.writestr("processed.png", buf.getvalue())
# .obj
with open(obj_path, "rb") as f:
zf.writestr(os.path.basename(obj_path), f.read())
# .glb
with open(glb_path, "rb") as f:
zf.writestr(os.path.basename(glb_path), f.read())
zip_buffer.seek(0)
# Cleanup temp files
os.remove(obj_path)
os.remove(glb_path)
headers = {
"Content-Disposition": 'attachment; filename="tripo_output.zip"'
}
return StreamingResponse(
zip_buffer, media_type="application/zip", headers=headers
)