import logging import os import tempfile import time os.environ["OMP_NUM_THREADS"] = "1" import gradio as gr import numpy as np import rembg import torch from PIL import Image from functools import partial from tsr.system import TSR from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation HEADER = """ """ if torch.cuda.is_available(): device = "cuda:0" else: device = "cpu" d = os.environ.get("DEVICE", None) if d != None: device = d model = TSR.from_pretrained( "stabilityai/TripoSR", config_name="config.yaml", weight_name="model.ckpt", ) model.renderer.set_chunk_size(131072) model.to(device) rembg_session = rembg.new_session() def check_input_image(input_image): if input_image is None: raise gr.Error("No image uploaded!") def preprocess(input_image, do_remove_background, foreground_ratio): def fill_background(image): image = np.array(image).astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = Image.fromarray((image * 255.0).astype(np.uint8)) return image if do_remove_background: image = input_image.convert("RGB") image = remove_background(image, rembg_session) image = resize_foreground(image, foreground_ratio) image = fill_background(image) else: image = input_image if image.mode == "RGBA": image = fill_background(image) return image def generate(image, mc_resolution): with torch.no_grad(): scene_codes = model(image, device=device) mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0] mesh = to_gradio_3d_orientation(mesh) mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False) mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) mesh.export(mesh_path.name) mesh.export(mesh_path2.name) torch.cuda.empty_cache() return mesh_path.name, mesh_path2.name def run_example(image_pil): preprocessed = preprocess(image_pil, False, 0.9) mesh_name, mesh_name2 = generate(preprocessed, 256) return preprocessed, mesh_name, mesh_name2 with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(variant="panel"): with gr.Column(): with gr.Row(): input_image = gr.Image( label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image", ) processed_image = gr.Image(label="Processed Image", interactive=False) with gr.Row(): with gr.Group(): do_remove_background = gr.Checkbox( label="Remove Background", value=True ) foreground_ratio = gr.Slider( label="Foreground Ratio", minimum=0.5, maximum=1.0, value=0.85, step=0.05, ) mc_resolution = gr.Slider( label="Mesh Resolution", minimum=128, maximum=320, value=256, step=32, ) with gr.Row(): submit = gr.Button("Generate", elem_id="generate", variant="primary") with gr.Column(): with gr.Tab("obj"): output_model = gr.Model3D( label="Output Model", interactive=False, ) with gr.Tab("glb"): output_model2 = gr.Model3D( label="Output Model", interactive=False, ) submit.click(fn=check_input_image, inputs=[input_image]).success( fn=preprocess, inputs=[input_image, do_remove_background, foreground_ratio], outputs=[processed_image], ).success( fn=generate, inputs=[processed_image, mc_resolution], outputs=[output_model, output_model2], ) demo.queue(max_size=10) demo.launch()