developy commited on
Commit
ab27da4
·
verified ·
1 Parent(s): 6735308

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -20
app.py CHANGED
@@ -1,29 +1,137 @@
1
  import gradio as gr
2
- from diffusers import MarigoldDepthPipeline, DDIMScheduler
 
 
3
  import torch
4
  from PIL import Image
 
 
 
5
 
6
- CHECKPOINT = "developy/ApDepth"
 
 
7
 
8
- device = "cpu"
9
- dtype = torch.float32
 
10
 
11
- pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
12
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
13
- pipe = pipe.to(device=device, dtype=dtype)
14
 
15
- def predict(image: Image.Image):
16
- out = pipe(image)
17
- depth_vis = pipe.image_processor.visualize_depth(out.prediction)[0]
18
- return depth_vis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- demo = gr.Interface(
21
- fn=predict,
22
- inputs=gr.Image(type="pil", label="Input Image"),
23
- outputs=gr.Image(type="pil", label="Depth Map"),
24
- title="ApDepth Demo",
25
- description="Monocular Depth Estimation based on Marigold"
26
- )
27
 
28
- if __name__ == "__main__":
29
- demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
+
5
  import torch
6
  from PIL import Image
7
+ from diffusers import (
8
+ AutoencoderKL,
9
+ )
10
 
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ from apdepth import MarigoldPipeline
13
+ from apdepth.modules.unet_2d_condition import UNet2DConditionModel
14
 
15
+ def load_example(example_image):
16
+ # 返回选中的图片
17
+ return example_image
18
 
 
 
 
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model_repo_id = "developy/ApDepth" # Replace to the model you would like to use
22
+
23
+ torch_dtype = torch.float32
24
+
25
+ vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=torch_dtype, allow_pickle=False)
26
+ unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=torch_dtype, allow_pickle=False)
27
+ text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype)
28
+ tokenizer = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer", torch_dtype=torch_dtype)
29
+ pipe = DepthMasterPipeline(vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer)
30
+
31
+
32
+ try:
33
+ pipe.enable_xformers_memory_efficient_attention()
34
+ except ImportError:
35
+ pass # run without xformers
36
+
37
+ pipe = pipe.to(device)
38
+
39
+
40
+ # @spaces.GPU #[uncomment to use ZeroGPU]
41
+ def infer(
42
+ input_image,
43
+ progress=gr.Progress(track_tqdm=True),
44
+ ):
45
+
46
+ pipe_out = pipe(
47
+ input_image,
48
+ processing_res=768,
49
+ match_input_res=True,
50
+ batch_size=1,
51
+ color_map="Spectral",
52
+ show_progress_bar=True,
53
+ resample_method="bilinear",
54
+ )
55
+
56
+ # depth_pred: np.ndarray = pipe_out.depth_np
57
+ depth_colored: Image.Image = pipe_out.depth_colored
58
+
59
+
60
+ return depth_colored
61
+
62
+
63
+ # 默认图像路径
64
+ example_images = [
65
+ "example/example_0.jpg",
66
+ "example/example_1.jpg",
67
+ "example/example_2.jpg",
68
+ "example/example_3.jpg",
69
+ "example/example_4.jpg",
70
+ "example/example_5.jpg",
71
+ "example/example_6.jpg"
72
+ ]
73
+
74
+ # css = """
75
+ # #col-container {
76
+ # margin: 0 auto;
77
+ # max-width: 640px;
78
+ # }
79
+ # #example-gallery {
80
+ # height: 80px; /* 设置缩略图高度 */
81
+ # width: auto; /* 保持宽高比 */
82
+ # margin: 0 auto; /* 图片间距 */
83
+ # cursor: pointer; /* 鼠标指针变为手型 */
84
+ # }
85
+ # """
86
+
87
+ css = """
88
+ #img-display-container {
89
+ max-height: 100vh;
90
+ }
91
+ #img-display-input {
92
+ max-height: 80vh;
93
+ }
94
+ #img-display-output {
95
+ max-height: 80vh;
96
+ }
97
+ #download {
98
+ height: 62px;
99
+ }
100
+ """
101
+
102
+ title = "# ApDepth"
103
+ description = """**Official demo for ApDepth**.
104
+ Please refer to our [website](https://haruko386.github.io/research/) for more details."""
105
+
106
+
107
+ with gr.Blocks(css=css) as demo:
108
+ gr.Markdown(title)
109
+ gr.Markdown(description)
110
+ gr.Markdown(" ### Depth Estimation with ApDepth.")
111
+ # with gr.Column(elem_id="col-container"):
112
+ # gr.Markdown(" # Depth Estimation")
113
+ with gr.Row():
114
+ with gr.Column():
115
+ input_image = gr.Image(label="Input Image", type="pil", elem_id="img-display-input")
116
+ with gr.Column():
117
+ # depth_img_slider = ImageSlider(label="Depth Map with Slider View", elem_id="img-display-output", position=0.5)
118
+ depth_map = gr.Image(label="Depth Map with Slider View", type="pil", interactive=False, elem_id="depth-map")
119
+
120
+ # 计算按钮
121
+ compute_button = gr.Button(value="Compute Depth")
122
+
123
+ # 设置计算按钮的回调
124
+ compute_button.click(
125
+ fn=infer, # 回调函数
126
+ inputs=[input_image], # 输入
127
+ outputs=[depth_map] # 输出
128
+ )
129
+
130
+ example_files = os.listdir('example')
131
+ example_files.sort()
132
+ example_files = [os.path.join('example', filename) for filename in example_files]
133
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_map], fn=infer)
134
 
 
 
 
 
 
 
 
135
 
136
+ # 启动 Gradio 应用
137
+ demo.queue().launch(share=True)