Spaces:
Build error
Build error
| import io | |
| import os | |
| import shlex | |
| import subprocess | |
| import tempfile | |
| import zipfile | |
| from functools import partial | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import numpy as np | |
| import rembg | |
| import torch | |
| from PIL import Image | |
| from tsr.system import TSR | |
| from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation | |
| # ------------------------------------------------------------ | |
| # 1. Model & utils initialization (runs at startup) | |
| # ------------------------------------------------------------ | |
| # Install any local wheels (if needed) | |
| import sys | |
| # device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # load model | |
| model = TSR.from_pretrained( | |
| "stabilityai/TripoSR", | |
| config_name="config.yaml", | |
| weight_name="model.ckpt", | |
| ) | |
| model.renderer.set_chunk_size(131072) | |
| model.to(device) | |
| # background removal | |
| rembg_session = rembg.new_session() | |
| def check_input_image(image: Image.Image): | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="No image uploaded!") | |
| def preprocess( | |
| input_image: Image.Image, do_remove_background: bool, foreground_ratio: float | |
| ) -> Image.Image: | |
| """ | |
| Mimics the Gradio preprocess(...) function. | |
| """ | |
| def fill_background(image: Image.Image) -> Image.Image: | |
| arr = np.array(image).astype(np.float32) / 255.0 | |
| arr = arr[:, :, :3] * arr[:, :, 3:4] + (1 - arr[:, :, 3:4]) * 0.5 | |
| out = (arr * 255.0).astype(np.uint8) | |
| return Image.fromarray(out) | |
| if do_remove_background: | |
| image = input_image.convert("RGB") | |
| image = remove_background(image, rembg_session) | |
| image = resize_foreground(image, foreground_ratio) | |
| image = fill_background(image) | |
| else: | |
| image = input_image | |
| if image.mode == "RGBA": | |
| image = fill_background(image) | |
| return image | |
| def generate( | |
| image: Image.Image, mc_resolution: int, formats=["obj", "glb"] | |
| ) -> tuple[str, str]: | |
| """ | |
| Mimics the Gradio generate(...) function. | |
| Returns paths to .obj and .glb on disk. | |
| """ | |
| # 1. inference | |
| scene_codes = model(image, device=device) | |
| mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0] | |
| mesh = to_gradio_3d_orientation(mesh) | |
| # 2. export GLB | |
| glb_tmp = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) | |
| mesh.export(glb_tmp.name) | |
| # 3. export OBJ (flip x-axis so OBJ is not mirrored) | |
| obj_tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False) | |
| mesh.apply_scale([-1, 1, 1]) | |
| mesh.export(obj_tmp.name) | |
| return obj_tmp.name, glb_tmp.name | |
| # ------------------------------------------------------------ | |
| # 2. FastAPI app | |
| # ------------------------------------------------------------ | |
| app = FastAPI(title="TripoSR FastAPI Demo") | |
| # If you need CORS (e.g. calling from a browser-based front-end) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["POST", "GET", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| async def generate_endpoint( | |
| image_file: UploadFile = File(...), | |
| do_remove_background: bool = Form(True), | |
| foreground_ratio: float = Form(0.85), | |
| mc_resolution: int = Form(256), | |
| ): | |
| """ | |
| 1. Read & validate image | |
| 2. Preprocess | |
| 3. Generate mesh | |
| 4. Package processed image + .obj + .glb into a ZIP | |
| """ | |
| # 1) Read image bytes | |
| contents = await image_file.read() | |
| try: | |
| pil_img = Image.open(io.BytesIO(contents)) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| check_input_image(pil_img) | |
| # 2) Preprocess | |
| processed = preprocess(pil_img, do_remove_background, foreground_ratio) | |
| # 3) Generate mesh | |
| obj_path, glb_path = generate(processed, mc_resolution) | |
| # 4) Create in-memory ZIP | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, mode="w") as zf: | |
| # processed image | |
| buf = io.BytesIO() | |
| processed.save(buf, format="PNG") | |
| zf.writestr("processed.png", buf.getvalue()) | |
| # .obj | |
| with open(obj_path, "rb") as f: | |
| zf.writestr(os.path.basename(obj_path), f.read()) | |
| # .glb | |
| with open(glb_path, "rb") as f: | |
| zf.writestr(os.path.basename(glb_path), f.read()) | |
| zip_buffer.seek(0) | |
| # Cleanup temp files | |
| os.remove(obj_path) | |
| os.remove(glb_path) | |
| headers = { | |
| "Content-Disposition": 'attachment; filename="tripo_output.zip"' | |
| } | |
| return StreamingResponse( | |
| zip_buffer, media_type="application/zip", headers=headers | |
| ) | |