RaynWu2002 commited on
Commit
9778377
Β·
verified Β·
1 Parent(s): e0bdf31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -156
app.py CHANGED
@@ -1,156 +1,156 @@
1
- import gradio as gr
2
- import numpy as np
3
- import torch
4
- import os
5
- import cv2
6
- from PIL import Image
7
- import torchvision.transforms as transforms
8
- from net.dornet import Net
9
- from net.dornet_ddp import Net_ddp
10
-
11
- # init
12
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
- net = Net(tiny_model=False).to(device)
14
- model_ckpt_map = {
15
- "RGB-D-D": "./checkpoints/RGBDD.pth",
16
- "TOFDSR": "./checkpoints/TOFDSR.pth"
17
- }
18
-
19
- # load model
20
- def load_model(model_type: str):
21
- global net
22
- ckpt_path = model_ckpt_map[model_type]
23
- print(f"Loading weights from: {ckpt_path}")
24
- if model_type == "RGB-D-D":
25
- net = Net(tiny_model=False).to(device)
26
- elif model_type == "TOFDSR":
27
- net = Net_ddp(tiny_model=False).srn.to(device)
28
- else:
29
- raise ValueError(f"Unknown model_type: {model_type}")
30
-
31
- net.load_state_dict(torch.load(ckpt_path, map_location=device))
32
- net.eval()
33
-
34
- load_model("RGB-D-D")
35
-
36
-
37
- # data process
38
- def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
39
- image = np.array(rgb_image.convert("RGB")).astype(np.float32)
40
- h, w, _ = image.shape
41
- lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32)
42
- # Normalize depth
43
- max_out, min_out = 5000.0, 0.0
44
- lr = (lr - min_out) / (max_out - min_out)
45
- # Normalize RGB
46
- maxx, minn = np.max(image), np.min(image)
47
- image = (image - minn) / (maxx - minn)
48
- # To tensor
49
- data_transform = transforms.Compose([transforms.ToTensor()])
50
- image = data_transform(image).float()
51
- lr = data_transform(np.expand_dims(lr, 2)).float()
52
- # Add batch dimension
53
- lr = lr.unsqueeze(0).to(device)
54
- image = image.unsqueeze(0).to(device)
55
- return image, lr, min_out, max_out
56
-
57
-
58
- # model inference
59
- @torch.no_grad()
60
- def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
61
- load_model(model_type) # reset weight
62
-
63
- image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth)
64
-
65
- if model_type == "RGB-D-D":
66
- out = net(x_query=lr, rgb=image)
67
- elif model_type == "TOFDSR":
68
- out, _ = net(x_query=lr, rgb=image)
69
-
70
- pred = out[0, 0] * (max_out - min_out) + min_out
71
- pred = pred.cpu().numpy().astype(np.uint16)
72
- # raw
73
- pred_gray = Image.fromarray(pred)
74
-
75
- # heat
76
- pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255
77
- pred_vis = pred_norm.astype(np.uint8)
78
- pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA)
79
- pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB)
80
- return pred_gray, Image.fromarray(pred_heat)
81
-
82
-
83
- # Gradio
84
- # demo = gr.Interface(
85
- # fn=infer,
86
- # inputs=[
87
- # gr.Image(label="RGB Image", type="pil"),
88
- # gr.Image(label="Low-res Depth", type="pil", image_mode="I"),
89
- # gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
90
- # ],
91
- # outputs=[
92
- # gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]),
93
- # gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
94
- # ],
95
- # examples=[
96
- # ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
97
- # ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
98
- # ],
99
- # allow_flagging="never",
100
- # title="DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution \n CVPR 2025 (Oral Presentation)",
101
- # css="""
102
- # .output-image {
103
- # display: flex;
104
- # justify-content: center;
105
- # align-items: center;
106
- # }
107
- # .output-image img {
108
- # margin: auto;
109
- # display: block;
110
- # }
111
- # """
112
- # )
113
- #
114
- # demo.launch(share=True)
115
- Intro = """
116
- ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
117
- [πŸ“„ Paper](https://arxiv.org/pdf/2410.11666) β€’ [πŸ’» Code](https://github.com/yanzq95/DORNet) β€’ [πŸ“¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)
118
- """
119
-
120
- with gr.Blocks(css="""
121
- .output-image {
122
- display: flex;
123
- justify-content: center;
124
- align-items: center;
125
- }
126
- .output-image img {
127
- margin: auto;
128
- display: block;
129
- }
130
- """) as demo:
131
- gr.Markdown(Intro)
132
-
133
- with gr.Row():
134
- with gr.Column():
135
- rgb_input = gr.Image(label="RGB Image", type="pil")
136
- lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I")
137
- with gr.Column():
138
- output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"])
139
- output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
140
-
141
- model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
142
- run_button = gr.Button("Run Inference")
143
-
144
- gr.Examples(
145
- examples=[
146
- ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
147
- ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
148
- ],
149
- inputs=[rgb_input, lr_input, model_selector],
150
- outputs=[output1, output2],
151
- label="Try Examples ↓"
152
- )
153
-
154
- run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2])
155
-
156
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import cv2
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ from net.dornet import Net
9
+ from net.dornet_ddp import Net_ddp
10
+
11
+ # init
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ net = Net(tiny_model=False).to(device)
14
+ model_ckpt_map = {
15
+ "RGB-D-D": "./checkpoints/RGBDD.pth",
16
+ "TOFDSR": "./checkpoints/TOFDSR.pth"
17
+ }
18
+
19
+ # load model
20
+ def load_model(model_type: str):
21
+ global net
22
+ ckpt_path = model_ckpt_map[model_type]
23
+ print(f"Loading weights from: {ckpt_path}")
24
+ if model_type == "RGB-D-D":
25
+ net = Net(tiny_model=False).to(device)
26
+ elif model_type == "TOFDSR":
27
+ net = Net_ddp(tiny_model=False).srn.to(device)
28
+ else:
29
+ raise ValueError(f"Unknown model_type: {model_type}")
30
+
31
+ net.load_state_dict(torch.load(ckpt_path, map_location=device))
32
+ net.eval()
33
+
34
+ load_model("RGB-D-D")
35
+
36
+
37
+ # data process
38
+ def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
39
+ image = np.array(rgb_image.convert("RGB")).astype(np.float32)
40
+ h, w, _ = image.shape
41
+ lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32)
42
+ # Normalize depth
43
+ max_out, min_out = 5000.0, 0.0
44
+ lr = (lr - min_out) / (max_out - min_out)
45
+ # Normalize RGB
46
+ maxx, minn = np.max(image), np.min(image)
47
+ image = (image - minn) / (maxx - minn)
48
+ # To tensor
49
+ data_transform = transforms.Compose([transforms.ToTensor()])
50
+ image = data_transform(image).float()
51
+ lr = data_transform(np.expand_dims(lr, 2)).float()
52
+ # Add batch dimension
53
+ lr = lr.unsqueeze(0).to(device)
54
+ image = image.unsqueeze(0).to(device)
55
+ return image, lr, min_out, max_out
56
+
57
+
58
+ # model inference
59
+ @torch.no_grad()
60
+ def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
61
+ load_model(model_type) # reset weight
62
+
63
+ image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth)
64
+
65
+ if model_type == "RGB-D-D":
66
+ out = net(x_query=lr, rgb=image)
67
+ elif model_type == "TOFDSR":
68
+ out, _ = net(x_query=lr, rgb=image)
69
+
70
+ pred = out[0, 0] * (max_out - min_out) + min_out
71
+ pred = pred.cpu().numpy().astype(np.uint16)
72
+ # raw
73
+ pred_gray = Image.fromarray(pred)
74
+
75
+ # heat
76
+ pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255
77
+ pred_vis = pred_norm.astype(np.uint8)
78
+ pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA)
79
+ pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB)
80
+ return pred_gray, Image.fromarray(pred_heat)
81
+
82
+
83
+ # Gradio
84
+ # demo = gr.Interface(
85
+ # fn=infer,
86
+ # inputs=[
87
+ # gr.Image(label="RGB Image", type="pil"),
88
+ # gr.Image(label="Low-res Depth", type="pil", image_mode="I"),
89
+ # gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
90
+ # ],
91
+ # outputs=[
92
+ # gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]),
93
+ # gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
94
+ # ],
95
+ # examples=[
96
+ # ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
97
+ # ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
98
+ # ],
99
+ # allow_flagging="never",
100
+ # title="DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution \n CVPR 2025 (Oral Presentation)",
101
+ # css="""
102
+ # .output-image {
103
+ # display: flex;
104
+ # justify-content: center;
105
+ # align-items: center;
106
+ # }
107
+ # .output-image img {
108
+ # margin: auto;
109
+ # display: block;
110
+ # }
111
+ # """
112
+ # )
113
+ #
114
+ # demo.launch(share=True)
115
+ Intro = """
116
+ ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
117
+ [πŸ“„ Paper](https://arxiv.org/pdf/2410.11666) β€’ [πŸ’» Code](https://github.com/yanzq95/DORNet) β€’ [πŸ“¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)
118
+ """
119
+
120
+ with gr.Blocks(css="""
121
+ .output-image {
122
+ display: flex;
123
+ justify-content: center;
124
+ align-items: center;
125
+ }
126
+ .output-image img {
127
+ margin: auto;
128
+ display: block;
129
+ }
130
+ """) as demo:
131
+ gr.Markdown(Intro)
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ rgb_input = gr.Image(label="RGB Image", type="pil")
136
+ lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I")
137
+ with gr.Column():
138
+ output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"])
139
+ output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
140
+
141
+ model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
142
+ run_button = gr.Button("Run Inference")
143
+
144
+ gr.Examples(
145
+ examples=[
146
+ ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
147
+ ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
148
+ ],
149
+ inputs=[rgb_input, lr_input, model_selector],
150
+ outputs=[output1, output2],
151
+ label="Try Examples ↓"
152
+ )
153
+
154
+ run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2])
155
+
156
+ demo.launch()