ZenosArrows commited on
Commit
7ca1c9a
·
verified ·
1 Parent(s): 7f2e027

Add looking glass and upscaling support

Browse files
app.py CHANGED
@@ -1,14 +1,13 @@
 
1
  import gradio as gr
2
- import cv2
3
- import matplotlib
4
  import numpy as np
5
- import os
6
- from PIL import Image
7
- import spaces
8
  import torch
9
  import tempfile
10
- from gradio_imageslider import ImageSlider
11
- from huggingface_hub import hf_hub_download
 
 
 
12
 
13
  from depth_anything_v2.dpt import DepthAnythingV2
14
 
@@ -25,78 +24,271 @@ css = """
25
  #download {
26
  height: 62px;
27
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
  model_configs = {
31
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
33
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
  'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
35
  }
36
- encoder2name = {
37
- 'vits': 'Small',
38
- 'vitb': 'Base',
39
- 'vitl': 'Large',
40
- 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
41
- }
42
- encoder = 'vitl'
43
- model_name = encoder2name[encoder]
44
- model = DepthAnythingV2(**model_configs[encoder])
45
- filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
46
- state_dict = torch.load(filepath, map_location="cpu")
47
- model.load_state_dict(state_dict)
48
- model = model.to(DEVICE).eval()
49
 
50
  title = "# Depth Anything V2"
51
- description = """Official demo for **Depth Anything V2**.
52
- Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
53
 
54
- @spaces.GPU
55
- def predict_depth(image):
56
- return model.infer_image(image)
57
 
58
- with gr.Blocks(css=css) as demo:
59
- gr.Markdown(title)
60
- gr.Markdown(description)
61
- gr.Markdown("### Depth Prediction demo")
62
 
63
- with gr.Row():
64
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
65
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
66
- submit = gr.Button(value="Compute Depth")
67
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
68
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
 
 
 
69
 
70
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
 
 
 
 
 
 
 
 
71
 
72
- def on_submit(image):
73
- original_image = image.copy()
 
 
 
74
 
75
- h, w = image.shape[:2]
 
 
 
 
76
 
