Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
-
import json
|
| 2 |
-
|
| 3 |
-
import gradio as gr
|
| 4 |
from PIL import Image
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
from timm.models import VisionTransformer
|
| 9 |
import torch
|
| 10 |
from torchvision.transforms import transforms
|
| 11 |
from torchvision.transforms import InterpolationMode
|
| 12 |
import torchvision.transforms.functional as TF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
-
import numpy as np
|
| 15 |
-
import matplotlib.cm as cm
|
| 16 |
|
| 17 |
class Fit(torch.nn.Module):
|
| 18 |
def __init__(
|
|
@@ -147,12 +146,13 @@ cached_model = hf_hub_download(
|
|
| 147 |
safetensors.torch.load_model(model, cached_model)
|
| 148 |
model.eval()
|
| 149 |
|
| 150 |
-
with open("tagger_tags.json", "
|
| 151 |
-
tags = json.
|
| 152 |
-
allowed_tags = list(tags.keys())
|
| 153 |
|
| 154 |
-
for
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
@spaces.GPU(duration=5)
|
| 158 |
def run_classifier(image: Image.Image, threshold):
|
|
@@ -161,11 +161,10 @@ def run_classifier(image: Image.Image, threshold):
|
|
| 161 |
|
| 162 |
with torch.no_grad():
|
| 163 |
probits = model(tensor)[0] # type: torch.Tensor
|
| 164 |
-
values, indices = probits.topk(250)
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
tag_score = dict()
|
| 167 |
-
for i in range(indices.size(0)):
|
| 168 |
-
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
| 169 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
| 170 |
|
| 171 |
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
|
@@ -178,8 +177,9 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
| 178 |
def clear_image():
|
| 179 |
return "", {}, None, {}, None
|
| 180 |
|
|
|
|
| 181 |
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
| 182 |
-
|
| 183 |
tensor = transform(img).unsqueeze(0)
|
| 184 |
|
| 185 |
gradients = {}
|
|
@@ -191,7 +191,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
| 191 |
def hook_backward(module, grad_in, grad_out):
|
| 192 |
gradients['value'] = grad_out[0]
|
| 193 |
|
| 194 |
-
target_tag_index = allowed_tags.index(target_tag)
|
| 195 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
| 196 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 197 |
|
|
@@ -287,11 +286,11 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 287 |
with gr.Row():
|
| 288 |
with gr.Column():
|
| 289 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
| 290 |
-
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
| 291 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
| 292 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
| 293 |
with gr.Column():
|
| 294 |
tag_string = gr.Textbox(label="Tag String")
|
|
|
|
| 295 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
| 296 |
|
| 297 |
gr.Markdown("""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.cm as cm
|
| 4 |
+
import msgspec
|
|
|
|
| 5 |
import torch
|
| 6 |
from torchvision.transforms import transforms
|
| 7 |
from torchvision.transforms import InterpolationMode
|
| 8 |
import torchvision.transforms.functional as TF
|
| 9 |
+
import timm
|
| 10 |
+
from timm.models import VisionTransformer
|
| 11 |
+
import safetensors.torch
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import spaces
|
| 14 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class Fit(torch.nn.Module):
|
| 17 |
def __init__(
|
|
|
|
| 146 |
safetensors.torch.load_model(model, cached_model)
|
| 147 |
model.eval()
|
| 148 |
|
| 149 |
+
with open("tagger_tags.json", "rb") as file:
|
| 150 |
+
tags = msgspec.json.decode(file.read(), type=dict[str, int])
|
|
|
|
| 151 |
|
| 152 |
+
for tag in tags.keys():
|
| 153 |
+
tags[tag.replace("_", " ")] = tags.pop(tag)
|
| 154 |
+
|
| 155 |
+
allowed_tags = list(tags.keys())
|
| 156 |
|
| 157 |
@spaces.GPU(duration=5)
|
| 158 |
def run_classifier(image: Image.Image, threshold):
|
|
|
|
| 161 |
|
| 162 |
with torch.no_grad():
|
| 163 |
probits = model(tensor)[0] # type: torch.Tensor
|
| 164 |
+
values, indices = probits.cpu().topk(250)
|
| 165 |
+
|
| 166 |
+
tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
| 169 |
|
| 170 |
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
|
|
|
| 177 |
def clear_image():
|
| 178 |
return "", {}, None, {}, None
|
| 179 |
|
| 180 |
+
@spaces.GPU(duration=5)
|
| 181 |
def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
| 182 |
+
target_tag_index = tags[evt.value]
|
| 183 |
tensor = transform(img).unsqueeze(0)
|
| 184 |
|
| 185 |
gradients = {}
|
|
|
|
| 191 |
def hook_backward(module, grad_in, grad_out):
|
| 192 |
gradients['value'] = grad_out[0]
|
| 193 |
|
|
|
|
| 194 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
| 195 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 196 |
|
|
|
|
| 286 |
with gr.Row():
|
| 287 |
with gr.Column():
|
| 288 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
|
|
|
| 289 |
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
| 290 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
| 291 |
with gr.Column():
|
| 292 |
tag_string = gr.Textbox(label="Tag String")
|
| 293 |
+
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
| 294 |
label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
|
| 295 |
|
| 296 |
gr.Markdown("""
|