Bridge / app.py
Dingning's picture
Update app.py
272c56d verified
import os
import cv2
import torch
import tempfile
import numpy as np
import matplotlib
import gradio as gr
from PIL import Image
import spaces
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from bridge.dpt import Bridge
# ====== Gradio CSS 样式 ======
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
# ====== device ====== 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# ====== model load ======
model = Bridge()
filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
# ====== description ======
title = "# Bridge Simplified Demo"
description = """
Official demo for Bridge using Gradio.
[project page](https://dingning-liu.github.io/bridge.github.io/),
[github](https://github.com/lnbxldn/BRIDGE).
"""
cmap = matplotlib.colormaps.get_cmap("Spectral_r")
# ====== inference ======
@spaces.GPU
def predict_depth(image: np.ndarray) -> np.ndarray:
"""Run depth inference on an RGB image (numpy)."""
return model.infer_image(image[:, :, ::-1]) # BGR→RGB
def on_submit(image: np.ndarray):
original_image = image.copy()
depth = predict_depth(image)
# 16-bit depth map
raw_depth = Image.fromarray(depth.astype("uint16"))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
raw_depth.save(tmp_raw_depth.name)
# normalization and colorize
depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_uint8 = depth_norm.astype(np.uint8)
colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
# save depth map
gray_depth = Image.fromarray(depth_uint8)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
gray_depth.save(tmp_gray_depth.name)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
# ====== Gradio UI======
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Depth Prediction Demo")
with gr.Row():
input_image = gr.Image(
label="Input Image", type="numpy", elem_id="img-display-input"
)
depth_image_slider = ImageSlider(
label="Depth Map with Slider View", elem_id="img-display-output", position=0.5
)
submit = gr.Button(value="Compute Depth")
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
raw_file = gr.File(label="16-bit raw output", elem_id="download")
submit.click(
on_submit,
inputs=[input_image],
outputs=[depth_image_slider, gray_depth_file, raw_file]
)
# examples
if os.path.exists("assets/examples"):
example_files = sorted(os.listdir("assets/examples"))
example_files = [os.path.join("assets/examples", f) for f in example_files]
gr.Examples(
examples=example_files,
inputs=[input_image],
outputs=[depth_image_slider, gray_depth_file, raw_file],
fn=on_submit
)
if __name__ == "__main__":
demo.queue().launch(share=True)