gangweix commited on
Commit
5833bb7
·
verified ·
1 Parent(s): 412cb1c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +110 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import matplotlib
4
+ import numpy as np
5
+ import os
6
+ from PIL import Image
7
+ import spaces
8
+ import torch
9
+ import tempfile
10
+ from gradio_imageslider import ImageSlider
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from ppd.utils.set_seed import set_seed
14
+ from ppd.models.ppd import PixelPerfectDepth
15
+
16
+ css = """
17
+ #img-display-container {
18
+ max-height: 100vh;
19
+ }
20
+ #img-display-input {
21
+ max-height: 80vh;
22
+ }
23
+ #img-display-output {
24
+ max-height: 80vh;
25
+ }
26
+ #download {
27
+ height: 62px;
28
+ }
29
+ """
30
+ # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+ # model_configs = {
32
+ # 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
+ # 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
34
+ # 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
35
+ # 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
36
+ # }
37
+ # encoder2name = {
38
+ # 'vits': 'Small',
39
+ # 'vitb': 'Base',
40
+ # 'vitl': 'Large',
41
+ # 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
42
+ # }
43
+ # encoder = 'vitl'
44
+ # model_name = encoder2name[encoder]
45
+ # model = DepthAnythingV2(**model_configs[encoder])
46
+ # filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
47
+ # state_dict = torch.load(filepath, map_location="cpu")
48
+ # model.load_state_dict(state_dict)
49
+ # model = model.to(DEVICE).eval()
50
+
51
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
52
+ model = PixelPerfectDepth(semantics_pth='checkpoints/depth_anything_v2_vitl.pth', sampling_steps=10)
53
+ model.load_state_dict(torch.load('checkpoints/ppd.pth', map_location='cpu'), strict=False)
54
+ model = model.to(DEVICE).eval()
55
+
56
+
57
+
58
+ title = "# Pixel-Perfect Depth"
59
+ description = """Official demo for **Pixel-Perfect Depth**.
60
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
61
+
62
+ @spaces.GPU
63
+ def predict_depth(image):
64
+ return model.forward_test(image)
65
+
66
+ with gr.Blocks(css=css) as demo:
67
+ gr.Markdown(title)
68
+ gr.Markdown(description)
69
+ gr.Markdown("### Depth Prediction demo")
70
+
71
+ with gr.Row():
72
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
73
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
74
+ submit = gr.Button(value="Predict Depth")
75
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
76
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
77
+
78
+ cmap = matplotlib.colormaps.get_cmap('Spectral')
79
+
80
+ def on_submit(image):
81
+ original_image = image.copy()
82
+
83
+ # h, w = image.shape[:2]
84
+
85
+ depth = predict_depth(image[:, :, ::-1])
86
+
87
+ raw_depth = Image.fromarray(depth.astype('uint16'))
88
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
89
+ raw_depth.save(tmp_raw_depth.name)
90
+
91
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
92
+ depth = depth.astype(np.uint8)
93
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
94
+
95
+ gray_depth = Image.fromarray(depth)
96
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
97
+ gray_depth.save(tmp_gray_depth.name)
98
+
99
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
100
+
101
+ submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
102
+
103
+ example_files = os.listdir('assets/examples')
104
+ example_files.sort()
105
+ example_files = [os.path.join('assets/examples', filename) for filename in example_files]
106
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ demo.queue().launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio_imageslider
2
+ gradio==4.36.0
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ matplotlib
7
+ huggingface_hub