File size: 3,946 Bytes
14affff ab27da4 a7638ff 14affff ab27da4 14affff ab27da4 55862cf ab27da4 14affff 10cc12f ab27da4 10cc12f 14affff ab27da4 de11222 ab27da4 684b020 ab27da4 0a26118 ab27da4 169d248 ab27da4 fdbdc38 ab27da4 a7638ff ab27da4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os
import torch
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
)
from transformers import CLIPTextModel, CLIPTokenizer
from apdepth.marigold_pipeline import MarigoldPipeline
from apdepth.modules.unet_2d_condition import UNet2DConditionModel
def load_example(example_images):
# 返回选中的图片
return example_images
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "developy/ApDepth" # Replace to the model you would like to use
torch_dtype = torch.float32
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=torch_dtype, allow_pickle=False)
unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=torch_dtype, allow_pickle=False)
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer", torch_dtype=torch_dtype)
pipe = MarigoldPipeline(vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer)
try:
pipe.enable_xformers_memory_efficient_attention()
except ImportError:
pass # run without xformers
pipe = pipe.to(device)
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
input_image,
progress=gr.Progress(track_tqdm=True),
):
pipe_out = pipe(
input_image,
processing_res=768,
match_input_res=True,
batch_size=1,
color_map="Spectral",
show_progress_bar=True,
resample_method="bilinear",
)
# depth_pred: np.ndarray = pipe_out.depth_np
depth_colored: Image.Image = pipe_out.depth_colored
return depth_colored
# 默认图像路径
example_images = [
"example/00.jpg",
"example/01.jpg",
"example/02.jpg",
"example/03.jpg",
"example/04.jpg",
"example/05.jpg",
"example/06.jpg",
"example/07.jpg",
"example/08.jpg",
]
# css = """
# #col-container {
# margin: 0 auto;
# max-width: 640px;
# }
# #example-gallery {
# height: 80px; /* 设置缩略图高度 */
# width: auto; /* 保持宽高比 */
# margin: 0 auto; /* 图片间距 */
# cursor: pointer; /* 鼠标指针变为手型 */
# }
# """
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
title = "# ApDepth"
description = """**Official demo for ApDepth**(We provide models trained using Depth Anything v2-base here, as the Hugging Face space is limited to 1GB.).
Please refer to our [website](https://haruko386.github.io/research/) for more details."""
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(" ### Depth Estimation with ApDepth.")
# with gr.Column(elem_id="col-container"):
# gr.Markdown(" # Depth Estimation")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", elem_id="img-display-input")
with gr.Column():
# depth_img_slider = ImageSlider(label="Depth Map with Slider View", elem_id="img-display-output", position=0.5)
depth_map = gr.Image(label="Depth Image", type="pil", interactive=False, elem_id="depth-map")
# 计算按钮
compute_button = gr.Button(value="Compute Depth")
# 设置计算按钮的回调
compute_button.click(
fn=infer, # 回调函数
inputs=[input_image], # 输入
outputs=[depth_map] # 输出
)
example_files = os.listdir('example')
example_files.sort()
example_files = [os.path.join('example', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_map], fn=infer)
# 启动 Gradio 应用
demo.queue().launch(share=True)
|