Spaces:
Runtime error
Runtime error
commit
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from gradio_client import Client, handle_file
|
|
| 4 |
from pathlib import Path
|
| 5 |
from gradio.utils import get_cache_folder
|
| 6 |
|
|
|
|
| 7 |
|
| 8 |
class Examples(gr.helpers.Examples):
|
| 9 |
def __init__(self, *args, cached_folder=None, **kwargs):
|
|
@@ -21,37 +22,104 @@ client = Client("Canyu/Diception",
|
|
| 21 |
hf_token=HF_TOKEN)
|
| 22 |
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if path_input is None:
|
| 26 |
raise gr.Error(
|
| 27 |
"Missing image in the left pane: please upload an image first."
|
| 28 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
return client.predict(
|
| 32 |
-
|
| 33 |
-
api_name="/
|
| 34 |
)
|
| 35 |
|
| 36 |
def clear_cache():
|
| 37 |
return None, None
|
| 38 |
|
| 39 |
def run_demo_server():
|
|
|
|
| 40 |
gradio_theme = gr.themes.Default()
|
| 41 |
with gr.Blocks(
|
| 42 |
theme=gradio_theme,
|
| 43 |
title="Matting",
|
| 44 |
) as demo:
|
| 45 |
with gr.Row():
|
| 46 |
-
gr.Markdown("#
|
| 47 |
with gr.Row():
|
| 48 |
-
gr.Markdown("###
|
|
|
|
|
|
|
|
|
|
| 49 |
with gr.Row():
|
| 50 |
with gr.Column():
|
| 51 |
matting_image_input = gr.Image(
|
| 52 |
label="Input Image",
|
| 53 |
type="filepath",
|
| 54 |
)
|
|
|
|
| 55 |
with gr.Row():
|
| 56 |
matting_image_submit_btn = gr.Button(
|
| 57 |
value="Estimate Matting", variant="primary"
|
|
@@ -80,7 +148,7 @@ def run_demo_server():
|
|
| 80 |
|
| 81 |
matting_image_submit_btn.click(
|
| 82 |
fn=process_image_check,
|
| 83 |
-
inputs=matting_image_input,
|
| 84 |
outputs=None,
|
| 85 |
preprocess=False,
|
| 86 |
queue=False,
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from gradio.utils import get_cache_folder
|
| 6 |
|
| 7 |
+
from PIL import Image
|
| 8 |
|
| 9 |
class Examples(gr.helpers.Examples):
|
| 10 |
def __init__(self, *args, cached_folder=None, **kwargs):
|
|
|
|
| 22 |
hf_token=HF_TOKEN)
|
| 23 |
|
| 24 |
|
| 25 |
+
map_prompt = {
|
| 26 |
+
'depth': '[[image2depth]]',
|
| 27 |
+
'normal': '[[image2normal]]',
|
| 28 |
+
'pose': '[[image2pose]]',
|
| 29 |
+
'entity segmentation': '[[image2panoptic coarse]]',
|
| 30 |
+
'point segmentation': '[[image2segmentation]]',
|
| 31 |
+
'semantic segmentation': '[[image2semantic]]',
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def download_additional_params(model_name, filename="add_params.bin"):
|
| 35 |
+
# 下载文件并返回文件路径
|
| 36 |
+
file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN)
|
| 37 |
+
return file_path
|
| 38 |
+
|
| 39 |
+
# 加载 additional_params.bin 文件
|
| 40 |
+
def load_additional_params(model_name):
|
| 41 |
+
# 下载 additional_params.bin
|
| 42 |
+
params_path = download_additional_params(model_name)
|
| 43 |
+
|
| 44 |
+
# 使用 torch.load() 加载文件内容
|
| 45 |
+
additional_params = torch.load(params_path, map_location='cpu')
|
| 46 |
+
|
| 47 |
+
# 返回加载的参数内容
|
| 48 |
+
return additional_params
|
| 49 |
+
|
| 50 |
+
def process_image_check(path_input, prompt):
|
| 51 |
if path_input is None:
|
| 52 |
raise gr.Error(
|
| 53 |
"Missing image in the left pane: please upload an image first."
|
| 54 |
)
|
| 55 |
+
if len(prompt) == 0:
|
| 56 |
+
raise gr.Error(
|
| 57 |
+
"At least 1 prediction type is needed."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def process_image_4(image_path, prompt):
|
| 63 |
+
|
| 64 |
+
inputs = []
|
| 65 |
+
for p in prompt:
|
| 66 |
+
image = Image.open(image_path)
|
| 67 |
+
|
| 68 |
+
w, h = image.size
|
| 69 |
+
|
| 70 |
+
coor_point = torch.zeros((1,5,2)).to(torch.float32)
|
| 71 |
+
point_labels = torch.zeros((1,5,1)).to(torch.float32)
|
| 72 |
|
| 73 |
+
image = image.resize((768, 768), Image.LANCZOS).convert('RGB')
|
| 74 |
+
to_tensor = transforms.ToTensor()
|
| 75 |
+
image = (to_tensor(image) - 0.5) * 2
|
| 76 |
+
|
| 77 |
+
cur_input = {
|
| 78 |
+
'input_images': image.unsqueeze(0),
|
| 79 |
+
'original_size': torch.tensor([[w,h]]),
|
| 80 |
+
'target_size': torch.tensor([[768, 768]]),
|
| 81 |
+
'prompt': [p],
|
| 82 |
+
'coor_point': coor_point,
|
| 83 |
+
'point_labels': point_labels,
|
| 84 |
+
'generator': generator
|
| 85 |
+
}
|
| 86 |
+
inputs.append(cur_input)
|
| 87 |
+
|
| 88 |
+
return inputs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def infer_image_matting(image_path, prompt):
|
| 92 |
+
inputs = process_image_4(image_path, prompt)
|
| 93 |
+
return None
|
| 94 |
return client.predict(
|
| 95 |
+
batch=inputs,
|
| 96 |
+
api_name="/inf"
|
| 97 |
)
|
| 98 |
|
| 99 |
def clear_cache():
|
| 100 |
return None, None
|
| 101 |
|
| 102 |
def run_demo_server():
|
| 103 |
+
options = ['depth', 'normal', 'entity', 'pose']
|
| 104 |
gradio_theme = gr.themes.Default()
|
| 105 |
with gr.Blocks(
|
| 106 |
theme=gradio_theme,
|
| 107 |
title="Matting",
|
| 108 |
) as demo:
|
| 109 |
with gr.Row():
|
| 110 |
+
gr.Markdown("# Diception Demo")
|
| 111 |
with gr.Row():
|
| 112 |
+
gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
|
| 113 |
+
with gr.Row():
|
| 114 |
+
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
|
| 115 |
+
|
| 116 |
with gr.Row():
|
| 117 |
with gr.Column():
|
| 118 |
matting_image_input = gr.Image(
|
| 119 |
label="Input Image",
|
| 120 |
type="filepath",
|
| 121 |
)
|
| 122 |
+
|
| 123 |
with gr.Row():
|
| 124 |
matting_image_submit_btn = gr.Button(
|
| 125 |
value="Estimate Matting", variant="primary"
|
|
|
|
| 148 |
|
| 149 |
matting_image_submit_btn.click(
|
| 150 |
fn=process_image_check,
|
| 151 |
+
inputs=[matting_image_input, checkbox_group],
|
| 152 |
outputs=None,
|
| 153 |
preprocess=False,
|
| 154 |
queue=False,
|