File size: 3,363 Bytes
f3938be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8bfe4
 
f3938be
cf8bfe4
f3938be
 
 
 
 
 
 
 
cf8bfe4
f3938be
 
 
272c56d
f3938be
 
 
 
 
cf8bfe4
f3938be
 
 
 
 
 
 
 
 
cf8bfe4
f3938be
 
 
 
cf8bfe4
f3938be
 
 
 
cf8bfe4
f3938be
 
 
 
 
 
cf8bfe4
f3938be
 
 
44550be
f3938be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8bfe4
f3938be
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import cv2
import torch
import tempfile
import numpy as np
import matplotlib
import gradio as gr
from PIL import Image
import spaces
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from bridge.dpt import Bridge  

# ====== Gradio CSS 样式 ======
css = """
#img-display-container {
    max-height: 100vh;
}
#img-display-input {
    max-height: 80vh;
}
#img-display-output {
    max-height: 80vh;
}
#download {
    height: 62px;
}
"""

# ====== device ====== 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# ====== model load ======
model = Bridge()
filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")


model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()

# ====== description ======
title = "# Bridge Simplified Demo"
description = """
Official demo for Bridge using Gradio.  
[project page](https://dingning-liu.github.io/bridge.github.io/), 
[github](https://github.com/lnbxldn/BRIDGE).
"""

cmap = matplotlib.colormaps.get_cmap("Spectral_r")

# ====== inference ======
@spaces.GPU
def predict_depth(image: np.ndarray) -> np.ndarray:
    """Run depth inference on an RGB image (numpy)."""
    return model.infer_image(image[:, :, ::-1])  # BGR→RGB

def on_submit(image: np.ndarray):
    original_image = image.copy()
    depth = predict_depth(image)

    # 16-bit depth map
    raw_depth = Image.fromarray(depth.astype("uint16"))
    tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    raw_depth.save(tmp_raw_depth.name)

    # normalization and colorize
    depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth_uint8 = depth_norm.astype(np.uint8)
    colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)

    # save depth map
    gray_depth = Image.fromarray(depth_uint8)
    tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    gray_depth.save(tmp_gray_depth.name)

    return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]

# ====== Gradio UI======
with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown("### Depth Prediction Demo")

    with gr.Row():
        input_image = gr.Image(
            label="Input Image", type="numpy", elem_id="img-display-input"
        )
        depth_image_slider = ImageSlider(
            label="Depth Map with Slider View", elem_id="img-display-output", position=0.5
        )
    submit = gr.Button(value="Compute Depth")
    gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
    raw_file = gr.File(label="16-bit raw output", elem_id="download")

    submit.click(
        on_submit,
        inputs=[input_image],
        outputs=[depth_image_slider, gray_depth_file, raw_file]
    )

    # examples
    if os.path.exists("assets/examples"):
        example_files = sorted(os.listdir("assets/examples"))
        example_files = [os.path.join("assets/examples", f) for f in example_files]
        gr.Examples(
            examples=example_files,
            inputs=[input_image],
            outputs=[depth_image_slider, gray_depth_file, raw_file],
            fn=on_submit
        )

if __name__ == "__main__":
    demo.queue().launch(share=True)