import gradio as gr import numpy as np import spaces import torch import os import cv2 import mmcv from PIL import Image import torchvision.transforms as transforms from net.dornet import Net from net.dornet_ddp import Net_ddp print("=" * 50) print(f"CUDA available: {torch.cuda.is_available()}") print(f"torch version: {torch.__version__}") print(f"CUDA version: {torch.version.cuda}") print(f"mmcv version: {mmcv.__version__}") print("=" * 50) # init device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = Net(tiny_model=False).to(device) model_ckpt_map = { "RGB-D-D": "./checkpoints/RGBDD.pth", "TOFDSR": "./checkpoints/TOFDSR.pth" } # load model @spaces.GPU def load_model(model_type: str): global net ckpt_path = model_ckpt_map[model_type] print(f"Loading weights from: {ckpt_path}") if model_type == "RGB-D-D": net = Net(tiny_model=False).to(device) elif model_type == "TOFDSR": net = Net_ddp(tiny_model=False).srn.to(device) else: raise ValueError(f"Unknown model_type: {model_type}") net.load_state_dict(torch.load(ckpt_path, map_location=device)) net.eval() load_model("RGB-D-D") # data process @spaces.GPU def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image): image = np.array(rgb_image.convert("RGB")).astype(np.float32) h, w, _ = image.shape lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32) # Normalize depth max_out, min_out = 5000.0, 0.0 lr = (lr - min_out) / (max_out - min_out) # Normalize RGB maxx, minn = np.max(image), np.min(image) image = (image - minn) / (maxx - minn) # To tensor data_transform = transforms.Compose([transforms.ToTensor()]) image = data_transform(image).float() lr = data_transform(np.expand_dims(lr, 2)).float() # Add batch dimension lr = lr.unsqueeze(0).to(device) image = image.unsqueeze(0).to(device) return image, lr, min_out, max_out # model inference @spaces.GPU @torch.no_grad() def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str): load_model(model_type) # reset weight image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth) if model_type == "RGB-D-D": out = net(x_query=lr, rgb=image) elif model_type == "TOFDSR": out, _ = net(x_query=lr, rgb=image) pred = out[0, 0] * (max_out - min_out) + min_out pred = pred.cpu().numpy().astype(np.uint16) # raw pred_gray = Image.fromarray(pred) # heat pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255 pred_vis = pred_norm.astype(np.uint8) pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA) pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB) return pred_gray, Image.fromarray(pred_heat) Intro = """ ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution [📄 Paper](https://arxiv.org/pdf/2410.11666) • [💻 Code](https://github.com/yanzq95/DORNet) • [📦 Model](https://huggingface.co/wzxwyx/DORNet/tree/main) """ with gr.Blocks(css=""" .output-image { display: flex; justify-content: center; align-items: center; } .output-image img { margin: auto; display: block; } """) as demo: gr.Markdown(Intro) with gr.Row(): with gr.Column(): rgb_input = gr.Image(label="RGB Image", type="pil") lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I") with gr.Column(): output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]) output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"]) model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D") run_button = gr.Button("Run Inference") gr.Examples( examples=[ ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"], ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"], ], inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2], label="Try Examples ↓" ) run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2]) demo.launch()