Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import os | |
| import cv2 | |
| import mmcv | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from net.dornet import Net | |
| from net.dornet_ddp import Net_ddp | |
| print("=" * 50) | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"torch version: {torch.__version__}") | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"mmcv version: {mmcv.__version__}") | |
| print("=" * 50) | |
| # init | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| net = Net(tiny_model=False).to(device) | |
| model_ckpt_map = { | |
| "RGB-D-D": "./checkpoints/RGBDD.pth", | |
| "TOFDSR": "./checkpoints/TOFDSR.pth" | |
| } | |
| # load model | |
| def load_model(model_type: str): | |
| global net | |
| ckpt_path = model_ckpt_map[model_type] | |
| print(f"Loading weights from: {ckpt_path}") | |
| if model_type == "RGB-D-D": | |
| net = Net(tiny_model=False).to(device) | |
| elif model_type == "TOFDSR": | |
| net = Net_ddp(tiny_model=False).srn.to(device) | |
| else: | |
| raise ValueError(f"Unknown model_type: {model_type}") | |
| net.load_state_dict(torch.load(ckpt_path, map_location=device)) | |
| net.eval() | |
| load_model("RGB-D-D") | |
| # data process | |
| def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image): | |
| image = np.array(rgb_image.convert("RGB")).astype(np.float32) | |
| h, w, _ = image.shape | |
| lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32) | |
| # Normalize depth | |
| max_out, min_out = 5000.0, 0.0 | |
| lr = (lr - min_out) / (max_out - min_out) | |
| # Normalize RGB | |
| maxx, minn = np.max(image), np.min(image) | |
| image = (image - minn) / (maxx - minn) | |
| # To tensor | |
| data_transform = transforms.Compose([transforms.ToTensor()]) | |
| image = data_transform(image).float() | |
| lr = data_transform(np.expand_dims(lr, 2)).float() | |
| # Add batch dimension | |
| lr = lr.unsqueeze(0).to(device) | |
| image = image.unsqueeze(0).to(device) | |
| return image, lr, min_out, max_out | |
| # model inference | |
| def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str): | |
| load_model(model_type) # reset weight | |
| image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth) | |
| if model_type == "RGB-D-D": | |
| out = net(x_query=lr, rgb=image) | |
| elif model_type == "TOFDSR": | |
| out, _ = net(x_query=lr, rgb=image) | |
| pred = out[0, 0] * (max_out - min_out) + min_out | |
| pred = pred.cpu().numpy().astype(np.uint16) | |
| # raw | |
| pred_gray = Image.fromarray(pred) | |
| # heat | |
| pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255 | |
| pred_vis = pred_norm.astype(np.uint8) | |
| pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA) | |
| pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB) | |
| return pred_gray, Image.fromarray(pred_heat) | |
| Intro = """ | |
| ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution | |
| [π Paper](https://arxiv.org/pdf/2410.11666) β’ [π» Code](https://github.com/yanzq95/DORNet) β’ [π¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main) | |
| """ | |
| with gr.Blocks(css=""" | |
| .output-image { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| .output-image img { | |
| margin: auto; | |
| display: block; | |
| } | |
| """) as demo: | |
| gr.Markdown(Intro) | |
| with gr.Row(): | |
| with gr.Column(): | |
| rgb_input = gr.Image(label="RGB Image", type="pil") | |
| lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I") | |
| with gr.Column(): | |
| output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]) | |
| output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"]) | |
| model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D") | |
| run_button = gr.Button("Run Inference") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"], | |
| ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"], | |
| ], | |
| inputs=[rgb_input, lr_input, model_selector], | |
| outputs=[output1, output2], | |
| label="Try Examples β" | |
| ) | |
| run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2]) | |
| demo.launch() |