gangweix commited on
Commit
b9b092b
·
verified ·
1 Parent(s): 709cfd2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +169 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import matplotlib
4
+ import numpy as np
5
+ import os
6
+ from PIL import Image
7
+ import spaces
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import open3d as o3d
11
+ import tempfile
12
+ from gradio_imageslider import ImageSlider
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ from ppd.utils.set_seed import set_seed
16
+ from ppd.utils.align_depth_func import recover_metric_depth_ransac
17
+ from ppd.utils.depth2pcd import depth2pcd
18
+ from moge.model.v2 import MoGeModel
19
+ from ppd.models.ppd import PixelPerfectDepth
20
+
21
+ css = """
22
+ #img-display-container {
23
+ max-height: 100vh;
24
+ }
25
+ #img-display-input {
26
+ max-height: 100vh;
27
+ }
28
+ #img-display-output {
29
+ max-height: 100vh;
30
+ }
31
+ #download {
32
+ height: 62px;
33
+ }
34
+
35
+ #img-display-output .image-slider-image {
36
+ object-fit: contain !important;
37
+ width: 100% !important;
38
+ height: 100% !important;
39
+ }
40
+ """
41
+
42
+ set_seed(666)
43
+
44
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+ default_steps = 10
46
+ model = PixelPerfectDepth(sampling_steps=default_steps)
47
+ ckpt_path = hf_hub_download(
48
+ repo_id="gangweix/Pixel-Perfect-Depth",
49
+ filename="ppd.pth",
50
+ repo_type="model"
51
+ )
52
+ state_dict = torch.load(ckpt_path, map_location="cpu")
53
+ model.load_state_dict(state_dict, strict=False)
54
+ model = model.to(DEVICE).eval()
55
+
56
+ moge_model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(DEVICE).eval()
57
+
58
+
59
+
60
+ title = "# Pixel-Perfect Depth"
61
+ description = """Official demo for **Pixel-Perfect Depth**.
62
+ Please refer to our [paper](), [project page](https://pixel-perfect-depth.github.io), and [github](https://github.com/gangweix/pixel-perfect-depth) for more details."""
63
+
64
+ @spaces.GPU
65
+ def predict_depth(image, denoise_steps):
66
+ depth, resize_image = model.infer_image(image, sampling_steps=denoise_steps)
67
+ return depth, resize_image
68
+
69
+ @spaces.GPU
70
+ def predict_moge_depth(image):
71
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
72
+ image = torch.tensor(image / 255, dtype=torch.float32, device=DEVICE).permute(2, 0, 1)
73
+ metric_depth, mask, intrinsics = moge_model.infer(image)
74
+ metric_depth[~mask] = metric_depth[mask].max()
75
+ return metric_depth, mask, intrinsics
76
+
77
+ with gr.Blocks(css=css) as demo:
78
+ gr.Markdown(title)
79
+ gr.Markdown(description)
80
+ gr.Markdown("### Depth Prediction demo")
81
+
82
+ with gr.Row():
83
+ # Left: input image + settings
84
+ with gr.Column():
85
+ input_image = gr.Image(label="Input Image", image_mode="RGB", type='numpy', elem_id='img-display-input')
86
+ with gr.Accordion(label="Settings", open=False):
87
+ denoise_steps = gr.Slider(label="Denoising Steps", minimum=1, maximum=50, value=10, step=1)
88
+ apply_filter = gr.Checkbox(label="Apply filter points", value=True)
89
+ submit_btn = gr.Button(value="Predict Depth")
90
+
91
+ # Right: 3D point cloud + depth
92
+ with gr.Column():
93
+ with gr.Tabs():
94
+ with gr.Tab("3D View"):
95
+ model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1,1,1,1], height="60vh")
96
+ with gr.Tab("Depth"):
97
+ depth_map = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
98
+
99
+ concat_file = gr.File(label="Concatenated visualization (image+depth)", elem_id="image-depth-download")
100
+ raw_depth_file = gr.File(label="Raw depth output (saved as .npy)", elem_id="download")
101
+ pcd_file = gr.File(label="Point Cloud (.ply)", elem_id="download-ply")
102
+
103
+ cmap = matplotlib.colormaps.get_cmap('Spectral')
104
+
105
+ def on_submit(image, denoise_steps, apply_filter):
106
+
107
+ H, W = image.shape[:2]
108
+
109
+ ppd_depth, resize_image = predict_depth(image[:, :, ::-1], denoise_steps)
110
+
111
+ resize_H, resize_W = resize_image.shape[:2]
112
+
113
+ # moge provide metric depth and intrinsics
114
+ moge_depth, mask, intrinsics = predict_moge_depth(resize_image)
115
+
116
+ # relative depth -> metric depth
117
+ metric_depth = recover_metric_depth_ransac(ppd_depth.squeeze().cpu().numpy(), moge_depth, mask)
118
+ intrinsics[0, 0] *= resize_W
119
+ intrinsics[1, 1] *= resize_H
120
+ intrinsics[0, 2] *= resize_W
121
+ intrinsics[1, 2] *= resize_H
122
+
123
+ # metric depth -> point cloud
124
+ pcd = depth2pcd(metric_depth, intrinsics, color=cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB), input_mask=mask, ret_pcd=True)
125
+ if apply_filter:
126
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=3.0)
127
+ pcd = pcd.select_by_index(ind)
128
+
129
+ # save pcd to temporary .ply for Model3D
130
+ tmp_ply = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
131
+ o3d.io.write_point_cloud(tmp_ply.name, pcd)
132
+
133
+ depth = F.interpolate(ppd_depth, size=(H, W), mode='bilinear', align_corners=False)[0, 0]
134
+ depth = depth.cpu().numpy()
135
+
136
+ # save raw depth (npy)
137
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
138
+ np.save(tmp_raw_depth.name, depth)
139
+
140
+ depth_vis = (depth - depth.min()) / (depth.max() - depth.min()+1e-5) * 255.0
141
+ depth_vis = depth_vis.astype(np.uint8)
142
+ colored_depth = (cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8)
143
+
144
+ split_region = np.ones((image.shape[0], 50, 3), dtype=np.uint8) * 255
145
+ combined_result = cv2.hconcat([image[:, :, ::-1], split_region, colored_depth[:, :, ::-1]])
146
+ tmp_concat = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
147
+ cv2.imwrite(tmp_concat.name, combined_result)
148
+
149
+ return [(image, colored_depth), tmp_ply.name, tmp_concat.name, tmp_raw_depth.name, tmp_ply.name]
150
+
151
+ submit_btn.click(
152
+ on_submit,
153
+ inputs=[input_image, denoise_steps, apply_filter],
154
+ outputs=[depth_map, model_3d, concat_file, raw_depth_file, pcd_file]
155
+ )
156
+
157
+ example_files = os.listdir('assets/examples')
158
+ example_files.sort()
159
+ example_files = [os.path.join('assets/examples', filename) for filename in example_files]
160
+ examples = gr.Examples(
161
+ examples=example_files,
162
+ inputs=[input_image],
163
+ outputs=[depth_map, model_3d, concat_file, raw_depth_file, pcd_file],
164
+ fn=on_submit
165
+ )
166
+
167
+
168
+ if __name__ == '__main__':
169
+ demo.queue().launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_imageslider
2
+ gradio==4.36.0
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ matplotlib
7
+ huggingface_hub
8
+ timm
9
+ open3d
10
+ scikit-learn