Update app.py
Browse files
app.py
CHANGED
|
@@ -27,10 +27,10 @@ css = """
|
|
| 27 |
}
|
| 28 |
"""
|
| 29 |
|
| 30 |
-
# ======
|
| 31 |
-
DEVICE =
|
| 32 |
|
| 33 |
-
# ======
|
| 34 |
model = Bridge()
|
| 35 |
filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
|
| 36 |
state_dict = torch.load(filepath, map_location="cpu")
|
|
@@ -39,7 +39,7 @@ state_dict = torch.load(filepath, map_location="cpu")
|
|
| 39 |
model.load_state_dict(state_dict)
|
| 40 |
model = model.to(DEVICE).eval()
|
| 41 |
|
| 42 |
-
# ======
|
| 43 |
title = "# Bridge Simplified Demo"
|
| 44 |
description = """
|
| 45 |
Official demo for Bridge using Gradio.
|
|
@@ -49,7 +49,7 @@ Official demo for Bridge using Gradio.
|
|
| 49 |
|
| 50 |
cmap = matplotlib.colormaps.get_cmap("Spectral_r")
|
| 51 |
|
| 52 |
-
# ======
|
| 53 |
@spaces.GPU
|
| 54 |
def predict_depth(image: np.ndarray) -> np.ndarray:
|
| 55 |
"""Run depth inference on an RGB image (numpy)."""
|
|
@@ -59,24 +59,24 @@ def on_submit(image: np.ndarray):
|
|
| 59 |
original_image = image.copy()
|
| 60 |
depth = predict_depth(image)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
raw_depth = Image.fromarray(depth.astype("uint16"))
|
| 64 |
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 65 |
raw_depth.save(tmp_raw_depth.name)
|
| 66 |
|
| 67 |
-
#
|
| 68 |
depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 69 |
depth_uint8 = depth_norm.astype(np.uint8)
|
| 70 |
colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
|
| 71 |
|
| 72 |
-
#
|
| 73 |
gray_depth = Image.fromarray(depth_uint8)
|
| 74 |
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 75 |
gray_depth.save(tmp_gray_depth.name)
|
| 76 |
|
| 77 |
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
|
| 78 |
|
| 79 |
-
# ====== Gradio
|
| 80 |
with gr.Blocks(css=css) as demo:
|
| 81 |
gr.Markdown(title)
|
| 82 |
gr.Markdown(description)
|
|
@@ -99,7 +99,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 99 |
outputs=[depth_image_slider, gray_depth_file, raw_file]
|
| 100 |
)
|
| 101 |
|
| 102 |
-
#
|
| 103 |
if os.path.exists("assets/examples"):
|
| 104 |
example_files = sorted(os.listdir("assets/examples"))
|
| 105 |
example_files = [os.path.join("assets/examples", f) for f in example_files]
|
|
|
|
| 27 |
}
|
| 28 |
"""
|
| 29 |
|
| 30 |
+
# ====== device ======
|
| 31 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 32 |
|
| 33 |
+
# ====== model load ======
|
| 34 |
model = Bridge()
|
| 35 |
filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model")
|
| 36 |
state_dict = torch.load(filepath, map_location="cpu")
|
|
|
|
| 39 |
model.load_state_dict(state_dict)
|
| 40 |
model = model.to(DEVICE).eval()
|
| 41 |
|
| 42 |
+
# ====== description ======
|
| 43 |
title = "# Bridge Simplified Demo"
|
| 44 |
description = """
|
| 45 |
Official demo for Bridge using Gradio.
|
|
|
|
| 49 |
|
| 50 |
cmap = matplotlib.colormaps.get_cmap("Spectral_r")
|
| 51 |
|
| 52 |
+
# ====== inference ======
|
| 53 |
@spaces.GPU
|
| 54 |
def predict_depth(image: np.ndarray) -> np.ndarray:
|
| 55 |
"""Run depth inference on an RGB image (numpy)."""
|
|
|
|
| 59 |
original_image = image.copy()
|
| 60 |
depth = predict_depth(image)
|
| 61 |
|
| 62 |
+
# 16-bit depth map
|
| 63 |
raw_depth = Image.fromarray(depth.astype("uint16"))
|
| 64 |
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 65 |
raw_depth.save(tmp_raw_depth.name)
|
| 66 |
|
| 67 |
+
# normalization and colorize
|
| 68 |
depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
| 69 |
depth_uint8 = depth_norm.astype(np.uint8)
|
| 70 |
colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8)
|
| 71 |
|
| 72 |
+
# save depth map
|
| 73 |
gray_depth = Image.fromarray(depth_uint8)
|
| 74 |
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 75 |
gray_depth.save(tmp_gray_depth.name)
|
| 76 |
|
| 77 |
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
|
| 78 |
|
| 79 |
+
# ====== Gradio UI======
|
| 80 |
with gr.Blocks(css=css) as demo:
|
| 81 |
gr.Markdown(title)
|
| 82 |
gr.Markdown(description)
|
|
|
|
| 99 |
outputs=[depth_image_slider, gray_depth_file, raw_file]
|
| 100 |
)
|
| 101 |
|
| 102 |
+
# examples
|
| 103 |
if os.path.exists("assets/examples"):
|
| 104 |
example_files = sorted(os.listdir("assets/examples"))
|
| 105 |
example_files = [os.path.join("assets/examples", f) for f in example_files]
|