Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,9 +16,9 @@ Demo for the WaifuDiffusion tagger models
|
|
| 16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 17 |
|
| 18 |
# Dataset v3 series of models:
|
| 19 |
-
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
| 20 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
| 21 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
|
|
|
| 22 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
| 23 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
| 24 |
|
|
@@ -123,9 +123,9 @@ def main():
|
|
| 123 |
predictor = Predictor()
|
| 124 |
|
| 125 |
model_repos = [
|
| 126 |
-
VIT_MODEL_DSV3_REPO,
|
| 127 |
SWINV2_MODEL_DSV3_REPO,
|
| 128 |
CONV_MODEL_DSV3_REPO,
|
|
|
|
| 129 |
VIT_LARGE_MODEL_DSV3_REPO,
|
| 130 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
| 131 |
# ---
|
|
@@ -177,12 +177,7 @@ def main():
|
|
| 177 |
"blank_censor",
|
| 178 |
"blur_censor",
|
| 179 |
"light_censor",
|
| 180 |
-
"mosaic_censoring"]
|
| 181 |
-
|
| 182 |
-
predefined_tags2 = [
|
| 183 |
-
"big, small:medium", # If either "big" or "small" is missing, add "medium"
|
| 184 |
-
"small hand, large hand:medium hand" # If either "small hand" or "large hand" is missing, add "medium hand"
|
| 185 |
-
]
|
| 186 |
|
| 187 |
with gr.Blocks(title=TITLE) as demo:
|
| 188 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
|
@@ -213,33 +208,21 @@ def main():
|
|
| 213 |
placeholder="Add tags to filter out (e.g., winter, red, from above)",
|
| 214 |
lines=5
|
| 215 |
)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
lines=3
|
| 221 |
-
)
|
| 222 |
-
submit = gr.Button(
|
| 223 |
-
value="Process Images", variant="primary"
|
| 224 |
-
)
|
| 225 |
|
| 226 |
with gr.Column():
|
| 227 |
output = gr.Textbox(label="Output", lines=10)
|
| 228 |
|
| 229 |
-
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags
|
| 230 |
images = [Image.open(file.name) for file in files]
|
| 231 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
| 232 |
-
|
| 233 |
# Parse filter tags
|
| 234 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
| 235 |
-
|
| 236 |
-
# Parse custom tags and their fallback pairs
|
| 237 |
-
fallback_tags = {}
|
| 238 |
-
for pair in custom_tags_input.split(","):
|
| 239 |
-
if ":" in pair:
|
| 240 |
-
tag, fallback = pair.split(":")
|
| 241 |
-
fallback_tags[tag.strip().lower()] = fallback.strip().lower()
|
| 242 |
-
|
| 243 |
# Generate formatted output
|
| 244 |
prompts = []
|
| 245 |
for i, (general_tags, character_tags) in enumerate(results):
|
|
@@ -250,30 +233,24 @@ def main():
|
|
| 250 |
general_part = ", ".join(
|
| 251 |
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
|
| 252 |
)
|
| 253 |
-
|
| 254 |
-
# Check if custom tags are missing and apply fallback tags
|
| 255 |
-
all_tags = set(general_tags + character_tags)
|
| 256 |
-
for tag, fallback in fallback_tags.items():
|
| 257 |
-
if tag not in all_tags:
|
| 258 |
-
all_tags.add(fallback)
|
| 259 |
-
|
| 260 |
-
# Construct the final prompt
|
| 261 |
-
final_tags = ", ".join(tag.replace('_', ' ') for tag in all_tags if tag.lower() not in filter_set)
|
| 262 |
-
prompts.append(final_tags)
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# Join all prompts with blank lines
|
| 265 |
return "\n\n".join(prompts)
|
| 266 |
|
| 267 |
-
|
| 268 |
submit.click(
|
| 269 |
process_images,
|
| 270 |
-
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags
|
| 271 |
outputs=output
|
| 272 |
)
|
| 273 |
|
| 274 |
-
|
| 275 |
demo.queue(max_size=10)
|
| 276 |
demo.launch()
|
| 277 |
|
| 278 |
if __name__ == "__main__":
|
| 279 |
-
main()
|
|
|
|
| 16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 17 |
|
| 18 |
# Dataset v3 series of models:
|
|
|
|
| 19 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
| 20 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
| 21 |
+
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
| 22 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
| 23 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
| 24 |
|
|
|
|
| 123 |
predictor = Predictor()
|
| 124 |
|
| 125 |
model_repos = [
|
|
|
|
| 126 |
SWINV2_MODEL_DSV3_REPO,
|
| 127 |
CONV_MODEL_DSV3_REPO,
|
| 128 |
+
VIT_MODEL_DSV3_REPO,
|
| 129 |
VIT_LARGE_MODEL_DSV3_REPO,
|
| 130 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
| 131 |
# ---
|
|
|
|
| 177 |
"blank_censor",
|
| 178 |
"blur_censor",
|
| 179 |
"light_censor",
|
| 180 |
+
"mosaic_censoring"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
with gr.Blocks(title=TITLE) as demo:
|
| 183 |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
|
|
|
| 208 |
placeholder="Add tags to filter out (e.g., winter, red, from above)",
|
| 209 |
lines=5
|
| 210 |
)
|
| 211 |
+
|
| 212 |
+
submit = gr.Button(
|
| 213 |
+
value="Process Images", variant="primary"
|
| 214 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
with gr.Column():
|
| 217 |
output = gr.Textbox(label="Output", lines=10)
|
| 218 |
|
| 219 |
+
def process_images(files, model_repo, general_thresh, character_thresh, filter_tags):
|
| 220 |
images = [Image.open(file.name) for file in files]
|
| 221 |
results = predictor.predict(images, model_repo, general_thresh, character_thresh)
|
| 222 |
+
|
| 223 |
# Parse filter tags
|
| 224 |
filter_set = set(tag.strip().lower() for tag in filter_tags.split(","))
|
| 225 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
# Generate formatted output
|
| 227 |
prompts = []
|
| 228 |
for i, (general_tags, character_tags) in enumerate(results):
|
|
|
|
| 233 |
general_part = ", ".join(
|
| 234 |
tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set
|
| 235 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
# Construct the prompt based on the presence of character_part
|
| 238 |
+
if character_part:
|
| 239 |
+
prompts.append(f"{character_part}, {general_part}")
|
| 240 |
+
else:
|
| 241 |
+
prompts.append(general_part)
|
| 242 |
+
|
| 243 |
# Join all prompts with blank lines
|
| 244 |
return "\n\n".join(prompts)
|
| 245 |
|
|
|
|
| 246 |
submit.click(
|
| 247 |
process_images,
|
| 248 |
+
inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags],
|
| 249 |
outputs=output
|
| 250 |
)
|
| 251 |
|
|
|
|
| 252 |
demo.queue(max_size=10)
|
| 253 |
demo.launch()
|
| 254 |
|
| 255 |
if __name__ == "__main__":
|
| 256 |
+
main()
|