Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms.functional as torchvision_F | |
| import numpy as np | |
| import os | |
| import shutil | |
| import importlib | |
| import trimesh | |
| import tempfile | |
| import subprocess | |
| import utils.options as options | |
| import shlex | |
| import time | |
| import rembg | |
| from utils.util import EasyDict as edict | |
| from PIL import Image | |
| from utils.eval_3D import get_dense_3D_grid, compute_level_grid, convert_to_explicit | |
| def get_1d_bounds(arr): | |
| nz = np.flatnonzero(arr) | |
| return nz[0], nz[-1] | |
| def get_bbox_from_mask(mask, thr): | |
| masks_for_box = (mask > thr).astype(np.float32) | |
| assert masks_for_box.sum() > 0, "Empty mask!" | |
| x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2)) | |
| y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1)) | |
| return x0, y0, x1, y1 | |
| def square_crop(image, bbox, crop_ratio=1.): | |
| x1, y1, x2, y2 = bbox | |
| h, w = y2-y1, x2-x1 | |
| yc, xc = (y1+y2)/2, (x1+x2)/2 | |
| S = max(h, w)*1.2 | |
| scale = S*crop_ratio | |
| image = torchvision_F.crop(image, top=int(yc-scale/2), left=int(xc-scale/2), height=int(scale), width=int(scale)) | |
| return image | |
| def preprocess_image(opt, image, bbox): | |
| image = square_crop(image, bbox=bbox) | |
| if image.size[0] != opt.W or image.size[1] != opt.H: | |
| image = image.resize((opt.W, opt.H)) | |
| image = torchvision_F.to_tensor(image) | |
| rgb, mask = image[:3], image[3:] | |
| if opt.data.bgcolor is not None: | |
| # replace background color using mask | |
| rgb = rgb * mask + opt.data.bgcolor * (1 - mask) | |
| mask = (mask > 0.5).float() | |
| return rgb, mask | |
| def get_image(opt, image_fname, mask_fname): | |
| image = Image.open(image_fname).convert("RGB") | |
| mask = Image.open(mask_fname).convert("L") | |
| mask_np = np.array(mask) | |
| #binarize | |
| mask_np[mask_np <= 127] = 0 | |
| mask_np[mask_np >= 127] = 1.0 | |
| image = Image.merge("RGBA", (*image.split(), mask)) | |
| bbox = get_bbox_from_mask(mask_np, 0.5) | |
| rgb_input_map, mask_input_map = preprocess_image(opt, image, bbox=bbox) | |
| return rgb_input_map, mask_input_map | |
| def get_intr(opt): | |
| # load camera | |
| f = 1.3875 | |
| K = torch.tensor([[f*opt.W, 0, opt.W/2], | |
| [0, f*opt.H, opt.H/2], | |
| [0, 0, 1]]).float() | |
| return K | |
| def get_pixel_grid(H, W, device='cuda'): | |
| y_range = torch.arange(H, dtype=torch.float32).to(device) | |
| x_range = torch.arange(W, dtype=torch.float32).to(device) | |
| Y, X = torch.meshgrid(y_range, x_range, indexing='ij') | |
| Z = torch.ones_like(Y).to(device) | |
| xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3) | |
| return xyz_grid | |
| def unproj_depth(depth, intr): | |
| ''' | |
| depth: [B, H, W] | |
| intr: [B, 3, 3] | |
| ''' | |
| batch_size, H, W = depth.shape | |
| intr = intr.to(depth.device) | |
| # [B, 3, 3] | |
| K_inv = torch.linalg.inv(intr).float() | |
| # [1, H*W,3] | |
| pixel_grid = get_pixel_grid(H, W, depth.device).unsqueeze(0) | |
| # [B, H*W,3] | |
| pixel_grid = pixel_grid.repeat(batch_size, 1, 1) | |
| # [B, 3, H*W] | |
| ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous() | |
| # [B, H*W, 3], in camera coordinates | |
| seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1) | |
| # [B, H, W, 3] | |
| seen_points = seen_points.view(batch_size, H, W, 3) | |
| return seen_points | |
| def prepare_data(opt, image_path, mask_path): | |
| var = edict() | |
| rgb_input_map, mask_input_map = get_image(opt, image_path, mask_path) | |
| intr = get_intr(opt) | |
| var.rgb_input_map = rgb_input_map.unsqueeze(0).to(opt.device) | |
| var.mask_input_map = mask_input_map.unsqueeze(0).to(opt.device) | |
| var.intr = intr.unsqueeze(0).to(opt.device) | |
| var.idx = torch.tensor([0]).to(opt.device).long() | |
| var.pose_gt = False | |
| return var | |
| def marching_cubes(opt, var, impl_network, visualize_attn=False): | |
| points_3D = get_dense_3D_grid(opt, var) # [B, N, N, N, 3] | |
| level_vox, attn_vis = compute_level_grid(opt, impl_network, var.latent_depth, var.latent_semantic, | |
| points_3D, var.rgb_input_map, visualize_attn) | |
| if attn_vis: var.attn_vis = attn_vis | |
| # occ_grids: a list of length B, each is [N, N, N] | |
| *level_grids, = level_vox.cpu().numpy() | |
| meshes = convert_to_explicit(opt, level_grids, isoval=0.5, to_pointcloud=False) | |
| var.mesh_pred = meshes | |
| return var | |
| def infer_sample(opt, var, graph): | |
| var = graph.forward(opt, var, training=False, get_loss=False) | |
| var = marching_cubes(opt, var, graph.impl_network, visualize_attn=True) | |
| return var.mesh_pred[0] | |
| def infer(input_image_path, input_mask_path): | |
| opt_cmd = options.parse_arguments(["--yaml=options/shape.yaml", "--datadir=examples", "--eval.vox_res=128", "--ckpt=/data/shape.ckpt"]) | |
| opt = options.set(opt_cmd=opt_cmd, safe_check=False) | |
| opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # build model | |
| print("Building model...") | |
| opt.pretrain.depth = None | |
| opt.arch.depth.pretrained = None | |
| module = importlib.import_module("model.compute_graph.graph_shape") | |
| graph = module.Graph(opt).to(opt.device) | |
| # download checkpoint | |
| if not os.path.isfile(opt.ckpt): | |
| print("Downloading checkpoint...") | |
| subprocess.run( | |
| shlex.split( | |
| "wget -q -O /data/shape.ckpt https://www.dropbox.com/scl/fi/hv3w9z59dqytievwviko4/shape.ckpt?rlkey=a2gut89kavrldmnt8b3df92oi&dl=0" | |
| ) | |
| ) | |
| # wait if the checkpoint is still downloading | |
| while not os.path.isfile(opt.ckpt): | |
| time.sleep(1) | |
| # load checkpoint | |
| print("Loading checkpoint...") | |
| checkpoint = torch.load(opt.ckpt, map_location=torch.device(opt.device)) | |
| graph.load_state_dict(checkpoint["graph"], strict=True) | |
| graph.eval() | |
| # load the data | |
| print("Loading data...") | |
| var = prepare_data(opt, input_image_path, input_mask_path) | |
| # create the save dir | |
| save_folder = os.path.join(opt.datadir, 'preds') | |
| if os.path.isdir(save_folder): | |
| shutil.rmtree(save_folder) | |
| os.makedirs(save_folder) | |
| opt.output_path = opt.datadir | |
| # inference the model and save the results | |
| print("Inferencing...") | |
| mesh_pred = infer_sample(opt, var, graph) | |
| # rotate the mesh upside down | |
| mesh_pred.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])) | |
| mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) | |
| mesh_pred.export(mesh_path.name, file_type="glb") | |
| return mesh_path.name | |
| def infer_wrapper_mask(input_image_path, input_mask_path): | |
| return infer(input_image_path, input_mask_path) | |
| def infer_wrapper_nomask(input_image_path): | |
| input = Image.open(input_image_path) | |
| segmented = rembg.remove(input) | |
| mask = segmented.split()[-1] | |
| mask_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| mask.save(mask_path.name) | |
| return infer(input_image_path, mask_path.name), mask_path.name | |
| def assert_input_image(input_image): | |
| if input_image is None: | |
| raise gr.Error("No image selected or uploaded!") | |
| def assert_mask_image(input_mask): | |
| if input_mask is None: | |
| raise gr.Error("No mask selected or uploaded! Please check the box if you do not have the mask.") | |
| def demo_gradio(): | |
| with gr.Blocks(analytics_enabled=False) as demo_ui: | |
| # HEADERS | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown('# ZeroShape: Regression-based Zero-shot Shape Reconstruction') | |
| gr.Markdown("[\[Arxiv\]](https://arxiv.org/pdf/2312.14198.pdf) | [\[Project\]](https://zixuanh.com/projects/zeroshape.html) | [\[GitHub\]](https://github.com/zxhuang1698/ZeroShape)") | |
| gr.Markdown("Please switch to the \"Estimated Mask\" tab if you do not have the foreground mask. The demo will try to estimate the mask for you.") | |
| # with mask | |
| with gr.Tab("Groundtruth Mask"): | |
| with gr.Row(): | |
| input_image_tab1 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
| mask_tab1 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
| output_mesh_tab1 = gr.Model3D(label="Output Mesh") | |
| with gr.Row(): | |
| submit_tab1 = gr.Button('Reconstruct', elem_id="recon_button_tab1", variant='primary') | |
| # examples | |
| with gr.Row(): | |
| examples_tab1 = [ | |
| ['examples/images/armchair.png', 'examples/masks/armchair.png'], | |
| ['examples/images/bolt.png', 'examples/masks/bolt.png'], | |
| ['examples/images/bucket.png', 'examples/masks/bucket.png'], | |
| ['examples/images/case.png', 'examples/masks/case.png'], | |
| ['examples/images/dispenser.png', 'examples/masks/dispenser.png'], | |
| ['examples/images/hat.png', 'examples/masks/hat.png'], | |
| ['examples/images/teddy_bear.png', 'examples/masks/teddy_bear.png'], | |
| ['examples/images/tiger.png', 'examples/masks/tiger.png'], | |
| ['examples/images/toy.png', 'examples/masks/toy.png'], | |
| ['examples/images/wedding_cake.png', 'examples/masks/wedding_cake.png'], | |
| ] | |
| gr.Examples( | |
| examples=examples_tab1, | |
| inputs=[input_image_tab1, mask_tab1], | |
| outputs=[output_mesh_tab1], | |
| fn=infer_wrapper_mask, | |
| cache_examples=False#os.getenv('SYSTEM') == 'spaces', | |
| ) | |
| # without mask | |
| with gr.Tab("Estimated Mask"): | |
| with gr.Row(): | |
| input_image_tab2 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
| mask_tab2 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
| output_mesh_tab2 = gr.Model3D(label="Output Mesh") | |
| with gr.Row(): | |
| submit_tab2 = gr.Button('Reconstruct', elem_id="recon_button_tab2", variant='primary') | |
| # examples | |
| with gr.Row(): | |
| examples_tab2 = [ | |
| ['examples/images/armchair.png'], | |
| ['examples/images/bolt.png'], | |
| ['examples/images/bucket.png'], | |
| ['examples/images/case.png'], | |
| ['examples/images/dispenser.png'], | |
| ['examples/images/hat.png'], | |
| ['examples/images/teddy_bear.png'], | |
| ['examples/images/tiger.png'], | |
| ['examples/images/toy.png'], | |
| ['examples/images/wedding_cake.png'], | |
| ] | |
| gr.Examples( | |
| examples=examples_tab2, | |
| inputs=[input_image_tab2], | |
| outputs=[output_mesh_tab2, mask_tab2], | |
| fn=infer_wrapper_nomask, | |
| cache_examples=False#os.getenv('SYSTEM') == 'spaces', | |
| ) | |
| submit_tab1.click( | |
| fn=assert_input_image, | |
| inputs=[input_image_tab1], | |
| queue=False | |
| ).success( | |
| fn=assert_mask_image, | |
| inputs=[mask_tab1], | |
| queue=False | |
| ).success( | |
| fn=infer_wrapper_mask, | |
| inputs=[input_image_tab1, mask_tab1], | |
| outputs=[output_mesh_tab1], | |
| ) | |
| submit_tab2.click( | |
| fn=assert_input_image, | |
| inputs=[input_image_tab2], | |
| queue=False | |
| ).success( | |
| fn=infer_wrapper_nomask, | |
| inputs=[input_image_tab2], | |
| outputs=[output_mesh_tab2, mask_tab2], | |
| ) | |
| return demo_ui | |
| if __name__ == "__main__": | |
| demo_ui = demo_gradio() | |
| demo_ui.queue(max_size=10) | |
| demo_ui.launch() |