johnrobinsn commited on
Commit
b2a227f
·
verified ·
1 Parent(s): ff24905

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -1,34 +1,40 @@
1
  import gradio as gr
2
- from transformers import DPTImageProcessor, DPTForDepthEstimation
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from pathlib import Path
7
  from depth_viewer import depthviewer2html
8
 
9
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
10
- model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
 
 
 
 
 
 
 
 
 
11
 
12
  def process_image(image_path):
13
  if image_path is None:
14
  return ""
15
 
 
 
16
  image_path = Path(image_path)
17
  image = Image.open(image_path)
18
 
19
- # if wider than 512 pixels let's resample to keep it performant
20
  if image.size[0] > 512:
21
  image = image.resize((512, int(512 * image.size[1] / image.size[0])), Image.Resampling.LANCZOS)
22
 
23
- # prepare image for the model
24
  inputs = feature_extractor(images=image, return_tensors="pt")
25
 
26
- # forward pass
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
  predicted_depth = outputs.predicted_depth
30
 
31
- # interpolate to original size
32
  prediction = torch.nn.functional.interpolate(
33
  predicted_depth.unsqueeze(1),
34
  size=image.size[::-1],
@@ -42,22 +48,21 @@ def process_image(image_path):
42
  return depthviewer2html(image, depth)
43
 
44
  title = "3D Visualization of Depth Maps Generated using MiDaS"
45
- description = "Improved 3D interactive depth viewer using Three.js embedded in a Gradio app."
46
 
47
  with gr.Blocks(css="#depth-viewer { height: 600px; }") as demo:
48
  gr.Markdown(f"# {title}")
49
  gr.Markdown(description)
50
 
51
- with gr.Row():
52
- input_image = gr.Image(type="filepath", label="Input Image")
53
-
54
  output_html = gr.HTML(label="Depth Viewer", elem_id="depth-viewer")
55
 
56
  input_image.change(fn=process_image, inputs=input_image, outputs=output_html)
57
 
58
  gr.Examples(
59
  examples=[["examples/owl1.jpg"], ["examples/marsattacks.jpg"], ["examples/kitten.jpg"]],
60
- inputs=input_image
 
61
  )
62
 
63
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
  from pathlib import Path
6
  from depth_viewer import depthviewer2html
7
 
8
+ # Don't load models at module level
9
+ _model = None
10
+ _feature_extractor = None
11
+
12
+ def get_model():
13
+ global _model, _feature_extractor
14
+ if _model is None:
15
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
16
+ _feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
17
+ _model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
18
+ return _feature_extractor, _model
19
 
20
  def process_image(image_path):
21
  if image_path is None:
22
  return ""
23
 
24
+ feature_extractor, model = get_model()
25
+
26
  image_path = Path(image_path)
27
  image = Image.open(image_path)
28
 
 
29
  if image.size[0] > 512:
30
  image = image.resize((512, int(512 * image.size[1] / image.size[0])), Image.Resampling.LANCZOS)
31
 
 
32
  inputs = feature_extractor(images=image, return_tensors="pt")
33
 
 
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
  predicted_depth = outputs.predicted_depth
37
 
 
38
  prediction = torch.nn.functional.interpolate(
39
  predicted_depth.unsqueeze(1),
40
  size=image.size[::-1],
 
48
  return depthviewer2html(image, depth)
49
 
50
  title = "3D Visualization of Depth Maps Generated using MiDaS"
51
+ description = "Improved 3D interactive depth viewer using Three.js"
52
 
53
  with gr.Blocks(css="#depth-viewer { height: 600px; }") as demo:
54
  gr.Markdown(f"# {title}")
55
  gr.Markdown(description)
56
 
57
+ input_image = gr.Image(type="filepath", label="Input Image")
 
 
58
  output_html = gr.HTML(label="Depth Viewer", elem_id="depth-viewer")
59
 
60
  input_image.change(fn=process_image, inputs=input_image, outputs=output_html)
61
 
62
  gr.Examples(
63
  examples=[["examples/owl1.jpg"], ["examples/marsattacks.jpg"], ["examples/kitten.jpg"]],
64
+ inputs=input_image,
65
+ cache_examples=False
66
  )
67
 
68
  demo.launch(server_name="0.0.0.0")