Spaces:
Runtime error
Runtime error
update for v3
Browse files- README.md +11 -5
- app.py +169 -63
- data/selected_tags.csv +0 -0
- tagger/common.py +56 -4
README.md
CHANGED
|
@@ -9,12 +9,18 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
short_description: A WD Tagger Space for pi-chan to use
|
| 11 |
preload_from_hub:
|
| 12 |
-
- SmilingWolf/wd-
|
| 13 |
-
- SmilingWolf/wd-
|
| 14 |
-
- SmilingWolf/wd-
|
| 15 |
-
- SmilingWolf/wd-v1-4-
|
| 16 |
-
- SmilingWolf/wd-v1-4-
|
|
|
|
|
|
|
|
|
|
| 17 |
models:
|
|
|
|
|
|
|
|
|
|
| 18 |
- SmilingWolf/wd-v1-4-moat-tagger-v2
|
| 19 |
- SmilingWolf/wd-v1-4-swinv2-tagger-v2
|
| 20 |
- SmilingWolf/wd-v1-4-convnext-tagger-v2
|
|
|
|
| 9 |
pinned: false
|
| 10 |
short_description: A WD Tagger Space for pi-chan to use
|
| 11 |
preload_from_hub:
|
| 12 |
+
- SmilingWolf/wd-vit-tagger-v3 model.onnx,selected_tags.csv
|
| 13 |
+
- SmilingWolf/wd-swinv2-tagger-v3 model.onnx,selected_tags.csv
|
| 14 |
+
- SmilingWolf/wd-convnext-tagger-v3 model.onnx,selected_tags.csv
|
| 15 |
+
- SmilingWolf/wd-v1-4-moat-tagger-v2 model.onnx,selected_tags.csv
|
| 16 |
+
- SmilingWolf/wd-v1-4-swinv2-tagger-v2 model.onnx,selected_tags.csv
|
| 17 |
+
- SmilingWolf/wd-v1-4-convnext-tagger-v2 model.onnx,selected_tags.csv
|
| 18 |
+
- SmilingWolf/wd-v1-4-convnextv2-tagger-v2 model.onnx,selected_tags.csv
|
| 19 |
+
- SmilingWolf/wd-v1-4-vit-tagger-v2 model.onnx,selected_tags.csv
|
| 20 |
models:
|
| 21 |
+
- SmilingWolf/wd-vit-tagger-v3
|
| 22 |
+
- SmilingWolf/wd-swinv2-tagger-v3
|
| 23 |
+
- SmilingWolf/wd-convnext-tagger-v3
|
| 24 |
- SmilingWolf/wd-v1-4-moat-tagger-v2
|
| 25 |
- SmilingWolf/wd-v1-4-swinv2-tagger-v2
|
| 26 |
- SmilingWolf/wd-v1-4-convnext-tagger-v2
|
app.py
CHANGED
|
@@ -7,25 +7,41 @@ import numpy as np
|
|
| 7 |
import onnxruntime as rt
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
-
from tagger.common import LabelData,
|
| 11 |
from tagger.model import create_session
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
HF_TOKEN = getenv("HF_TOKEN", None)
|
| 14 |
-
WORK_DIR = Path.cwd().resolve()
|
| 15 |
|
| 16 |
MODEL_VARIANTS: dict[str, str] = {
|
| 17 |
-
"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
}
|
|
|
|
|
|
|
|
|
|
| 23 |
|
|
|
|
|
|
|
| 24 |
# allowed extensions
|
| 25 |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
IMAGE_SIZE = 448
|
| 29 |
example_images = sorted(
|
| 30 |
[
|
| 31 |
str(x.relative_to(WORK_DIR))
|
|
@@ -33,34 +49,51 @@ example_images = sorted(
|
|
| 33 |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
|
| 34 |
]
|
| 35 |
)
|
| 36 |
-
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()}
|
| 37 |
|
| 38 |
|
| 39 |
-
def load_model(variant: str) -> rt.InferenceSession:
|
| 40 |
global loaded_models
|
| 41 |
|
| 42 |
# resolve the repo name
|
| 43 |
-
model_repo = MODEL_VARIANTS.get(variant, None)
|
| 44 |
if model_repo is None:
|
| 45 |
-
raise ValueError(f"Unknown model variant: {variant}")
|
| 46 |
|
| 47 |
-
|
|
|
|
| 48 |
# save model to cache
|
| 49 |
-
loaded_models[
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def predict(
|
| 55 |
image: Image.Image,
|
|
|
|
| 56 |
variant: str,
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
):
|
| 60 |
-
#
|
| 61 |
-
model: rt.InferenceSession = load_model(variant)
|
| 62 |
# load labels
|
| 63 |
-
labels: LabelData =
|
| 64 |
|
| 65 |
# get input size and name
|
| 66 |
_, h, w, _ = model.get_inputs()[0].shape
|
|
@@ -85,13 +118,21 @@ def predict(
|
|
| 85 |
rating_labels = dict([probs[i] for i in labels.rating])
|
| 86 |
|
| 87 |
# General labels, pick any where prediction confidence > threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
gen_labels = [probs[i] for i in labels.general]
|
| 89 |
-
gen_labels = dict([x for x in gen_labels if x[1] >
|
| 90 |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
| 91 |
|
| 92 |
# Character labels, pick any where prediction confidence > threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
char_labels = [probs[i] for i in labels.character]
|
| 94 |
-
char_labels = dict([x for x in char_labels if x[1] >
|
| 95 |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
| 96 |
|
| 97 |
# Combine general and character labels, sort by confidence
|
|
@@ -102,64 +143,129 @@ def predict(
|
|
| 102 |
caption = ", ".join(combined_names)
|
| 103 |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
|
| 104 |
|
| 105 |
-
return image, caption, booru, rating_labels, char_labels, gen_labels
|
| 106 |
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
with gr.Row(equal_height=False):
|
| 110 |
-
with gr.Column():
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
with gr.Row():
|
| 122 |
-
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 123 |
clear = gr.ClearButton(
|
| 124 |
components=[],
|
| 125 |
variant="secondary",
|
| 126 |
size="lg",
|
| 127 |
)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
],
|
| 135 |
-
inputs=[img_input, variant, gen_thresh, char_thresh],
|
| 136 |
-
)
|
| 137 |
-
with gr.Column():
|
| 138 |
-
img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False)
|
| 139 |
with gr.Group():
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
# tell clear button which components to clear
|
| 151 |
-
clear.add([img_input, img_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# show/hide processed image
|
| 154 |
-
def
|
| 155 |
-
return gr.update(visible=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
submit.click(
|
| 160 |
predict,
|
| 161 |
-
inputs=[img_input, variant,
|
| 162 |
-
outputs=[img_output,
|
| 163 |
api_name="predict",
|
| 164 |
)
|
| 165 |
|
|
|
|
| 7 |
import onnxruntime as rt
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
+
from tagger.common import LabelData, load_labels_hf, preprocess_image
|
| 11 |
from tagger.model import create_session
|
| 12 |
|
| 13 |
+
TITLE = "WaifuDiffusion Tagger"
|
| 14 |
+
DESCRIPTION = """
|
| 15 |
+
Tag images with the WaifuDiffusion Tagger models!
|
| 16 |
+
|
| 17 |
+
Primarily used as a backend for a Discord bot.
|
| 18 |
+
"""
|
| 19 |
HF_TOKEN = getenv("HF_TOKEN", None)
|
|
|
|
| 20 |
|
| 21 |
MODEL_VARIANTS: dict[str, str] = {
|
| 22 |
+
"v3": {
|
| 23 |
+
"SwinV2": "SmilingWolf/wd-swinv2-tagger-v3",
|
| 24 |
+
"ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3",
|
| 25 |
+
"ViT": "SmilingWolf/wd-vit-tagger-v3",
|
| 26 |
+
},
|
| 27 |
+
"v2": {
|
| 28 |
+
"MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2",
|
| 29 |
+
"SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
|
| 30 |
+
"ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
|
| 31 |
+
"ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
| 32 |
+
"ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2",
|
| 33 |
+
},
|
| 34 |
}
|
| 35 |
+
# prepopulate cache keys in model cache
|
| 36 |
+
cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()]
|
| 37 |
+
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys}
|
| 38 |
|
| 39 |
+
# get the repo root (or the current working directory if running in ipython)
|
| 40 |
+
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
|
| 41 |
# allowed extensions
|
| 42 |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
| 43 |
|
| 44 |
+
# get the example images
|
|
|
|
| 45 |
example_images = sorted(
|
| 46 |
[
|
| 47 |
str(x.relative_to(WORK_DIR))
|
|
|
|
| 49 |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
|
| 50 |
]
|
| 51 |
)
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
+
def load_model(version: str, variant: str) -> rt.InferenceSession:
|
| 55 |
global loaded_models
|
| 56 |
|
| 57 |
# resolve the repo name
|
| 58 |
+
model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None)
|
| 59 |
if model_repo is None:
|
| 60 |
+
raise ValueError(f"Unknown model variant: {version}-{variant}")
|
| 61 |
|
| 62 |
+
cache_key = f"{version}-{variant}"
|
| 63 |
+
if loaded_models.get(cache_key, None) is None:
|
| 64 |
# save model to cache
|
| 65 |
+
loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN)
|
| 66 |
+
|
| 67 |
+
return loaded_models[cache_key]
|
| 68 |
+
|
| 69 |
|
| 70 |
+
def mcut_threshold(probs: np.ndarray) -> float:
|
| 71 |
+
"""
|
| 72 |
+
Maximum Cut Thresholding (MCut)
|
| 73 |
+
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
|
| 74 |
+
for Multi-label Classification. In 11th International Symposium, IDA 2012
|
| 75 |
+
(pp. 172-183).
|
| 76 |
+
"""
|
| 77 |
+
probs = probs[probs.argsort()[::-1]]
|
| 78 |
+
diffs = probs[:-1] - probs[1:]
|
| 79 |
+
idx = diffs.argmax()
|
| 80 |
+
thresh = (probs[idx] + probs[idx + 1]) / 2
|
| 81 |
+
return float(thresh)
|
| 82 |
|
| 83 |
|
| 84 |
def predict(
|
| 85 |
image: Image.Image,
|
| 86 |
+
version: str,
|
| 87 |
variant: str,
|
| 88 |
+
gen_threshold: float = 0.35,
|
| 89 |
+
gen_use_mcut: bool = False,
|
| 90 |
+
char_threshold: float = 0.85,
|
| 91 |
+
char_use_mcut: bool = False,
|
| 92 |
):
|
| 93 |
+
# join variant for cache key
|
| 94 |
+
model: rt.InferenceSession = load_model(version, variant)
|
| 95 |
# load labels
|
| 96 |
+
labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant])
|
| 97 |
|
| 98 |
# get input size and name
|
| 99 |
_, h, w, _ = model.get_inputs()[0].shape
|
|
|
|
| 118 |
rating_labels = dict([probs[i] for i in labels.rating])
|
| 119 |
|
| 120 |
# General labels, pick any where prediction confidence > threshold
|
| 121 |
+
if gen_use_mcut:
|
| 122 |
+
gen_array = np.array([probs[i][1] for i in labels.general])
|
| 123 |
+
gen_threshold = mcut_threshold(gen_array)
|
| 124 |
+
|
| 125 |
gen_labels = [probs[i] for i in labels.general]
|
| 126 |
+
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
| 127 |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
| 128 |
|
| 129 |
# Character labels, pick any where prediction confidence > threshold
|
| 130 |
+
if char_use_mcut:
|
| 131 |
+
char_array = np.array([probs[i][1] for i in labels.character])
|
| 132 |
+
char_threshold = round(mcut_threshold(char_array), 2)
|
| 133 |
+
|
| 134 |
char_labels = [probs[i] for i in labels.character]
|
| 135 |
+
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
| 136 |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
| 137 |
|
| 138 |
# Combine general and character labels, sort by confidence
|
|
|
|
| 143 |
caption = ", ".join(combined_names)
|
| 144 |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
|
| 145 |
|
| 146 |
+
return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold
|
| 147 |
|
| 148 |
|
| 149 |
+
css = """
|
| 150 |
+
#gen_mcut, #char_mcut {
|
| 151 |
+
padding-top: var(--scale-3);
|
| 152 |
+
}
|
| 153 |
+
#gen_threshold.dimmed, #char_threshold.dimmed {
|
| 154 |
+
filter: brightness(75%);
|
| 155 |
+
}
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
|
| 159 |
with gr.Row(equal_height=False):
|
| 160 |
+
with gr.Column(min_width=720):
|
| 161 |
+
with gr.Group():
|
| 162 |
+
img_input = gr.Image(
|
| 163 |
+
label="Input",
|
| 164 |
+
type="pil",
|
| 165 |
+
image_mode="RGB",
|
| 166 |
+
sources=["upload", "clipboard"],
|
| 167 |
+
)
|
| 168 |
+
show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False)
|
| 169 |
+
with gr.Row():
|
| 170 |
+
version = gr.Radio(
|
| 171 |
+
choices=list(MODEL_VARIANTS.keys()),
|
| 172 |
+
label="Model Version",
|
| 173 |
+
value="v3",
|
| 174 |
+
min_width=160,
|
| 175 |
+
scale=1,
|
| 176 |
+
) # gen_threshold > div.wrap.hide
|
| 177 |
+
variant = gr.Radio(
|
| 178 |
+
choices=list(MODEL_VARIANTS[version.value].keys()),
|
| 179 |
+
label="Model Variant",
|
| 180 |
+
value="ConvNeXT",
|
| 181 |
+
min_width=560,
|
| 182 |
+
)
|
| 183 |
+
with gr.Group():
|
| 184 |
+
with gr.Row():
|
| 185 |
+
gen_threshold = gr.Slider(
|
| 186 |
+
minimum=0.0,
|
| 187 |
+
maximum=1.0,
|
| 188 |
+
value=0.35,
|
| 189 |
+
step=0.01,
|
| 190 |
+
label="General Tag Threshold",
|
| 191 |
+
scale=5,
|
| 192 |
+
elem_id="gen_threshold",
|
| 193 |
+
)
|
| 194 |
+
gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut")
|
| 195 |
+
with gr.Row():
|
| 196 |
+
char_threshold = gr.Slider(
|
| 197 |
+
minimum=0.0,
|
| 198 |
+
maximum=1.0,
|
| 199 |
+
value=0.85,
|
| 200 |
+
step=0.01,
|
| 201 |
+
label="Character Tag Threshold",
|
| 202 |
+
scale=5,
|
| 203 |
+
elem_id="char_threshold",
|
| 204 |
+
)
|
| 205 |
+
char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut")
|
| 206 |
with gr.Row():
|
|
|
|
| 207 |
clear = gr.ClearButton(
|
| 208 |
components=[],
|
| 209 |
variant="secondary",
|
| 210 |
size="lg",
|
| 211 |
)
|
| 212 |
+
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 213 |
+
|
| 214 |
+
with gr.Column(min_width=720):
|
| 215 |
+
img_output = gr.Image(
|
| 216 |
+
label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False
|
| 217 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
with gr.Group():
|
| 219 |
+
caption = gr.Textbox(label="Caption", show_copy_button=True)
|
| 220 |
+
tags = gr.Textbox(label="Tags", show_copy_button=True)
|
| 221 |
+
with gr.Group():
|
| 222 |
+
rating = gr.Label(label="Rating")
|
| 223 |
+
with gr.Group():
|
| 224 |
+
char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
|
| 225 |
+
character = gr.Label(label="Character")
|
| 226 |
+
with gr.Group():
|
| 227 |
+
gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
|
| 228 |
+
general = gr.Label(label="General")
|
| 229 |
+
|
| 230 |
+
with gr.Row():
|
| 231 |
+
examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images]
|
| 232 |
+
|
| 233 |
+
examples = gr.Examples(
|
| 234 |
+
examples=examples,
|
| 235 |
+
inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut],
|
| 236 |
+
)
|
| 237 |
|
| 238 |
# tell clear button which components to clear
|
| 239 |
+
clear.add([img_input, img_output, caption, rating, character, general])
|
| 240 |
+
|
| 241 |
+
def on_select_variant(evt: gr.SelectData, variant: str):
|
| 242 |
+
if evt.selected:
|
| 243 |
+
choices = list(MODEL_VARIANTS[variant])
|
| 244 |
+
return gr.update(choices=choices, value=choices[0])
|
| 245 |
+
return gr.update()
|
| 246 |
+
|
| 247 |
+
version.select(on_select_variant, inputs=[version], outputs=[variant])
|
| 248 |
|
| 249 |
# show/hide processed image
|
| 250 |
+
def on_change_show(val: gr.Checkbox):
|
| 251 |
+
return gr.update(visible=val)
|
| 252 |
+
|
| 253 |
+
show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output])
|
| 254 |
+
|
| 255 |
+
# handle mcut thresholding (auto-calculate threshold from probs, disable slider)
|
| 256 |
+
def on_change_mcut(val: gr.Checkbox):
|
| 257 |
+
return (
|
| 258 |
+
gr.update(interactive=not val, elem_classes=["dimmed"] if val else []),
|
| 259 |
+
gr.update(visible=val),
|
| 260 |
+
)
|
| 261 |
|
| 262 |
+
gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out])
|
| 263 |
+
char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out])
|
| 264 |
|
| 265 |
submit.click(
|
| 266 |
predict,
|
| 267 |
+
inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut],
|
| 268 |
+
outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold],
|
| 269 |
api_name="predict",
|
| 270 |
)
|
| 271 |
|
data/selected_tags.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tagger/common.py
CHANGED
|
@@ -3,10 +3,12 @@ from dataclasses import asdict, dataclass
|
|
| 3 |
from functools import lru_cache
|
| 4 |
from os import PathLike
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
|
|
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
|
| 12 |
|
|
@@ -36,10 +38,36 @@ class ImageLabels(DictJsonMixin):
|
|
| 36 |
|
| 37 |
|
| 38 |
@lru_cache(maxsize=5)
|
| 39 |
-
def load_labels(
|
| 40 |
-
|
|
|
|
| 41 |
if not csv_path.is_file():
|
| 42 |
-
raise FileNotFoundError("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
| 45 |
tag_data = LabelData(
|
|
@@ -101,3 +129,27 @@ def preprocess_image(
|
|
| 101 |
image.thumbnail(size_px, Image.BICUBIC)
|
| 102 |
|
| 103 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from functools import lru_cache
|
| 4 |
from os import PathLike
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 12 |
from PIL import Image
|
| 13 |
|
| 14 |
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
@lru_cache(maxsize=5)
|
| 41 |
+
def load_labels(version: str = "v3", data_dir: PathLike = "./data") -> LabelData:
|
| 42 |
+
data_dir = Path(data_dir).resolve()
|
| 43 |
+
csv_path = data_dir.joinpath(f"selected_tags_{version}.csv")
|
| 44 |
if not csv_path.is_file():
|
| 45 |
+
raise FileNotFoundError(f"{csv_path.name} not found in {data_dir}")
|
| 46 |
+
|
| 47 |
+
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
| 48 |
+
tag_data = LabelData(
|
| 49 |
+
names=df["name"].tolist(),
|
| 50 |
+
rating=list(np.where(df["category"] == 9)[0]),
|
| 51 |
+
general=list(np.where(df["category"] == 0)[0]),
|
| 52 |
+
character=list(np.where(df["category"] == 4)[0]),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return tag_data
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@lru_cache(maxsize=5)
|
| 59 |
+
def load_labels_hf(
|
| 60 |
+
repo_id: str,
|
| 61 |
+
revision: Optional[str] = None,
|
| 62 |
+
token: Optional[str] = None,
|
| 63 |
+
) -> LabelData:
|
| 64 |
+
try:
|
| 65 |
+
csv_path = hf_hub_download(
|
| 66 |
+
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
|
| 67 |
+
)
|
| 68 |
+
csv_path = Path(csv_path).resolve()
|
| 69 |
+
except HfHubHTTPError as e:
|
| 70 |
+
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
|
| 71 |
|
| 72 |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
| 73 |
tag_data = LabelData(
|
|
|
|
| 129 |
image.thumbnail(size_px, Image.BICUBIC)
|
| 130 |
|
| 131 |
return image
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
| 135 |
+
kaomojis = [
|
| 136 |
+
"0_0",
|
| 137 |
+
"(o)_(o)",
|
| 138 |
+
"+_+",
|
| 139 |
+
"+_-",
|
| 140 |
+
"._.",
|
| 141 |
+
"<o>_<o>",
|
| 142 |
+
"<|>_<|>",
|
| 143 |
+
"=_=",
|
| 144 |
+
">_<",
|
| 145 |
+
"3_3",
|
| 146 |
+
"6_9",
|
| 147 |
+
">_o",
|
| 148 |
+
"@_@",
|
| 149 |
+
"^_^",
|
| 150 |
+
"o_o",
|
| 151 |
+
"u_u",
|
| 152 |
+
"x_x",
|
| 153 |
+
"|_|",
|
| 154 |
+
"||_||",
|
| 155 |
+
]
|