Spaces:
Runtime error
Runtime error
File size: 4,247 Bytes
9778377 3dcfad8 9778377 3dcfad8 9778377 3dcfad8 9778377 3dcfad8 9778377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision.transforms as transforms
from net.dornet import Net
from net.dornet_ddp import Net_ddp
# init
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = Net(tiny_model=False).to(device)
model_ckpt_map = {
"RGB-D-D": "./checkpoints/RGBDD.pth",
"TOFDSR": "./checkpoints/TOFDSR.pth"
}
# load model
@spaces.GPU
def load_model(model_type: str):
global net
ckpt_path = model_ckpt_map[model_type]
print(f"Loading weights from: {ckpt_path}")
if model_type == "RGB-D-D":
net = Net(tiny_model=False).to(device)
elif model_type == "TOFDSR":
net = Net_ddp(tiny_model=False).srn.to(device)
else:
raise ValueError(f"Unknown model_type: {model_type}")
net.load_state_dict(torch.load(ckpt_path, map_location=device))
net.eval()
load_model("RGB-D-D")
# data process
@spaces.GPU
def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
image = np.array(rgb_image.convert("RGB")).astype(np.float32)
h, w, _ = image.shape
lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32)
# Normalize depth
max_out, min_out = 5000.0, 0.0
lr = (lr - min_out) / (max_out - min_out)
# Normalize RGB
maxx, minn = np.max(image), np.min(image)
image = (image - minn) / (maxx - minn)
# To tensor
data_transform = transforms.Compose([transforms.ToTensor()])
image = data_transform(image).float()
lr = data_transform(np.expand_dims(lr, 2)).float()
# Add batch dimension
lr = lr.unsqueeze(0).to(device)
image = image.unsqueeze(0).to(device)
return image, lr, min_out, max_out
# model inference
@spaces.GPU
@torch.no_grad()
def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
load_model(model_type) # reset weight
image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth)
if model_type == "RGB-D-D":
out = net(x_query=lr, rgb=image)
elif model_type == "TOFDSR":
out, _ = net(x_query=lr, rgb=image)
pred = out[0, 0] * (max_out - min_out) + min_out
pred = pred.cpu().numpy().astype(np.uint16)
# raw
pred_gray = Image.fromarray(pred)
# heat
pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255
pred_vis = pred_norm.astype(np.uint8)
pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA)
pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB)
return pred_gray, Image.fromarray(pred_heat)
Intro = """
## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
[π Paper](https://arxiv.org/pdf/2410.11666) β’ [π» Code](https://github.com/yanzq95/DORNet) β’ [π¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)
"""
with gr.Blocks(css="""
.output-image {
display: flex;
justify-content: center;
align-items: center;
}
.output-image img {
margin: auto;
display: block;
}
""") as demo:
gr.Markdown(Intro)
with gr.Row():
with gr.Column():
rgb_input = gr.Image(label="RGB Image", type="pil")
lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I")
with gr.Column():
output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"])
output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
run_button = gr.Button("Run Inference")
gr.Examples(
examples=[
["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
],
inputs=[rgb_input, lr_input, model_selector],
outputs=[output1, output2],
label="Try Examples β"
)
run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2])
demo.launch() |