import gradio as gr import cv2 import matplotlib import numpy as np import os from PIL import Image import spaces import torch import tempfile from gradio_imageslider import ImageSlider from huggingface_hub import hf_hub_download from ppd.utils.set_seed import set_seed from ppd.models.ppd import PixelPerfectDepth css = """ #img-display-container { max-height: 100vh; } #img-display-input { max-height: 100vh; } #img-display-output { max-height: 100vh; } #download { height: 62px; } #img-display-output .image-slider-image { object-fit: contain !important; width: 100% !important; height: 100% !important; } """ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model = PixelPerfectDepth(sampling_steps=4) ckpt_path = hf_hub_download( repo_id="gangweix/Pixel-Perfect-Depth", filename="ppd.pth", repo_type="model" ) state_dict = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(state_dict, strict=False) model = model.to(DEVICE).eval() title = "# Pixel-Perfect Depth" description = """Official demo for **Pixel-Perfect Depth**. Please refer to our [paper](), [project page](https://pixel-perfect-depth.github.io), and [github](https://github.com/gangweix/pixel-perfect-depth) for more details.""" @spaces.GPU def predict_depth(image): return model.infer_image(image) with gr.Blocks(css=css) as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown("### Depth Prediction demo") with gr.Row(): input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5) submit = gr.Button(value="Predict Depth") concat_file = gr.File(label="Concatenated visualization (image+depth)", elem_id="image-depth-download") raw_file = gr.File(label="Raw depth output (saved as .npy)", elem_id="download",) cmap = matplotlib.colormaps.get_cmap('Spectral') def on_submit(image): original_image = image.copy() depth = predict_depth(image[:, :, ::-1]) # save raw depth (npy) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False) np.save(tmp_raw_depth.name, depth) depth_vis = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth_vis = depth_vis.astype(np.uint8) colored_depth = (cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8) split_region = np.ones((image.shape[0], 50, 3), dtype=np.uint8) * 255 combined_result = cv2.hconcat([image[:, :, ::-1], split_region, colored_depth[:, :, ::-1]]) tmp_concat = tempfile.NamedTemporaryFile(suffix='.png', delete=False) cv2.imwrite(tmp_concat.name, combined_result) return [(original_image, colored_depth), tmp_concat.name, tmp_raw_depth.name] submit.click( on_submit, inputs=[input_image], outputs=[depth_image_slider, concat_file, raw_file] ) example_files = os.listdir('assets/examples') example_files.sort() example_files = [os.path.join('assets/examples', filename) for filename in example_files] examples = gr.Examples( examples=example_files, inputs=[input_image], outputs=[depth_image_slider, concat_file, raw_file], fn=on_submit ) if __name__ == '__main__': demo.queue().launch(share=True)