File size: 3,363 Bytes
f3938be cf8bfe4 f3938be cf8bfe4 f3938be cf8bfe4 f3938be 272c56d f3938be cf8bfe4 f3938be cf8bfe4 f3938be cf8bfe4 f3938be cf8bfe4 f3938be cf8bfe4 f3938be 44550be f3938be cf8bfe4 f3938be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import cv2
import torch
import tempfile
import numpy as np
import matplotlib
import gradio as gr
from PIL import Image
import spaces
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from bridge.dpt import Bridge
# ====== Gradio CSS 样式 ======
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
# ====== device ======
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# ====== model load ======
model = Bridge()
filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
# ====== description ======
title = "# Bridge Simplified Demo"
description = """
Official demo for Bridge using Gradio.
[project page](https://dingning-liu.github.io/bridge.github.io/),
[github](https://github.com/lnbxldn/BRIDGE).
"""
cmap = matplotlib.colormaps.get_cmap("Spectral_r")
# ====== inference ======
@spaces.GPU
def predict_depth(image: np.ndarray) -> np.ndarray:
"""Run depth inference on an RGB image (numpy)."""
return model.infer_image(image[:, :, ::-1]) # BGR→RGB
def on_submit(image: np.ndarray):
original_image = image.copy()
depth = predict_depth(image)
# 16-bit depth map
raw_depth = Image.fromarray(depth.astype("uint16"))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
raw_depth.save(tmp_raw_depth.name)
# normalization and colorize
depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_uint8 = depth_norm.astype(np.uint8)
colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
# save depth map
gray_depth = Image.fromarray(depth_uint8)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
gray_depth.save(tmp_gray_depth.name)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
# ====== Gradio UI======
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Depth Prediction Demo")
with gr.Row():
input_image = gr.Image(
label="Input Image", type="numpy", elem_id="img-display-input"
)
depth_image_slider = ImageSlider(
label="Depth Map with Slider View", elem_id="img-display-output", position=0.5
)
submit = gr.Button(value="Compute Depth")
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
raw_file = gr.File(label="16-bit raw output", elem_id="download")
submit.click(
on_submit,
inputs=[input_image],
outputs=[depth_image_slider, gray_depth_file, raw_file]
)
# examples
if os.path.exists("assets/examples"):
example_files = sorted(os.listdir("assets/examples"))
example_files = [os.path.join("assets/examples", f) for f in example_files]
gr.Examples(
examples=example_files,
inputs=[input_image],
outputs=[depth_image_slider, gray_depth_file, raw_file],
fn=on_submit
)
if __name__ == "__main__":
demo.queue().launch(share=True)
|