File size: 3,946 Bytes
14affff
ab27da4
 
 
a7638ff
14affff
ab27da4
 
 
14affff
ab27da4
55862cf
ab27da4
14affff
10cc12f
ab27da4
10cc12f
14affff
 
ab27da4
 
 
 
 
 
 
 
 
de11222
ab27da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684b020
 
 
 
 
 
 
 
 
ab27da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a26118
ab27da4
 
 
 
 
 
 
 
 
 
 
 
 
 
169d248
ab27da4
 
 
 
 
 
 
 
 
 
 
fdbdc38
ab27da4
 
 
a7638ff
 
ab27da4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import os
import torch

import torch
from PIL import Image
from diffusers import (
    AutoencoderKL,
)

from transformers import CLIPTextModel, CLIPTokenizer
from apdepth.marigold_pipeline import MarigoldPipeline
from apdepth.modules.unet_2d_condition import UNet2DConditionModel

def load_example(example_images):
    # 返回选中的图片
    return example_images


device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "developy/ApDepth"  # Replace to the model you would like to use

torch_dtype = torch.float32

vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=torch_dtype, allow_pickle=False)
unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=torch_dtype, allow_pickle=False)
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer", torch_dtype=torch_dtype)
pipe = MarigoldPipeline(vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer)


try:
    pipe.enable_xformers_memory_efficient_attention()
except ImportError:
    pass  # run without xformers

pipe = pipe.to(device)


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    input_image,
    progress=gr.Progress(track_tqdm=True),
):

    pipe_out = pipe(
            input_image,
            processing_res=768,
            match_input_res=True,
            batch_size=1,
            color_map="Spectral",
            show_progress_bar=True,
            resample_method="bilinear",
        )

    # depth_pred: np.ndarray = pipe_out.depth_np
    depth_colored: Image.Image = pipe_out.depth_colored


    return depth_colored


# 默认图像路径
example_images = [
    "example/00.jpg",
    "example/01.jpg",
    "example/02.jpg",
    "example/03.jpg",
    "example/04.jpg",
    "example/05.jpg",
    "example/06.jpg",
    "example/07.jpg",
    "example/08.jpg",
]

# css = """
# #col-container {
#     margin: 0 auto;
#     max-width: 640px;
# }
# #example-gallery {
#     height: 80px; /* 设置缩略图高度 */
#     width: auto;  /* 保持宽高比 */
#     margin: 0 auto;  /* 图片间距 */
#     cursor: pointer; /* 鼠标指针变为手型 */
# }
# """

css = """
#img-display-container {
    max-height: 100vh;
}
#img-display-input {
    max-height: 80vh;
}
#img-display-output {
    max-height: 80vh;
}
#download {
    height: 62px;
}
"""

title = "# ApDepth"
description = """**Official demo for ApDepth**(We provide models trained using Depth Anything v2-base here, as the Hugging Face space is limited to 1GB.).
Please refer to our [website](https://haruko386.github.io/research/) for more details."""


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(" ### Depth Estimation with ApDepth.")
    # with gr.Column(elem_id="col-container"):
    #     gr.Markdown(" # Depth Estimation")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil", elem_id="img-display-input")
        with gr.Column():
            # depth_img_slider = ImageSlider(label="Depth Map with Slider View", elem_id="img-display-output", position=0.5)
            depth_map = gr.Image(label="Depth Image", type="pil", interactive=False, elem_id="depth-map")

    # 计算按钮
    compute_button = gr.Button(value="Compute Depth")

    # 设置计算按钮的回调
    compute_button.click(
        fn=infer,  # 回调函数
        inputs=[input_image],  # 输入
        outputs=[depth_map]  # 输出
    )

    example_files = os.listdir('example')
    example_files.sort()
    example_files = [os.path.join('example', filename) for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_map], fn=infer)


# 启动 Gradio 应用
demo.queue().launch(share=True)