wop commited on
Commit
9d05699
·
verified ·
1 Parent(s): c70fe14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -17
app.py CHANGED
@@ -27,35 +27,53 @@ css = """
27
  height: 62px;
28
  }
29
  """
 
30
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
31
  model_configs = {
32
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
34
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
35
- 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
36
  }
 
37
  encoder2name = {
38
  'vits': 'Small',
39
  'vitb': 'Base',
40
  'vitl': 'Large',
41
- 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
42
  }
 
43
  encoder = 'vits'
44
  model_name = encoder2name[encoder]
 
45
  model = DepthAnythingV2(**model_configs[encoder])
46
- filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
 
 
 
 
47
  state_dict = torch.load(filepath, map_location="cpu")
48
  model.load_state_dict(state_dict)
49
  model = model.to(DEVICE).eval()
50
 
51
  title = "# Depth Anything V2"
52
  description = """Official demo for **Depth Anything V2**.
53
- Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
 
 
54
 
55
  @spaces.GPU
56
  def predict_depth(image):
57
  return model.infer_image(image)
58
 
 
 
 
 
 
 
 
59
  with gr.Blocks(css=css) as demo:
60
  gr.Markdown(title)
61
  gr.Markdown(description)
@@ -63,18 +81,20 @@ with gr.Blocks(css=css) as demo:
63
 
64
  with gr.Row():
65
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
66
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
 
 
 
 
 
67
  submit = gr.Button(value="Compute Depth")
68
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
69
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
70
 
71
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
72
 
73
  def on_submit(image):
74
  original_image = image.copy()
75
-
76
- h, w = image.shape[:2]
77
-
78
  depth = predict_depth(image[:, :, ::-1])
79
 
80
  raw_depth = Image.fromarray(depth.astype('uint16'))
@@ -91,13 +111,22 @@ with gr.Blocks(css=css) as demo:
91
 
92
  return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
93
 
94
- submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
95
-
96
- example_files = os.listdir('assets/examples')
97
- example_files.sort()
98
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
99
- examples = gr.Examples(cache_examples=False, examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
100
-
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == '__main__':
103
  demo.queue().launch(share=True)
 
27
  height: 62px;
28
  }
29
  """
30
+
31
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+
33
  model_configs = {
34
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
35
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
36
  'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
37
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]},
38
  }
39
+
40
  encoder2name = {
41
  'vits': 'Small',
42
  'vitb': 'Base',
43
  'vitl': 'Large',
44
+ 'vitg': 'Giant',
45
  }
46
+
47
  encoder = 'vits'
48
  model_name = encoder2name[encoder]
49
+
50
  model = DepthAnythingV2(**model_configs[encoder])
51
+ filepath = hf_hub_download(
52
+ repo_id=f"depth-anything/Depth-Anything-V2-{model_name}",
53
+ filename=f"depth_anything_v2_{encoder}.pth",
54
+ repo_type="model"
55
+ )
56
  state_dict = torch.load(filepath, map_location="cpu")
57
  model.load_state_dict(state_dict)
58
  model = model.to(DEVICE).eval()
59
 
60
  title = "# Depth Anything V2"
61
  description = """Official demo for **Depth Anything V2**.
62
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414),
63
+ [project page](https://depth-anything-v2.github.io),
64
+ and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
65
 
66
  @spaces.GPU
67
  def predict_depth(image):
68
  return model.infer_image(image)
69
 
70
+ # -------------------------------------
71
+ # OLD GRADIO COMPATIBILITY PATCH
72
+ # -------------------------------------
73
+ if not hasattr(gr.Blocks, "get_api_info"):
74
+ gr.Blocks.get_api_info = lambda self: {}
75
+ # -------------------------------------
76
+
77
  with gr.Blocks(css=css) as demo:
78
  gr.Markdown(title)
79
  gr.Markdown(description)
 
81
 
82
  with gr.Row():
83
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
84
+ depth_image_slider = ImageSlider(
85
+ label="Depth Map with Slider View",
86
+ elem_id='img-display-output',
87
+ position=0.5
88
+ )
89
+
90
  submit = gr.Button(value="Compute Depth")
91
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
92
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
93
 
94
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
95
 
96
  def on_submit(image):
97
  original_image = image.copy()
 
 
 
98
  depth = predict_depth(image[:, :, ::-1])
99
 
100
  raw_depth = Image.fromarray(depth.astype('uint16'))
 
111
 
112
  return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
113
 
114
+ submit.click(
115
+ on_submit,
116
+ inputs=[input_image],
117
+ outputs=[depth_image_slider, gray_depth_file, raw_file]
118
+ )
119
+
120
+ if os.path.exists('assets/examples'):
121
+ example_files = sorted(os.listdir('assets/examples'))
122
+ example_files = [os.path.join('assets/examples', f) for f in example_files]
123
+ gr.Examples(
124
+ cache_examples=False,
125
+ examples=example_files,
126
+ inputs=[input_image],
127
+ outputs=[depth_image_slider, gray_depth_file, raw_file],
128
+ fn=on_submit
129
+ )
130
 
131
  if __name__ == '__main__':
132
  demo.queue().launch(share=True)