Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from functools import partial | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from torchvision import transforms | |
| import rembg | |
| import cv2 | |
| from pytorch_lightning import seed_everything | |
| from src.visualizer import CameraVisualizer | |
| from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs | |
| from src.pose_funcs import find_optimal_poses | |
| from src.utils import spherical_to_cartesian, elu_to_c2w | |
| if torch.cuda.is_available(): | |
| _device_ = 'cuda:0' | |
| else: | |
| _device_ = 'cpu' | |
| _config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml' | |
| _ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model') | |
| _matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model') | |
| _config_ = OmegaConf.load(_config_path_) | |
| _model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu') | |
| _model_ = _model_.to(_device_) | |
| _model_.eval() | |
| def rgba_to_rgb(img): | |
| assert img.mode == 'RGBA' | |
| img = np.asarray(img, dtype=np.float32) | |
| img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:]) | |
| img = img.clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(img[:, :, :3]) | |
| def remove_background(image, rembg_session = None, force = False, **rembg_kwargs): | |
| do_remove = True | |
| if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| do_remove = False | |
| do_remove = do_remove or force | |
| if do_remove: | |
| image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
| return image | |
| def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]): | |
| ws = [] | |
| hs = [] | |
| images = [ np.asarray(img) for img in images ] | |
| for img in images: | |
| alpha = img[:, :, 3] | |
| yy, xx = np.where(alpha > mask_thres) | |
| y0, y1 = yy.min(), yy.max() | |
| x0, x1 = xx.min(), xx.max() | |
| ws.append(float(x1 - x0) / img.shape[0]) | |
| hs.append(float(y1 - y0) / img.shape[1]) | |
| sz_w = np.max(ws) | |
| sz_h = np.max(hs) | |
| sz = max(ratio*sz_w, ratio*sz_h) | |
| out_rgbs = [] | |
| for rgba in images: | |
| rgb = rgba[:, :, :3] | |
| alpha = rgba[:, :, 3] | |
| yy, xx = np.where(alpha > mask_thres) | |
| y0, y1 = yy.min(), yy.max() | |
| x0, x1 = xx.min(), xx.max() | |
| height, width, chn = rgb.shape | |
| cy = (y0 + y1) // 2 | |
| cx = (x0 + x1) // 2 | |
| y0 = cy - int(np.floor(sz * rgba.shape[0] / 2)) | |
| y1 = cy + int(np.ceil(sz * rgba.shape[0] / 2)) | |
| x0 = cx - int(np.floor(sz * rgba.shape[1] / 2)) | |
| x1 = cx + int(np.ceil(sz * rgba.shape[1] / 2)) | |
| out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy() | |
| pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)] | |
| out = np.pad(out, pads, mode='constant', constant_values=0) | |
| out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.) | |
| out[:, :, -1] = bkg_color[-1] | |
| out = cv2.resize(out.astype(np.uint8), (256, 256)) | |
| out = out[:, :, :3] | |
| out_rgbs.append(out) | |
| return out_rgbs | |
| def run_preprocess(image1, image2, preprocess_chk, seed_value): | |
| seed_everything(seed_value) | |
| if preprocess_chk: | |
| rembg_session = rembg.new_session() | |
| image1 = remove_background(image1, force=True, rembg_session = rembg_session) | |
| image2 = remove_background(image2, force=True, rembg_session = rembg_session) | |
| rgbs = group_recenter([image1, image2]) | |
| image1 = Image.fromarray(rgbs[0]) | |
| image2 = Image.fromarray(rgbs[1]) | |
| return image1, image2 | |
| def image_to_tensor(img, width=256, height=256): | |
| img = transforms.ToTensor()(img).unsqueeze(0) | |
| img = img * 2 - 1 | |
| img = transforms.functional.resize(img, [height, width]) | |
| return img | |
| def run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value): | |
| seed_everything(seed_value) | |
| image1 = image_to_tensor(image1).to(_device_) | |
| image2 = image_to_tensor(image2).to(_device_) | |
| images = [image1, image2] | |
| elevs, elev_ranges = estimate_elevs( | |
| _model_, images, | |
| est_type='all', | |
| matcher_ckpt_path=_matcher_ckpt_path_ | |
| ) | |
| anchor_polar = elevs[0] | |
| if torch.mean(torch.abs(image1 - image2)) < 0.005: | |
| theta = azimuth = radius = 0 | |
| print('Identical images found!') | |
| else: | |
| noise = np.random.randn(probe_bsz, 4, 32, 32) | |
| result_poses, aux_data = estimate_poses( | |
| _model_, images, | |
| seed_cand_num=8, | |
| explore_type='triangular', | |
| refine_type='triangular', | |
| probe_ts_range=[0.2, 0.21], | |
| ts_range=[0.2, 0.21], | |
| probe_bsz=probe_bsz, | |
| adjust_factor=10.0, | |
| adjust_iters=adj_iters, | |
| adjust_bsz=adj_bsz, | |
| refine_factor=1.0, | |
| refine_iters=0, | |
| refine_bsz=4, | |
| noise=noise, | |
| elevs=elevs, | |
| elev_ranges=elev_ranges | |
| ) | |
| theta, azimuth, radius = result_poses[0] | |
| if anchor_polar is None: | |
| anchor_polar = np.pi/2 | |
| explored_sph = (float(theta), float(azimuth), float(radius)) | |
| return float(anchor_polar), explored_sph | |
| def run_pose_refinement(image1, image2, est_result, refine_iters, seed_value): | |
| seed_everything(seed_value) | |
| anchor_polar = est_result[0] | |
| explored_sph = est_result[1] | |
| images = [image_to_tensor(image1).to(_device_), image_to_tensor(image2).to(_device_)] | |
| images = [ img.permute(0, 2, 3, 1) for img in images ] | |
| out_poses, _, loss = find_optimal_poses( | |
| _model_, images, | |
| 1.0, | |
| bsz=1, | |
| n_iter=refine_iters, | |
| init_poses={1: explored_sph}, | |
| ts_range=[0.2, 0.21], | |
| combinations=[(0, 1), (1, 0)], | |
| avg_last_n=20, | |
| print_n=100 | |
| ) | |
| final_sph = out_poses[0] | |
| theta, azimuth, radius = final_sph | |
| xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) | |
| c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) | |
| xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) | |
| c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) | |
| cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) | |
| fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) | |
| return (anchor_polar, final_sph), fig | |
| def run_example(image1, image2): | |
| image1, image2 = run_preprocess(image1, image2, True, 0) | |
| anchor_polar, explored_sph = run_pose_exploration(image1, image2, 16, 4, 10, 0) | |
| return (anchor_polar, explored_sph), image1, image2 | |
| def run_or_visualize(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result): | |
| if est_result is None: | |
| anchor_polar, explored_sph = run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value) | |
| else: | |
| anchor_polar = est_result[0] | |
| explored_sph = est_result[1] | |
| print('Using cache result.') | |
| xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) | |
| c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) | |
| xyz1 = spherical_to_cartesian((explored_sph[0] + anchor_polar, 0. + explored_sph[1], 4. + explored_sph[2])) | |
| c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) | |
| cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) | |
| fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) | |
| return (anchor_polar, explored_sph), fig, gr.update(interactive=True) | |
| _HEADER_ = ''' | |
| # Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose) | |
| - ID-Pose accepts input images with NO overlapping appearance. | |
| - The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions. | |
| ''' | |
| _FOOTER_ = ''' | |
| [Project Page](https://xt4d.github.io/id-pose-web/) | ⭐ [Github](https://github.com/xt4d/id-pose) ⭐ [](https://github.com/xt4d/id-pose) | |
| --- | |
| ''' | |
| _CITE_ = r""" | |
| ```bibtex | |
| @article{cheng2023id, | |
| title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models}, | |
| author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying}, | |
| journal={arXiv preprint arXiv:2306.17140}, | |
| year={2023} | |
| } | |
| ``` | |
| """ | |
| def run_demo(): | |
| demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models') | |
| with demo: | |
| est_result = gr.JSON(visible=False) | |
| gr.Markdown(_HEADER_) | |
| with gr.Row(variant='panel'): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1') | |
| input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2') | |
| with gr.Row(): | |
| processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', interactive=False) | |
| processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', interactive=False) | |
| with gr.Row(): | |
| preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object') | |
| with gr.Accordion('Advanced options', open=False): | |
| probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size') | |
| adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size') | |
| adj_iters = gr.Slider(1, 20, value=10, step=1, label='Adjust Iterations') | |
| seed_value = gr.Number(value=0, label="Seed Value", precision=0) | |
| with gr.Row(): | |
| run_btn = gr.Button('Estimate', variant='primary', interactive=True) | |
| with gr.Row(): | |
| refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations') | |
| with gr.Row(): | |
| refine_btn = gr.Button('Refine', variant='primary', interactive=False) | |
| with gr.Row(): | |
| gr.Markdown(_FOOTER_) | |
| with gr.Row(): | |
| gr.Markdown(_CITE_) | |
| with gr.Column(scale=1.4): | |
| with gr.Row(): | |
| vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)') | |
| with gr.Row(): | |
| with gr.Column(min_width=200): | |
| gr.Examples( | |
| examples = [ | |
| ['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'], | |
| ['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'], | |
| ['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'], | |
| ['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'], | |
| ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'], | |
| ], | |
| inputs=[input_image1, input_image2], | |
| fn=run_example, | |
| outputs=[est_result, processed_image1, processed_image2], | |
| label='Examples (Captured)', | |
| cache_examples='lazy', | |
| examples_per_page=5 | |
| ) | |
| with gr.Column(min_width=200): | |
| gr.Examples( | |
| examples = [ | |
| ['data/gradio_demo/arc_0.png', 'data/gradio_demo/arc_1.png'], | |
| ['data/gradio_demo/husky_0.png', 'data/gradio_demo/husky_1.png'], | |
| ['data/gradio_demo/cybertruck_0.png', 'data/gradio_demo/cybertruck_1.png'], | |
| ['data/gradio_demo/plane_0.png', 'data/gradio_demo/plane_1.png'], | |
| ['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'], | |
| ], | |
| inputs=[input_image1, input_image2], | |
| fn=run_example, | |
| outputs=[est_result, processed_image1, processed_image2], | |
| label='Examples (Internet)', | |
| cache_examples='lazy', | |
| examples_per_page=5 | |
| ) | |
| with gr.Column(min_width=200): | |
| gr.Examples( | |
| examples = [ | |
| ['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'], | |
| ['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'], | |
| ['data/gradio_demo/ferrari_0.png', 'data/gradio_demo/ferrari_1.png'], | |
| ['data/gradio_demo/elon_0.png', 'data/gradio_demo/elon_1.png'], | |
| ['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'], | |
| ], | |
| inputs=[input_image1, input_image2], | |
| fn=run_example, | |
| outputs=[est_result, processed_image1, processed_image2], | |
| label='Examples (Generated)', | |
| cache_examples='lazy', | |
| examples_per_page=5 | |
| ) | |
| run_btn.click( | |
| fn=run_preprocess, | |
| inputs=[input_image1, input_image2, preprocess_chk, seed_value], | |
| outputs=[processed_image1, processed_image2], | |
| ).success( | |
| fn=run_or_visualize, | |
| inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result], | |
| outputs=[est_result, vis_output, refine_btn] | |
| ) | |
| refine_btn.click( | |
| fn=run_pose_refinement, | |
| inputs=[processed_image1, processed_image2, est_result, refine_iters, seed_value], | |
| outputs=[est_result, vis_output] | |
| ) | |
| input_image1.clear( | |
| fn=lambda: None, | |
| outputs=[est_result] | |
| ) | |
| input_image2.clear( | |
| fn=lambda: None, | |
| outputs=[est_result] | |
| ) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| run_demo() | |