fix threshold slider and matplotlib warning
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
|
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
import numpy as np
|
| 15 |
-
import matplotlib.
|
| 16 |
|
| 17 |
class Fit(torch.nn.Module):
|
| 18 |
def __init__(
|
|
@@ -170,7 +170,7 @@ def run_classifier(image: Image.Image, threshold):
|
|
| 170 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
| 171 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
| 172 |
|
| 173 |
-
return *create_tags(threshold, sorted_tag_score), img
|
| 174 |
|
| 175 |
def create_tags(threshold, sorted_tag_score: dict):
|
| 176 |
filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
|
|
@@ -178,7 +178,7 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
| 178 |
return text_no_impl, filtered_tag_score
|
| 179 |
|
| 180 |
def clear_image():
|
| 181 |
-
return "", {}, None
|
| 182 |
|
| 183 |
def cam_inference(img, threshold, evt: gr.SelectData):
|
| 184 |
target_tag = evt.value
|
|
@@ -274,6 +274,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 274 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
| 275 |
""")
|
| 276 |
original_image_state = gr.State() # stash a copy of the input image
|
|
|
|
| 277 |
with gr.Row():
|
| 278 |
with gr.Column():
|
| 279 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
|
@@ -285,18 +286,18 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 285 |
image_input.upload(
|
| 286 |
fn=run_classifier,
|
| 287 |
inputs=[image_input, threshold_slider],
|
| 288 |
-
outputs=[tag_string, label_box, original_image_state]
|
| 289 |
)
|
| 290 |
|
| 291 |
image_input.clear(
|
| 292 |
fn=clear_image,
|
| 293 |
inputs=[],
|
| 294 |
-
outputs=[tag_string, label_box, original_image_state]
|
| 295 |
)
|
| 296 |
|
| 297 |
threshold_slider.input(
|
| 298 |
fn=create_tags,
|
| 299 |
-
inputs=[threshold_slider],
|
| 300 |
outputs=[tag_string, label_box]
|
| 301 |
)
|
| 302 |
|
|
|
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
import numpy as np
|
| 15 |
+
import matplotlib.colormaps as cm
|
| 16 |
|
| 17 |
class Fit(torch.nn.Module):
|
| 18 |
def __init__(
|
|
|
|
| 170 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
| 171 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
| 172 |
|
| 173 |
+
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
| 174 |
|
| 175 |
def create_tags(threshold, sorted_tag_score: dict):
|
| 176 |
filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
|
|
|
|
| 178 |
return text_no_impl, filtered_tag_score
|
| 179 |
|
| 180 |
def clear_image():
|
| 181 |
+
return "", {}, None, {}
|
| 182 |
|
| 183 |
def cam_inference(img, threshold, evt: gr.SelectData):
|
| 184 |
target_tag = evt.value
|
|
|
|
| 274 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
| 275 |
""")
|
| 276 |
original_image_state = gr.State() # stash a copy of the input image
|
| 277 |
+
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
| 278 |
with gr.Row():
|
| 279 |
with gr.Column():
|
| 280 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
|
|
|
| 286 |
image_input.upload(
|
| 287 |
fn=run_classifier,
|
| 288 |
inputs=[image_input, threshold_slider],
|
| 289 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
|
| 290 |
)
|
| 291 |
|
| 292 |
image_input.clear(
|
| 293 |
fn=clear_image,
|
| 294 |
inputs=[],
|
| 295 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
|
| 296 |
)
|
| 297 |
|
| 298 |
threshold_slider.input(
|
| 299 |
fn=create_tags,
|
| 300 |
+
inputs=[threshold_slider, sorted_tag_score_state],
|
| 301 |
outputs=[tag_string, label_box]
|
| 302 |
)
|
| 303 |
|