77
- depth = predict_depth(image[:, :, ::-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- raw_depth = Image.fromarray(depth.astype('uint16'))
80
- tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
81
- raw_depth.save(tmp_raw_depth.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
84
- depth = depth.astype(np.uint8)
85
- colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
86
 
87
- gray_depth = Image.fromarray(depth)
88
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
89
- gray_depth.save(tmp_gray_depth.name)
 
 
90
 
91
- return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
 
 
 
 
92
 
93
- submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
 
 
 
 
94
 
95
- example_files = os.listdir('assets/examples')
96
- example_files.sort()
97
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
98
- examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
99
 
100
 
101
  if __name__ == '__main__':
102
- demo.queue().launch(share=True)
 
1
+ import glob
2
  import gradio as gr
 
 
3
  import numpy as np
 
 
 
4
  import torch
5
  import tempfile
6
+ import uuid
7
+ from PIL import Image, ImageOps, ImageEnhance
8
+ from pathlib import Path
9
+ from zipfile import ZipFile, is_zipfile
10
+ from pypdf import PdfReader
11
 
12
  from depth_anything_v2.dpt import DepthAnythingV2
13
 
 
24
  #download {
25
  height: 62px;
26
  }
27
+ .thumbnail-item {
28
+ aspect-ratio: var(--ratio-wide)
29
+ }
30
+ .thumbnail-item img {
31
+ object-fit: contain
32
+ }
33
+ """
34
+ head = """
35
+ <script type="module">
36
+ import { BridgeClient, RGBDHologram } from "/gradio_api/file=assets/looking-glass-bridge.js";
37
+ window.BridgeClient = BridgeClient;
38
+ window.RGBDHologram = RGBDHologram;
39
+ window.updating = false;
40
+ window.settings = {
41
+ depthiness: 1.0,
42
+ focus: 0,
43
+ aspect: 1,
44
+ chroma_depth: 0,
45
+ depth_inversion: 0,
46
+ depth_loc: 2,
47
+ depth_cutoff: 1,
48
+ zoom: 1,
49
+ crop_pos_x: 0,
50
+ crop_pos_y: 0,
51
+ };
52
+ window.castHologram = async function() {
53
+ const uri = document.querySelector('#img-display-output .thumbnail-item.selected img').src;
54
+ if (!uri)
55
+ return;
56
+ const Bridge = BridgeClient.getInstance();
57
+ if (!Bridge.isConnected)
58
+ await Bridge.connect();
59
+ await Bridge.getDisplays();
60
+ if (Bridge.isCastPending)
61
+ return;
62
+ const rgbd = new RGBDHologram({ uri, settings });
63
+ await Bridge.cast(rgbd);
64
+ };
65
+ window.updateHologram = async function(value, parameter) {
66
+ settings[parameter] = value;
67
+ const Bridge = BridgeClient.getInstance();
68
+ if (!Bridge.isConnected || window.updating)
69
+ return;
70
+ const name = Bridge.getCurrentPlaylist().name;
71
+ window.updating = true;
72
+ await Bridge.updateCurrentHologram({ name, parameter, value });
73
+ window.updating = false;
74
+ };
75
+ </script>
76
  """
77
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
78
  model_configs = {
79
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
80
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
81
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
82
  'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
83
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  title = "# Depth Anything V2"
86
+ description = """Looking Glass demo for **Depth Anything V2**.
87
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), or [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
88
 
89
+ def predict_depth(image, model):
90
+ w, h = image.size
 
91
 
92
+ depth = model.infer_image(np.array(image.convert("RGB"))[:, :, ::-1])
 
 
 
93
 
94
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
95
+ depth = depth.astype(np.uint8)
96
+
97
+ gray_depth = Image.fromarray(depth)
98
+
99
+ rgbd = Image.new(image.mode, (w * 2, h))
100
+ rgbd.paste(image, (0, 0))
101
+ rgbd.paste(gray_depth, (w, 0))
102
+ return rgbd
103
 
104
+ def upscale_image(image, model, background, discard_alpha):
105
+ if image.mode == "RGBA":
106
+ if discard_alpha:
107
+ image = Image.alpha_composite(ImageOps.pad(background, image.size, color=(0, 0, 0)), image);
108
+ elif image.mode != "RGB":
109
+ image = image.convert("RGB")
110
+ if model is not None:
111
+ image = model.infer(image)
112
+ return image.convert("RGB") if discard_alpha else image
113
 
114
+ def on_submit(image, batch_images, book, config, upscale_model, upscale_method, denoise_level, discard_alpha, progress=gr.Progress()):
115
+ model = DepthAnythingV2(**model_configs[config])
116
+ state_dict = torch.load(f'checkpoints/depth_anything_v2_{config}.pth', map_location="cpu")
117
+ model.load_state_dict(state_dict)
118
+ model = model.to(DEVICE).eval()
119
 
120
+ superresolution = None
121
+ if upscale_method is not None:
122
+ superresolution = torch.hub.load("nagadomi/nunif:master", "waifu2x",
123
+ model_type=upscale_model, method=upscale_method, noise_level=denoise_level,
124
+ keep_alpha=not discard_alpha, trust_repo=True).to(DEVICE)
125
 
126
+ gradient = ImageEnhance.Brightness(Image.radial_gradient("L"))
127
+ background = ImageOps.invert(gradient.enhance(1.5)).convert("RGBA")
128
+
129
+ result = []
130
+ if image is not None:
131
+ image = upscale_image(image, superresolution, background, discard_alpha)
132
+ result.append((predict_depth(image, model), None))
133
+ if batch_images is not None:
134
+ for path in progress.tqdm(batch_images):
135
+ with Image.open(path) as img:
136
+ img = upscale_image(img, superresolution, background, discard_alpha)
137
+ result.append((predict_depth(img, model), Path(path).name))
138
+ if book is not None:
139
+ if is_zipfile(book):
140
+ with ZipFile(book, "r") as zf:
141
+ for entry in progress.tqdm(zf.infolist()):
142
+ with zf.open(entry) as file:
143
+ with Image.open(file) as img:
144
+ img = upscale_image(img, superresolution, background, discard_alpha)
145
+ result.append((predict_depth(img, model), entry.filename))
146
+ else:
147
+ reader = PdfReader(book)
148
+ for page in progress.tqdm(reader.pages):
149
+ for image_file_object in page.images:
150
+ img = upscale_image(image_file_object.image, superresolution, background, discard_alpha)
151
+ result.append((predict_depth(img, model), image_file_object.name))
152
+ return result
153
+
154
+ def zip_gallery(gallery, progress=gr.Progress()):
155
+ if gallery is None:
156
+ return None
157
+ if len(gallery) == 1:
158
+ return gallery[0][0]
159
+ temp = Path(tempfile.gettempdir()) / uuid.uuid4().hex
160
+ zip = temp.with_suffix(".zip")
161
+ with ZipFile(zip, "w") as zf:
162
+ for index, image in progress.tqdm(enumerate(gallery)):
163
+ fn = Path(image[0]).name if image[1] is None else Path(image[1]).with_suffix(".rgbd.png")
164
+ zf.write(image[0], "{:02d}_{}".format(index, fn))
165
+ return zip
166
+
167
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
168
+
169
+ with gr.Blocks(css=css, head=head) as demo:
170
+ gr.Markdown(title)
171
+ gr.Markdown(description)
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ with gr.Tab("Single Image"):
176
+ input_image = gr.Image(
177
+ label="Input Image",
178
+ elem_id='img-display-input',
179
+ type='pil',
180
+ image_mode=None
181
+ )
182
+ with gr.Tab("Batch Mode"):
183
+ batch_images = gr.File(
184
+ label="Images",
185
+ file_types=["image"],
186
+ file_count="multiple"
187
+ )
188
+ with gr.Tab("Document Mode"):
189
+ book = gr.File(
190
+ label="Document",
191
+ file_types=[".pdf", ".zip"],
192
+ )
193
+ with gr.Row():
194
+ clear = gr.ClearButton(components=[input_image, batch_images, book])
195
+ submit = gr.Button(value="Compute Depth", variant="primary")
196
+ model_size = gr.Radio(
197
+ label="Model Size",
198
+ choices=[('Small', 'vits'), ('Base', 'vitb'), ('Large', 'vitl')],
199
+ value="vitl"
200
+ )
201
+ upscale_method = gr.Radio(
202
+ label="Upscale Method",
203
+ choices=[("No Upscaling or Denoising", None), ("Denoise Only", "noise"), ("2x Upscaling", "scale2x"), ("4x Upscaling", "scale4x")]
204
+ )
205
+ upscale_model = gr.Dropdown(
206
+ choices=["art", "art_scan", "photo", "swin_unet/art", "swin_unet/art_scan", "swin_unet/photo", "cunet/art", "upconv_7/art", "upconv_7/photo"],
207
+ label="Upscaling Model",
208
+ value="art"
209
+ )
210
+ denoise_level = gr.Slider(
211
+ label="Denoise Level (-1 = None)",
212
+ value=0,
213
+ step=1,
214
+ minimum=-1,
215
+ maximum=4
216
+ )
217
+ discard_alpha = gr.Checkbox(label="Add radial gradient background to transparent images", value=True)
218
 
219
+ with gr.Column():
220
+ gallery = gr.Gallery(
221
+ label="RGBD Images",
222
+ elem_id='img-display-output',
223
+ format="png",
224
+ columns=4,
225
+ object_fit="contain",
226
+ preview=True,
227
+ interactive=True
228
+ )
229
+ download_btn = gr.DownloadButton()
230
+ depthiness = gr.Slider(
231
+ label="Depthiness",
232
+ elem_id="depthiness",
233
+ interactive=True,
234
+ minimum=0,
235
+ maximum=3,
236
+ value=1
237
+ )
238
+ focus = gr.Slider(
239
+ label="Focus",
240
+ interactive=True,
241
+ minimum=-0.03,
242
+ maximum=0.03,
243
+ value=0
244
+ )
245
+ zoom = gr.Slider(
246
+ label="Zoom",
247
+ interactive=True,
248
+ minimum=0,
249
+ maximum=10,
250
+ value=1
251
+ )
252
+ pos_x = gr.Slider(
253
+ label="Position X",
254
+ interactive=True,
255
+ minimum=-1,
256
+ maximum=1,
257
+ value=0
258
+ )
259
+ pos_y = gr.Slider(
260
+ label="Position Y",
261
+ interactive=True,
262
+ minimum=-1,
263
+ maximum=1,
264
+ value=0
265
+ )
266
+ reset = gr.Button(value="Reset All Parameters")
267
 
268
+ gallery.select(fn=None, js="castHologram")
269
+ gallery.change(fn=zip_gallery, inputs=gallery, outputs=download_btn).then(fn=None, js="castHologram")
 
270
 
271
+ submit.click(
272
+ on_submit,
273
+ inputs=[input_image, batch_images, book, model_size, upscale_model, upscale_method, denoise_level, discard_alpha],
274
+ outputs=[gallery]
275
+ ).then(fn=zip_gallery, inputs=gallery, outputs=download_btn).then(fn=None, js="castHologram")
276
 
277
+ depthiness.change(fn=None, inputs=depthiness, js="(value) => updateHologram (value, 'depthiness')")
278
+ focus.change(fn=None, inputs=focus, js="(value) => updateHologram (value, 'focus')")
279
+ zoom.change(fn=None, inputs=zoom, js="(value) => updateHologram (value, 'zoom')")
280
+ pos_x.change(fn=None, inputs=pos_x, js="(value) => updateHologram (value, 'crop_pos_x')")
281
+ pos_y.change(fn=None, inputs=pos_y, js="(value) => updateHologram (value, 'crop_pos_y')")
282
 
283
+ reset.click(fn=None, js="""
284
+ () => {
285
+ document.querySelectorAll('button.reset-button').forEach(b => b.click());
286
+ }
287
+ """)
288
 
289
+ example_files = glob.glob('assets/examples/*')
290
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[gallery], fn=on_submit)
 
 
291
 
292
 
293
  if __name__ == '__main__':
294
+ demo.queue().launch()
assets/looking-glass-bridge.js ADDED
The diff for this file is too large to render. See raw diff
 
assets/looking-glass-bridge.js.map ADDED
The diff for this file is too large to render. See raw diff