Spaces:
Runtime error
Runtime error
commit
Browse files
app.py
CHANGED
|
@@ -20,8 +20,28 @@ class Examples(gr.helpers.Examples):
|
|
| 20 |
self.create()
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# user click the image to get points, and show the points on the image
|
| 24 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
|
|
|
| 25 |
if len(sel_pix) < 5:
|
| 26 |
sel_pix.append((evt.index, 1)) # default foreground_point
|
| 27 |
img = cv2.imread(img)
|
|
@@ -54,11 +74,11 @@ def undo_points(orig_img, sel_pix):
|
|
| 54 |
return temp, sel_pix
|
| 55 |
|
| 56 |
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
colors = [(255, 0, 0), (0, 255, 0)]
|
| 64 |
markers = [1, 5]
|
|
@@ -89,12 +109,6 @@ def load_additional_params(model_name):
|
|
| 89 |
return additional_params
|
| 90 |
|
| 91 |
def process_image_check(path_input, prompt, sel_points, semantic):
|
| 92 |
-
print('=========== PROCESS IMAGE CHECK ===========')
|
| 93 |
-
print(f"Image Path: {path_input}")
|
| 94 |
-
print(f"Prompt: {prompt}")
|
| 95 |
-
print(f"Selected Points (before processing): {sel_points}")
|
| 96 |
-
print(f"Semantic Input: {semantic}")
|
| 97 |
-
print('===========================================')
|
| 98 |
if path_input is None:
|
| 99 |
raise gr.Error(
|
| 100 |
"Missing image in the left pane: please upload an image first."
|
|
@@ -103,23 +117,6 @@ def process_image_check(path_input, prompt, sel_points, semantic):
|
|
| 103 |
raise gr.Error(
|
| 104 |
"At least 1 prediction type is needed."
|
| 105 |
)
|
| 106 |
-
if 'point segmentation' in prompt and len(sel_points) == 0:
|
| 107 |
-
raise gr.Error(
|
| 108 |
-
"At least 1 point is needed."
|
| 109 |
-
)
|
| 110 |
-
if 'point segmentation' not in prompt and len(sel_points) != 0:
|
| 111 |
-
raise gr.Error(
|
| 112 |
-
"You must select 'point segmentation' when performing point segmentation."
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
if 'semantic segmentation' in prompt and semantic == None:
|
| 116 |
-
raise gr.Error(
|
| 117 |
-
"Target category is needed."
|
| 118 |
-
)
|
| 119 |
-
if 'semantic segmentation' not in prompt and semantic != None:
|
| 120 |
-
raise gr.Error(
|
| 121 |
-
"You must select 'semantic segmentation' when performing semantic segmentation."
|
| 122 |
-
)
|
| 123 |
|
| 124 |
|
| 125 |
|
|
@@ -146,14 +143,51 @@ def process_image_4(image_path, prompt):
|
|
| 146 |
|
| 147 |
|
| 148 |
def inf(image_path, prompt, sel_points, semantic):
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# return None
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
api_name="/inf"
|
| 155 |
)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
def clear_cache():
|
| 158 |
return None, None
|
| 159 |
|
|
@@ -162,18 +196,76 @@ def run_demo_server():
|
|
| 162 |
gradio_theme = gr.themes.Default()
|
| 163 |
with gr.Blocks(
|
| 164 |
theme=gradio_theme,
|
| 165 |
-
title="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
) as demo:
|
| 167 |
selected_points = gr.State([]) # store points
|
| 168 |
original_image = gr.State(value=None) # store original image without points, default None
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
with gr.Row():
|
| 174 |
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
|
| 175 |
with gr.Row():
|
| 176 |
semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
|
|
|
|
|
|
|
| 177 |
with gr.Row():
|
| 178 |
with gr.Column():
|
| 179 |
input_image = gr.Image(
|
|
@@ -184,20 +276,22 @@ def run_demo_server():
|
|
| 184 |
with gr.Column():
|
| 185 |
with gr.Row():
|
| 186 |
gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
|
| 187 |
-
undo_button = gr.Button('Undo point')
|
| 188 |
|
| 189 |
-
with gr.Row():
|
| 190 |
matting_image_submit_btn = gr.Button(
|
| 191 |
-
value="
|
| 192 |
)
|
|
|
|
|
|
|
|
|
|
| 193 |
matting_image_reset_btn = gr.Button(value="Reset")
|
| 194 |
|
| 195 |
-
with gr.Row():
|
| 196 |
-
|
| 197 |
|
| 198 |
with gr.Column():
|
| 199 |
# matting_image_output = gr.Image(label='Output')
|
| 200 |
-
matting_image_output = gr.Image(label='
|
|
|
|
| 201 |
|
| 202 |
# label="Matting Output",
|
| 203 |
# type="filepath",
|
|
@@ -210,7 +304,7 @@ def run_demo_server():
|
|
| 210 |
|
| 211 |
|
| 212 |
|
| 213 |
-
img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
|
| 214 |
|
| 215 |
matting_image_submit_btn.click(
|
| 216 |
fn=process_image_check,
|
|
@@ -230,11 +324,13 @@ def run_demo_server():
|
|
| 230 |
fn=lambda: (
|
| 231 |
None,
|
| 232 |
None,
|
|
|
|
| 233 |
),
|
| 234 |
inputs=[],
|
| 235 |
outputs=[
|
| 236 |
input_image,
|
| 237 |
matting_image_output,
|
|
|
|
| 238 |
],
|
| 239 |
queue=False,
|
| 240 |
)
|
|
|
|
| 20 |
self.create()
|
| 21 |
|
| 22 |
|
| 23 |
+
def postprocess(output, prompt):
|
| 24 |
+
result = []
|
| 25 |
+
image = Image.open(output)
|
| 26 |
+
w, h = image.size
|
| 27 |
+
n = len(prompt)
|
| 28 |
+
slice_width = w // n
|
| 29 |
+
|
| 30 |
+
for i in range(n):
|
| 31 |
+
left = i * slice_width
|
| 32 |
+
right = (i + 1) * slice_width if i < n - 1 else w
|
| 33 |
+
cropped_img = image.crop((left, 0, right, h))
|
| 34 |
+
|
| 35 |
+
# 生成 caption
|
| 36 |
+
caption = prompt[i]
|
| 37 |
+
|
| 38 |
+
# 存入列表
|
| 39 |
+
result.append((cropped_img, caption))
|
| 40 |
+
return result
|
| 41 |
+
|
| 42 |
# user click the image to get points, and show the points on the image
|
| 43 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
| 44 |
+
print(sel_pix)
|
| 45 |
if len(sel_pix) < 5:
|
| 46 |
sel_pix.append((evt.index, 1)) # default foreground_point
|
| 47 |
img = cv2.imread(img)
|
|
|
|
| 74 |
return temp, sel_pix
|
| 75 |
|
| 76 |
|
| 77 |
+
HF_TOKEN = os.environ.get('HF_KEY')
|
| 78 |
|
| 79 |
+
client = Client("Canyu/Diception",
|
| 80 |
+
max_workers=3,
|
| 81 |
+
hf_token=HF_TOKEN)
|
| 82 |
|
| 83 |
colors = [(255, 0, 0), (0, 255, 0)]
|
| 84 |
markers = [1, 5]
|
|
|
|
| 109 |
return additional_params
|
| 110 |
|
| 111 |
def process_image_check(path_input, prompt, sel_points, semantic):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if path_input is None:
|
| 113 |
raise gr.Error(
|
| 114 |
"Missing image in the left pane: please upload an image first."
|
|
|
|
| 117 |
raise gr.Error(
|
| 118 |
"At least 1 prediction type is needed."
|
| 119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def inf(image_path, prompt, sel_points, semantic):
|
| 146 |
+
print('=========== PROCESS IMAGE CHECK ===========')
|
| 147 |
+
print(f"Image Path: {image_path}")
|
| 148 |
+
print(f"Prompt: {prompt}")
|
| 149 |
+
print(f"Selected Points (before processing): {sel_points}")
|
| 150 |
+
print(f"Semantic Input: {semantic}")
|
| 151 |
+
print('===========================================')
|
| 152 |
+
|
| 153 |
+
if 'point segmentation' in prompt and len(sel_points) == 0:
|
| 154 |
+
raise gr.Error(
|
| 155 |
+
"At least 1 point is needed."
|
| 156 |
+
)
|
| 157 |
+
return
|
| 158 |
+
if 'point segmentation' not in prompt and len(sel_points) != 0:
|
| 159 |
+
raise gr.Error(
|
| 160 |
+
"You must select 'point segmentation' when performing point segmentation."
|
| 161 |
+
)
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
if 'semantic segmentation' in prompt and semantic == '':
|
| 165 |
+
raise gr.Error(
|
| 166 |
+
"Target category is needed."
|
| 167 |
+
)
|
| 168 |
+
return
|
| 169 |
+
if 'semantic segmentation' not in prompt and semantic != '':
|
| 170 |
+
raise gr.Error(
|
| 171 |
+
"You must select 'semantic segmentation' when performing semantic segmentation."
|
| 172 |
+
)
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
# return None
|
| 176 |
+
# inputs = process_image_4(image_path, prompt, sel_points, semantic)
|
| 177 |
+
|
| 178 |
+
prompt_str = str(sel_points)
|
| 179 |
+
|
| 180 |
+
result = client.predict(
|
| 181 |
+
input_image=handle_file(image_path),
|
| 182 |
+
checkbox_group=prompt,
|
| 183 |
+
selected_points=prompt_str,
|
| 184 |
+
semantic_input=semantic,
|
| 185 |
api_name="/inf"
|
| 186 |
)
|
| 187 |
|
| 188 |
+
result = postprocess(result, prompt)
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
def clear_cache():
|
| 192 |
return None, None
|
| 193 |
|
|
|
|
| 196 |
gradio_theme = gr.themes.Default()
|
| 197 |
with gr.Blocks(
|
| 198 |
theme=gradio_theme,
|
| 199 |
+
title="Diception",
|
| 200 |
+
css="""
|
| 201 |
+
#download {
|
| 202 |
+
height: 118px;
|
| 203 |
+
}
|
| 204 |
+
.slider .inner {
|
| 205 |
+
width: 5px;
|
| 206 |
+
background: #FFF;
|
| 207 |
+
}
|
| 208 |
+
.viewport {
|
| 209 |
+
aspect-ratio: 4/3;
|
| 210 |
+
}
|
| 211 |
+
.tabs button.selected {
|
| 212 |
+
font-size: 20px !important;
|
| 213 |
+
color: crimson !important;
|
| 214 |
+
}
|
| 215 |
+
h1 {
|
| 216 |
+
text-align: center;
|
| 217 |
+
display: block;
|
| 218 |
+
}
|
| 219 |
+
h2 {
|
| 220 |
+
text-align: center;
|
| 221 |
+
display: block;
|
| 222 |
+
}
|
| 223 |
+
h3 {
|
| 224 |
+
text-align: center;
|
| 225 |
+
display: block;
|
| 226 |
+
}
|
| 227 |
+
.md_feedback li {
|
| 228 |
+
margin-bottom: 0px !important;
|
| 229 |
+
}
|
| 230 |
+
""",
|
| 231 |
+
head="""
|
| 232 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
|
| 233 |
+
<script>
|
| 234 |
+
window.dataLayer = window.dataLayer || [];
|
| 235 |
+
function gtag() {dataLayer.push(arguments);}
|
| 236 |
+
gtag('js', new Date());
|
| 237 |
+
gtag('config', 'G-1FWSVCGZTG');
|
| 238 |
+
</script>
|
| 239 |
+
""",
|
| 240 |
) as demo:
|
| 241 |
selected_points = gr.State([]) # store points
|
| 242 |
original_image = gr.State(value=None) # store original image without points, default None
|
| 243 |
+
gr.Markdown(
|
| 244 |
+
"""
|
| 245 |
+
# DICEPTION: A Generalist Diffusion Model for Vision Perception
|
| 246 |
+
<p align="center">
|
| 247 |
+
<a title="arXiv" href="https://arxiv.org" target="_blank" rel="noopener noreferrer"
|
| 248 |
+
style="display: inline-block;">
|
| 249 |
+
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
| 250 |
+
</a>
|
| 251 |
+
<a title="Github" href="https://github.com/aim-uofa/Diception" target="_blank" rel="noopener noreferrer"
|
| 252 |
+
style="display: inline-block;">
|
| 253 |
+
<img src="https://img.shields.io/github/stars/aim-uofa/GenPercept?label=GitHub%20%E2%98%85&logo=github&color=C8C"
|
| 254 |
+
alt="badge-github-stars">
|
| 255 |
+
</a>
|
| 256 |
+
</p>
|
| 257 |
+
<p align="justify">
|
| 258 |
+
One single model solves multiple perception tasks, producing impressive results!
|
| 259 |
+
</p>
|
| 260 |
+
"""
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
with gr.Row():
|
| 264 |
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
|
| 265 |
with gr.Row():
|
| 266 |
semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
|
| 267 |
+
with gr.Row():
|
| 268 |
+
gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.')
|
| 269 |
with gr.Row():
|
| 270 |
with gr.Column():
|
| 271 |
input_image = gr.Image(
|
|
|
|
| 276 |
with gr.Column():
|
| 277 |
with gr.Row():
|
| 278 |
gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
|
|
|
|
| 279 |
|
|
|
|
| 280 |
matting_image_submit_btn = gr.Button(
|
| 281 |
+
value="Run", variant="primary"
|
| 282 |
)
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
undo_button = gr.Button('Undo point')
|
| 286 |
matting_image_reset_btn = gr.Button(value="Reset")
|
| 287 |
|
| 288 |
+
# with gr.Row():
|
| 289 |
+
# img_clear_button = gr.Button("Clear Cache")
|
| 290 |
|
| 291 |
with gr.Column():
|
| 292 |
# matting_image_output = gr.Image(label='Output')
|
| 293 |
+
# matting_image_output = gr.Image(label='Results')
|
| 294 |
+
matting_image_output = gr.Gallery(label="Results")
|
| 295 |
|
| 296 |
# label="Matting Output",
|
| 297 |
# type="filepath",
|
|
|
|
| 304 |
|
| 305 |
|
| 306 |
|
| 307 |
+
# img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
|
| 308 |
|
| 309 |
matting_image_submit_btn.click(
|
| 310 |
fn=process_image_check,
|
|
|
|
| 324 |
fn=lambda: (
|
| 325 |
None,
|
| 326 |
None,
|
| 327 |
+
[]
|
| 328 |
),
|
| 329 |
inputs=[],
|
| 330 |
outputs=[
|
| 331 |
input_image,
|
| 332 |
matting_image_output,
|
| 333 |
+
selected_points
|
| 334 |
],
|
| 335 |
queue=False,
|
| 336 |
)
|