Spaces:
Running
Running
Food Desert
commited on
Commit
·
1136048
1
Parent(s):
41d10ff
Update UI (scroll cues, tooltips, model-specific tags, enter-to-run) and fixes
Browse files- .gitignore +6 -0
- app.py +689 -122
- requirements.txt +1 -1
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.log
|
| 5 |
+
*.tmp
|
| 6 |
+
.DS_Store
|
app.py
CHANGED
|
@@ -23,9 +23,28 @@ import itertools
|
|
| 23 |
from itertools import islice
|
| 24 |
from pathlib import Path
|
| 25 |
import logging
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Set up logging
|
| 28 |
-
logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
faq_content="""
|
|
@@ -35,7 +54,7 @@ faq_content="""
|
|
| 35 |
|
| 36 |
Since Stable Diffusion's initial release in 2022, users have developed a myriad of fine-tuned text to image models, each with unique "linguistic" preferences depending on the data from which it was fine-tuned.
|
| 37 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
| 38 |
-
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock,
|
| 39 |
|
| 40 |
When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
|
| 41 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
|
@@ -115,16 +134,123 @@ Each subsequent row of images was generated using the same process, but with a d
|
|
| 115 |
See SamplePrompts.csv for the list of prompts used and their descriptions.
|
| 116 |
"""
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
| 120 |
|
| 121 |
-
css = """
|
| 122 |
-
.scrollable-content
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
"""
|
| 127 |
|
|
|
|
|
|
|
|
|
|
| 128 |
grammar=r"""
|
| 129 |
!start: (prompt | /[][():]/+)*
|
| 130 |
prompt: (emphasized | plain | comma | WHITESPACE)*
|
|
@@ -139,6 +265,127 @@ plain: /([^,\\\[\]():|]|\\.)+/
|
|
| 139 |
# Initialize the parser
|
| 140 |
parser = Lark(grammar, start='start')
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# Function to extract tags
|
| 143 |
def extract_tags(tree):
|
| 144 |
tags_with_positions = []
|
|
@@ -160,7 +407,38 @@ def remove_special_tags(original_string):
|
|
| 160 |
remaining_tags = [tag for tag in tags if tag not in special_tags]
|
| 161 |
removed_tags = [tag for tag in tags if tag in special_tags]
|
| 162 |
return ", ".join(remaining_tags), removed_tags
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# Define a function to load all necessary components
|
| 166 |
def load_model_components(file_path):
|
|
@@ -182,7 +460,15 @@ def load_model_components(file_path):
|
|
| 182 |
|
| 183 |
# Load all components at the start
|
| 184 |
tf_idf_components = load_model_components('tf_idf_files_420.joblib')
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
|
| 188 |
# Open and read the CSV file
|
|
@@ -218,18 +504,41 @@ def is_artist(name):
|
|
| 218 |
sample_images_directory_path = 'sampleimages'
|
| 219 |
def generate_artist_image_tuples(top_artists, image_directory):
|
| 220 |
json_files = glob.glob(f'{image_directory}/*.json')
|
| 221 |
-
|
|
|
|
|
|
|
| 222 |
with open(json_file_path, 'r') as json_file:
|
| 223 |
artist_to_file_map = json.load(json_file)
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
filename = artist_to_file_map.get("")
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
| 230 |
artist_image_tuples = []
|
| 231 |
for artist in top_artists:
|
| 232 |
filename = artist_to_file_map.get(artist)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
if filename:
|
| 234 |
image_path = os.path.join(image_directory, filename)
|
| 235 |
if os.path.exists(image_path):
|
|
@@ -321,7 +630,7 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
|
|
| 321 |
# Loop through the results and add table rows for each
|
| 322 |
for word, sim in word_similarity_tuples:
|
| 323 |
word_with_underscores = word.replace(' ', '_')
|
| 324 |
-
word_with_escaped_parentheses = word
|
| 325 |
count = tag2count.get(word_with_underscores.replace("\\(", "(").replace("\\)", ")"), 0) # Get the count if available, otherwise default to 0
|
| 326 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
| 327 |
# Check if tag_id and wiki_entry are valid
|
|
@@ -329,7 +638,10 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
|
|
| 329 |
# Construct the URL for the tag's wiki page
|
| 330 |
wiki_url = f"https://e621.net/wiki_pages/{tag_id}"
|
| 331 |
# Make the tag a hyperlink with a tooltip
|
| 332 |
-
tag_element =
|
|
|
|
|
|
|
|
|
|
| 333 |
else:
|
| 334 |
# Display the word without any hyperlink or tooltip
|
| 335 |
tag_element = word_with_escaped_parentheses
|
|
@@ -341,36 +653,34 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
|
|
| 341 |
|
| 342 |
|
| 343 |
def create_top_artists_table(top_artists):
|
| 344 |
-
# Add a heading above the table
|
| 345 |
html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 346 |
-
html_str += "<h1>Top Artists</h1>"
|
| 347 |
-
# Start the table with increased font size and no borders between rows
|
| 348 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
| 349 |
html_str += "<thead><tr><th>Artist</th><th>Similarity</th></tr></thead><tbody>"
|
| 350 |
-
|
| 351 |
for artist, score in top_artists:
|
| 352 |
-
|
| 353 |
-
similarity_percentage = "{:.1f}%".format(score * 100)
|
| 354 |
-
html_str +=
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
# Close the table HTML
|
| 357 |
html_str += "</tbody></table></div>"
|
| 358 |
-
|
| 359 |
return html_str
|
| 360 |
|
| 361 |
|
| 362 |
-
def construct_pseudo_vector(pseudo_doc_terms,
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
return pseudo_vector.reshape(1, -1)
|
| 374 |
|
| 375 |
|
| 376 |
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
|
@@ -388,36 +698,32 @@ def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
|
|
| 388 |
idf = tf_idf_components['idf']
|
| 389 |
term_to_column_index = tf_idf_components['tag_to_column_index']
|
| 390 |
row_to_tag = tf_idf_components['row_to_tag']
|
| 391 |
-
reduced_matrix = tf_idf_components['reduced_matrix']
|
| 392 |
svd = tf_idf_components['svd_model']
|
| 393 |
|
| 394 |
-
#
|
| 395 |
pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
|
|
|
|
| 396 |
|
| 397 |
-
#
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
# Compute cosine similarities in the reduced space
|
| 401 |
-
cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
|
| 402 |
|
| 403 |
-
#
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
if not allow_nsfw_tags:
|
| 410 |
-
tag_similarity_dict = {
|
| 411 |
-
|
| 412 |
-
tag_similarity_dict = {"by " + tag if is_artist(tag) else tag: sim for tag, sim in tag_similarity_dict.items()}
|
| 413 |
|
| 414 |
-
# Sort
|
| 415 |
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
| 416 |
transformed_sorted_tag_similarity_dict = OrderedDict(
|
| 417 |
-
(key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'),
|
| 418 |
-
for key,
|
| 419 |
)
|
| 420 |
-
|
| 421 |
return transformed_sorted_tag_similarity_dict
|
| 422 |
|
| 423 |
|
|
@@ -463,22 +769,62 @@ def find_similar_tags(test_tags, tag_to_context_similarity, context_similarity_w
|
|
| 463 |
end_pos = tag_info['end_pos']
|
| 464 |
node_type = tag_info['node_type']
|
| 465 |
|
|
|
|
|
|
|
|
|
|
| 466 |
if modified_tag in special_tags:
|
| 467 |
bad_entities.append({"entity":"Special", "start":start_pos, "end":end_pos})
|
| 468 |
continue
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
if modified_tag in encountered_modified_tags:
|
| 471 |
bad_entities.append({"entity":"Duplicate", "start":start_pos, "end":end_pos})
|
| 472 |
continue
|
| 473 |
encountered_modified_tags.add(modified_tag)
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
| 477 |
result, seen = [], set(transformed_tags)
|
| 478 |
|
| 479 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
| 480 |
if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
|
| 481 |
-
result.append(modified_tag_for_search.replace('_',' '), 1)
|
| 482 |
seen.add(modified_tag)
|
| 483 |
else: #The user correctly did not put underscores in their tag
|
| 484 |
count = find_similar_tags.tag2count.get(modified_tag_for_search, 0) # Get the count if available, otherwise default to 0
|
|
@@ -503,27 +849,47 @@ def find_similar_tags(test_tags, tag_to_context_similarity, context_similarity_w
|
|
| 503 |
result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
|
| 504 |
seen.add(similar_tag)
|
| 505 |
|
| 506 |
-
#Remove NSFW tags if appropriate.
|
| 507 |
if not allow_nsfw_tags:
|
| 508 |
-
result = [(
|
| 509 |
-
|
| 510 |
-
#
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
|
|
|
| 517 |
html_content += create_html_tables_for_tags(modified_tag, "Corrected Tag", result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
| 518 |
|
| 519 |
bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
|
| 520 |
|
| 521 |
tags_added=True
|
| 522 |
-
# If no tags were processed, add a message
|
| 523 |
if not tags_added:
|
| 524 |
html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
-
return html_content, bad_entities, known_entities_in_prompt # Return list of lists for Dataframe
|
| 527 |
|
| 528 |
|
| 529 |
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
|
@@ -581,57 +947,91 @@ def augment_bad_entities_with_regex(text):
|
|
| 581 |
def escape_html(text):
|
| 582 |
return text.replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """).replace("'", "'")
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
def format_annotated_html(bad_entities, known_entities, text):
|
| 585 |
tooltip_map = {
|
| 586 |
"Unknown Tag": "This may not be a valid e621 tag. Consider removing or replacing it with tag(s) from the \"Unknown Tags\" section.",
|
| 587 |
"Duplicate": "This tag has appeared multiple times in your prompt. Consider removing the copies.",
|
| 588 |
"Remove Final Comma": "There should be no comma at the end of your prompt. Consider removing it.",
|
| 589 |
"Move Comma Inside Parentheses": "In most e621-based models, the comma following a tag functions as an "attention anchor", carrying most of the tag's information. It should therefore be assigned the same weight as the rest of the tag. So instead of "(lineless:1.1),", consider "(lineless,:1.1)" or "(lineless,)"",
|
| 590 |
-
"Double Comma": "One comma between tags is considered ample."
|
|
|
|
| 591 |
}
|
| 592 |
color_map = {
|
| 593 |
-
"Unknown Tag": ("white", "red"),
|
| 594 |
-
"Duplicate": ("black", "yellow"),
|
| 595 |
-
"
|
| 596 |
-
"
|
| 597 |
-
"
|
|
|
|
| 598 |
}
|
| 599 |
-
|
| 600 |
-
#
|
| 601 |
-
|
| 602 |
-
combined_entities = sorted(combined_entities, key=lambda x: x['start'],reverse=True)
|
| 603 |
-
|
| 604 |
-
# Generate HTML for the main text
|
| 605 |
html_text = text
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
if label == "Known Tag":
|
| 611 |
-
wiki_url = entity.get(
|
| 612 |
-
count = entity
|
| 613 |
-
wiki_entry = entity.get(
|
| 614 |
-
sanitized_wiki_entry = escape_html(wiki_entry) if wiki_entry else
|
| 615 |
-
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
else:
|
| 618 |
-
html_part =
|
|
|
|
|
|
|
|
|
|
| 619 |
else:
|
| 620 |
-
|
| 621 |
-
html_part = f'<span style="background-color: {
|
|
|
|
| 622 |
html_text = html_text[:start] + html_part + html_text[end:]
|
| 623 |
-
|
| 624 |
-
#
|
| 625 |
color_key_html = "<div style='text-align: right; margin-top: 20px;'>Key:"
|
| 626 |
-
used_labels =
|
| 627 |
-
for label,
|
| 628 |
if label in used_labels:
|
| 629 |
tooltip = tooltip_map.get(label, "")
|
| 630 |
-
|
| 631 |
-
|
|
|
|
|
|
|
| 632 |
color_key_html += "</div>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
return f'<div style="padding: 10px; font-size: 16px;">{html_text}</div>{color_key_html}'
|
| 635 |
|
| 636 |
|
| 637 |
def find_similar_artists(original_tags_string, top_n, context_similarity_weight, allow_nsfw_tags):
|
|
@@ -648,8 +1048,9 @@ def find_similar_artists(original_tags_string, top_n, context_similarity_weight,
|
|
| 648 |
#Suggested tags stuff
|
| 649 |
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 650 |
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
| 651 |
-
|
| 652 |
-
|
|
|
|
| 653 |
unseen_tags_data, bad_entities, known_entities = find_similar_tags(tag_data, suggested_tags, context_similarity_weight, allow_nsfw_tags)
|
| 654 |
|
| 655 |
#Bad tags stuff
|
|
@@ -660,40 +1061,130 @@ def find_similar_artists(original_tags_string, top_n, context_similarity_weight,
|
|
| 660 |
|
| 661 |
# Create a set of tags that should be filtered out
|
| 662 |
filter_tags = {entry["original_tag"].strip() for entry in tag_data}
|
| 663 |
-
|
| 664 |
-
suggested_tags_filtered = OrderedDict(
|
|
|
|
|
|
|
|
|
|
| 665 |
|
| 666 |
# Splitting the dictionary into two based on the condition
|
| 667 |
-
|
| 668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
|
| 670 |
topnsuggestions = list(islice(suggested_non_artist_tags_filtered.items(), 100))
|
| 671 |
suggested_tags_html_content += create_html_tables_for_tags("-", "Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
top_artists_str = create_top_artists_table(top_artists)
|
| 677 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
image_galleries = []
|
| 680 |
for root, dirs, files in os.walk(sample_images_directory_path):
|
| 681 |
for name in dirs:
|
| 682 |
-
baseline, artists = generate_artist_image_tuples([name
|
|
|
|
|
|
|
|
|
|
| 683 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 684 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 685 |
|
| 686 |
return (unseen_tags_data, bad_tags_illustrated_html, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
| 687 |
-
except ParseError
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
|
| 691 |
with gr.Blocks(css=css) as app:
|
| 692 |
with gr.Group():
|
| 693 |
with gr.Row():
|
| 694 |
-
with gr.Column(scale=3):
|
| 695 |
-
image_tags = gr.Textbox(
|
| 696 |
-
|
|
|
|
|
|
|
|
|
|
| 697 |
bad_tags_illustrated_string = gr.HTML()
|
| 698 |
with gr.Column(scale=1):
|
| 699 |
gr.HTML(
|
|
@@ -708,23 +1199,92 @@ with gr.Blocks(css=css) as app:
|
|
| 708 |
with gr.Group():
|
| 709 |
with gr.Row():
|
| 710 |
context_similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Context Similarity Weight")
|
| 711 |
-
allow_nsfw = gr.Checkbox(label="Allow NSFW
|
| 712 |
with gr.Row():
|
| 713 |
with gr.Column(scale=2):
|
| 714 |
-
unseen_tags = gr.HTML(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
with gr.Column(scale=1):
|
| 716 |
-
suggested_tags = gr.HTML(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
with gr.Column(scale=1):
|
| 718 |
with gr.Group():
|
| 719 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
| 720 |
-
top_artists = gr.HTML(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
dynamic_prompts = gr.Textbox(label="Dynamic Prompts Format", info="For if you're using the Automatic1111 webui (https://github.com/AUTOMATIC1111/stable-diffusion-webui) with the Dynamic Prompts extension activated (https://github.com/adieyal/sd-dynamic-prompts) and want to try them all individually.")
|
| 722 |
galleries = []
|
| 723 |
for root, dirs, files in os.walk(sample_images_directory_path):
|
| 724 |
for name in dirs:
|
| 725 |
with gr.Row():
|
| 726 |
baseline = gr.Gallery(allow_preview=False, rows=1, columns=1, height=420, scale=3)
|
| 727 |
-
styles = gr.Gallery(
|
| 728 |
galleries.extend([baseline, styles])
|
| 729 |
|
| 730 |
submit_button.click(
|
|
@@ -732,6 +1292,13 @@ with gr.Blocks(css=css) as app:
|
|
| 732 |
inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
|
| 733 |
outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
|
| 734 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
gr.Markdown(faq_content)
|
| 737 |
|
|
|
|
| 23 |
from itertools import islice
|
| 24 |
from pathlib import Path
|
| 25 |
import logging
|
| 26 |
+
import hnswlib
|
| 27 |
+
import pathlib
|
| 28 |
+
from collections import Counter
|
| 29 |
|
| 30 |
# Set up logging
|
| 31 |
+
# Minimal prod logging: warnings+ to stderr, no file by default
|
| 32 |
+
import os, logging
|
| 33 |
+
|
| 34 |
+
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=getattr(logging, LOG_LEVEL, logging.WARNING),
|
| 37 |
+
format="%(asctime)s %(levelname)s:%(message)s",
|
| 38 |
+
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Quiet down common noisy libs (optional)
|
| 42 |
+
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
|
| 43 |
+
logging.getLogger(_name).setLevel(logging.ERROR)
|
| 44 |
+
|
| 45 |
+
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
|
| 46 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
|
| 47 |
+
|
| 48 |
|
| 49 |
|
| 50 |
faq_content="""
|
|
|
|
| 54 |
|
| 55 |
Since Stable Diffusion's initial release in 2022, users have developed a myriad of fine-tuned text to image models, each with unique "linguistic" preferences depending on the data from which it was fine-tuned.
|
| 56 |
Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
|
| 57 |
+
This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, NoobAI, and Pony Diffusion v6 were trained.
|
| 58 |
|
| 59 |
When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
|
| 60 |
If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
|
|
|
|
| 134 |
See SamplePrompts.csv for the list of prompts used and their descriptions.
|
| 135 |
"""
|
| 136 |
|
| 137 |
+
TOOLTIP_NOTE_HTML = '<div class="hover-hint">Underlined items can be hovered for more info.</div>'
|
| 138 |
+
|
| 139 |
+
HOVER_HINT_CSS = """
|
| 140 |
+
/* Solid, visible underline for tagged items */
|
| 141 |
+
.gradio-container .hover-underline{
|
| 142 |
+
text-decoration-line: underline !important;
|
| 143 |
+
text-decoration-thickness: 2px;
|
| 144 |
+
text-underline-offset: 2px;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/* Small, subtle hint text */
|
| 148 |
+
.hover-hint{
|
| 149 |
+
font-size: 12px;
|
| 150 |
+
opacity: .85;
|
| 151 |
+
line-height: 1.2;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/* Wrapper to position the hint in the bottom-right of the annotated box */
|
| 155 |
+
.annotated-wrap{ position: relative; }
|
| 156 |
+
.annotated-wrap .hover-hint{
|
| 157 |
+
position: absolute;
|
| 158 |
+
right: 6px;
|
| 159 |
+
bottom: 6px;
|
| 160 |
+
text-align: right;
|
| 161 |
+
}
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
from gradio_client import utils as _gc_utils
|
| 168 |
+
|
| 169 |
+
_orig_get_type = _gc_utils.get_type
|
| 170 |
+
_orig_j2p = _gc_utils._json_schema_to_python_type
|
| 171 |
+
_orig_pub = _gc_utils.json_schema_to_python_type
|
| 172 |
+
|
| 173 |
+
def _get_type_safe(schema):
|
| 174 |
+
# Sometimes schema is a bare True/False (JSON Schema boolean form)
|
| 175 |
+
if not isinstance(schema, dict):
|
| 176 |
+
return "any"
|
| 177 |
+
return _orig_get_type(schema)
|
| 178 |
+
|
| 179 |
+
def _j2p_safe(schema, defs=None):
|
| 180 |
+
# Accept non-dict schemas (True/False/None) and treat as "any"
|
| 181 |
+
if not isinstance(schema, dict):
|
| 182 |
+
return "any"
|
| 183 |
+
return _orig_j2p(schema, defs or schema.get("$defs"))
|
| 184 |
+
|
| 185 |
+
def _pub_safe(schema):
|
| 186 |
+
# Public wrapper used by Gradio; keep it resilient too
|
| 187 |
+
if not isinstance(schema, dict):
|
| 188 |
+
return "any"
|
| 189 |
+
return _j2p_safe(schema, schema.get("$defs"))
|
| 190 |
+
|
| 191 |
+
_gc_utils.get_type = _get_type_safe
|
| 192 |
+
_gc_utils._json_schema_to_python_type = _j2p_safe
|
| 193 |
+
_gc_utils.json_schema_to_python_type = _pub_safe
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print("gradio_client hotfix not applied:", e)
|
| 197 |
+
# -------------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
|
| 200 |
nsfw_threshold = 0.95 # Assuming the threshold value is defined here
|
| 201 |
|
| 202 |
+
css = HOVER_HINT_CSS + """
|
| 203 |
+
.scrollable-content{
|
| 204 |
+
max-height: 420px;
|
| 205 |
+
overflow-y: scroll; /* always show scrollbar */
|
| 206 |
+
overflow-x: hidden;
|
| 207 |
+
padding-right: 8px;
|
| 208 |
+
padding-bottom: 14px; /* <— add this */
|
| 209 |
+
scrollbar-gutter: stable; /* prevent layout shift as it fills */
|
| 210 |
+
|
| 211 |
+
/* Firefox */
|
| 212 |
+
scrollbar-width: auto;
|
| 213 |
+
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
/* WebKit/Chromium (Chrome/Edge/Safari) */
|
| 217 |
+
.scrollable-content::-webkit-scrollbar{ width: 10px; }
|
| 218 |
+
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
|
| 219 |
+
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
|
| 220 |
+
|
| 221 |
+
/* --- Fade that blends into the pane background, no chip --- */
|
| 222 |
+
.scroll-fade {
|
| 223 |
+
position: relative;
|
| 224 |
+
/* ensure our ::after overlay paints above children */
|
| 225 |
+
isolation: isolate;
|
| 226 |
}
|
| 227 |
+
|
| 228 |
+
.scroll-fade::after{
|
| 229 |
+
content: "";
|
| 230 |
+
position: absolute;
|
| 231 |
+
left: 0; right: 0; bottom: 0;
|
| 232 |
+
height: 20px; /* a hair taller; tweak if you like */
|
| 233 |
+
pointer-events: none;
|
| 234 |
+
/* transparent → panel background (Gradio theme var, with dark fallback) */
|
| 235 |
+
background: linear-gradient(
|
| 236 |
+
to bottom,
|
| 237 |
+
rgba(0,0,0,0),
|
| 238 |
+
var(--background-fill-secondary, #1f2937)
|
| 239 |
+
);
|
| 240 |
+
transition: opacity .18s ease;
|
| 241 |
+
z-index: 3; /* sit above the scroller’s content */
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.scroll-fade.at-bottom::after { opacity: 0; }
|
| 245 |
+
|
| 246 |
+
/* no chip */
|
| 247 |
+
.scroll-fade::before { content: none; }
|
| 248 |
+
|
| 249 |
"""
|
| 250 |
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
#Parser
|
| 254 |
grammar=r"""
|
| 255 |
!start: (prompt | /[][():]/+)*
|
| 256 |
prompt: (emphasized | plain | comma | WHITESPACE)*
|
|
|
|
| 265 |
# Initialize the parser
|
| 266 |
parser = Lark(grammar, start='start')
|
| 267 |
|
| 268 |
+
# ---------- Two HNSW indexes: artists and non-artist tags ----------
|
| 269 |
+
_HNSW_ART = None
|
| 270 |
+
_HNSW_TAG = None
|
| 271 |
+
_HNSW_DIM = None
|
| 272 |
+
_HNSW_N_ART = None
|
| 273 |
+
_HNSW_N_TAG = None
|
| 274 |
+
_HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
|
| 275 |
+
_HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
|
| 276 |
+
|
| 277 |
+
def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray:
|
| 278 |
+
mat = np.asarray(mat, dtype=np.float32)
|
| 279 |
+
norms = np.linalg.norm(mat, axis=1, keepdims=True)
|
| 280 |
+
norms[norms == 0.0] = 1.0
|
| 281 |
+
return mat / norms
|
| 282 |
+
|
| 283 |
+
def _ensure_dual_hnsw_indexes():
|
| 284 |
+
"""
|
| 285 |
+
Build/load two HNSW indexes over the SVD-reduced TF-IDF matrix:
|
| 286 |
+
• _HNSW_ART — rows whose tag (with optional 'by_' stripped) is in the artist_set
|
| 287 |
+
• _HNSW_TAG — only rows that are NOT artist tags
|
| 288 |
+
Index item IDs are the ORIGINAL row indices in reduced_matrix.
|
| 289 |
+
"""
|
| 290 |
+
global _HNSW_ART, _HNSW_TAG, _HNSW_DIM, _HNSW_N_ART, _HNSW_N_TAG
|
| 291 |
+
|
| 292 |
+
if _HNSW_ART is not None and _HNSW_TAG is not None:
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
reduced_matrix = tf_idf_components['reduced_matrix'] # (N, D)
|
| 296 |
+
row_to_tag = tf_idf_components['row_to_tag'] # {row:int -> "tag_with_underscores"}
|
| 297 |
+
rm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
|
| 298 |
+
n_items, dim = rm.shape
|
| 299 |
+
|
| 300 |
+
# Partition rows
|
| 301 |
+
artist_rows = []
|
| 302 |
+
tag_rows = []
|
| 303 |
+
|
| 304 |
+
for i in range(n_items):
|
| 305 |
+
tag = row_to_tag.get(i, "")
|
| 306 |
+
|
| 307 |
+
# Strip leading "by_" if present in the TF-IDF vocabulary, but don't rely on it.
|
| 308 |
+
base = tag[3:] if tag.startswith("by_") else tag
|
| 309 |
+
|
| 310 |
+
# Some corpora contain buckets you don't want shown as artists:
|
| 311 |
+
if tag in {"by_unknown_artist", "by_conditional_dnp"}:
|
| 312 |
+
tag_rows.append(i)
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
if is_artist(base):
|
| 316 |
+
artist_rows.append(i)
|
| 317 |
+
else:
|
| 318 |
+
tag_rows.append(i)
|
| 319 |
+
|
| 320 |
+
logging.debug(f"HNSW partition: artists={len(artist_rows)} non_artists={len(tag_rows)}")
|
| 321 |
+
|
| 322 |
+
# Helper: build or load an index for a subset of rows
|
| 323 |
+
def _build_or_load(path: pathlib.Path, rows: list[int]) -> hnswlib.Index:
|
| 324 |
+
idx = hnswlib.Index(space='cosine', dim=dim)
|
| 325 |
+
need_build = True
|
| 326 |
+
if path.exists():
|
| 327 |
+
try:
|
| 328 |
+
idx.load_index(str(path), max_elements=max(1, len(rows)))
|
| 329 |
+
# Rebuild if the saved index count doesn’t match our rows
|
| 330 |
+
if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0:
|
| 331 |
+
need_build = False
|
| 332 |
+
else:
|
| 333 |
+
logging.debug(f"Rebuilding {path.name}: saved_count!=rows_len ({idx.get_current_count()} vs {len(rows)})")
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logging.debug(f"Reload {path.name} failed, rebuilding: {e}")
|
| 336 |
+
|
| 337 |
+
if need_build:
|
| 338 |
+
try:
|
| 339 |
+
if path.exists():
|
| 340 |
+
path.unlink()
|
| 341 |
+
except Exception:
|
| 342 |
+
pass
|
| 343 |
+
idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16)
|
| 344 |
+
if rows:
|
| 345 |
+
idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32))
|
| 346 |
+
idx.save_index(str(path))
|
| 347 |
+
|
| 348 |
+
idx.set_ef(200)
|
| 349 |
+
return idx
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
_HNSW_ART = _build_or_load(_HNSW_ART_PATH, artist_rows)
|
| 353 |
+
_HNSW_TAG = _build_or_load(_HNSW_TAG_PATH, tag_rows)
|
| 354 |
+
_HNSW_DIM = dim
|
| 355 |
+
_HNSW_N_ART = len(artist_rows)
|
| 356 |
+
_HNSW_N_TAG = len(tag_rows)
|
| 357 |
+
|
| 358 |
+
def _hnsw_query(idx: hnswlib.Index, vec: np.ndarray, k: int):
|
| 359 |
+
"""
|
| 360 |
+
Query a given HNSW index with a (1, D) or (D,) vector in SVD space.
|
| 361 |
+
Returns (indices, sims) with cosine similarity scores.
|
| 362 |
+
"""
|
| 363 |
+
_ensure_dual_hnsw_indexes()
|
| 364 |
+
q = np.asarray(vec, dtype=np.float32).reshape(-1)
|
| 365 |
+
q_norm = np.linalg.norm(q)
|
| 366 |
+
if q_norm > 0:
|
| 367 |
+
q = q / q_norm
|
| 368 |
+
labels, dists = idx.knn_query(q, k=k)
|
| 369 |
+
inds = labels[0]
|
| 370 |
+
sims = 1.0 - dists[0] # cosine distance -> similarity
|
| 371 |
+
return inds, sims
|
| 372 |
+
|
| 373 |
+
def _ann_tags_topk(vec: np.ndarray, k: int):
|
| 374 |
+
_ensure_dual_hnsw_indexes()
|
| 375 |
+
k = min(k, _HNSW_N_TAG if _HNSW_N_TAG else 0)
|
| 376 |
+
return _hnsw_query(_HNSW_TAG, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
|
| 377 |
+
|
| 378 |
+
def _ann_artists_topk(vec: np.ndarray, k: int):
|
| 379 |
+
_ensure_dual_hnsw_indexes()
|
| 380 |
+
k = min(k, _HNSW_N_ART if _HNSW_N_ART else 0)
|
| 381 |
+
return _hnsw_query(_HNSW_ART, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
|
| 382 |
+
# ------------------------------------------------------------------
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _norm_tag_for_lookup(s: str) -> str:
|
| 386 |
+
# convert "name with spaces" -> "name_with_spaces" and unescape parens
|
| 387 |
+
return s.replace(' ', '_').replace('\\(', '(').replace('\\)', ')')
|
| 388 |
+
|
| 389 |
# Function to extract tags
|
| 390 |
def extract_tags(tree):
|
| 391 |
tags_with_positions = []
|
|
|
|
| 407 |
remaining_tags = [tag for tag in tags if tag not in special_tags]
|
| 408 |
removed_tags = [tag for tag in tags if tag in special_tags]
|
| 409 |
return ", ".join(remaining_tags), removed_tags
|
| 410 |
+
|
| 411 |
+
#Model specific tags
|
| 412 |
+
MODEL_SPECIFIC_TAGS = {
|
| 413 |
+
"masterpiece",
|
| 414 |
+
"best quality",
|
| 415 |
+
"good quality",
|
| 416 |
+
"normal quality",
|
| 417 |
+
"newest",
|
| 418 |
+
"absurdres",
|
| 419 |
+
"highres",
|
| 420 |
+
"safe",
|
| 421 |
+
"worst quality",
|
| 422 |
+
"early",
|
| 423 |
+
"low quality",
|
| 424 |
+
"lowres",
|
| 425 |
+
"explict content",
|
| 426 |
+
"very awa",
|
| 427 |
+
"worst aesthetic",
|
| 428 |
+
"score_9",
|
| 429 |
+
"score_8_up",
|
| 430 |
+
"score_7_up",
|
| 431 |
+
"score_6_up",
|
| 432 |
+
"score_5_up",
|
| 433 |
+
"score_4_up",
|
| 434 |
+
"source_pony",
|
| 435 |
+
"source_furry",
|
| 436 |
+
"source_cartoon",
|
| 437 |
+
"source_anime",
|
| 438 |
+
"rating_safe",
|
| 439 |
+
"rating_questionable",
|
| 440 |
+
"rating_explicit"
|
| 441 |
+
}
|
| 442 |
|
| 443 |
# Define a function to load all necessary components
|
| 444 |
def load_model_components(file_path):
|
|
|
|
| 460 |
|
| 461 |
# Load all components at the start
|
| 462 |
tf_idf_components = load_model_components('tf_idf_files_420.joblib')
|
| 463 |
+
idf = tf_idf_components['idf']
|
| 464 |
+
if isinstance(idf, dict):
|
| 465 |
+
# idf is term -> idf_value; build a column-aligned vector
|
| 466 |
+
t2c = tf_idf_components['tag_to_column_index']
|
| 467 |
+
n_cols = max(t2c.values()) + 1
|
| 468 |
+
idf_by_col = np.ones(n_cols, dtype=np.float32)
|
| 469 |
+
for term, col in t2c.items():
|
| 470 |
+
idf_by_col[col] = float(idf.get(term, 1.0))
|
| 471 |
+
tf_idf_components['idf'] = idf_by_col
|
| 472 |
|
| 473 |
nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
|
| 474 |
# Open and read the CSV file
|
|
|
|
| 504 |
sample_images_directory_path = 'sampleimages'
|
| 505 |
def generate_artist_image_tuples(top_artists, image_directory):
|
| 506 |
json_files = glob.glob(f'{image_directory}/*.json')
|
| 507 |
+
if not json_files:
|
| 508 |
+
return [], [] # no mapping present; return empty galleries safely
|
| 509 |
+
json_file_path = json_files[0]
|
| 510 |
with open(json_file_path, 'r') as json_file:
|
| 511 |
artist_to_file_map = json.load(json_file)
|
| 512 |
+
# DEBUG: mapping + baseline info
|
| 513 |
+
logging.debug("Gallery %s: loaded %d entries (map file=%s)",
|
| 514 |
+
image_directory, len(artist_to_file_map), json_file_path)
|
| 515 |
+
_base = artist_to_file_map.get("")
|
| 516 |
+
logging.debug(
|
| 517 |
+
"Gallery %s: baseline '' -> %r (exists=%s)",
|
| 518 |
+
image_directory,
|
| 519 |
+
_base,
|
| 520 |
+
os.path.exists(os.path.join(image_directory, _base)) if _base else None,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
baseline_tuple = []
|
| 524 |
filename = artist_to_file_map.get("")
|
| 525 |
+
if filename:
|
| 526 |
+
image_path = os.path.join(image_directory, filename)
|
| 527 |
+
if os.path.exists(image_path):
|
| 528 |
+
baseline_tuple = [(image_path, "No Artist")]
|
| 529 |
+
|
| 530 |
artist_image_tuples = []
|
| 531 |
for artist in top_artists:
|
| 532 |
filename = artist_to_file_map.get(artist)
|
| 533 |
+
# DEBUG: per-artist resolution
|
| 534 |
+
logging.debug(
|
| 535 |
+
"Gallery %s: %s -> %r (exists=%s)",
|
| 536 |
+
image_directory,
|
| 537 |
+
artist,
|
| 538 |
+
filename,
|
| 539 |
+
os.path.exists(os.path.join(image_directory, filename)) if filename else None,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
if filename:
|
| 543 |
image_path = os.path.join(image_directory, filename)
|
| 544 |
if os.path.exists(image_path):
|
|
|
|
| 630 |
# Loop through the results and add table rows for each
|
| 631 |
for word, sim in word_similarity_tuples:
|
| 632 |
word_with_underscores = word.replace(' ', '_')
|
| 633 |
+
word_with_escaped_parentheses = escape_parens_for_display(word)
|
| 634 |
count = tag2count.get(word_with_underscores.replace("\\(", "(").replace("\\)", ")"), 0) # Get the count if available, otherwise default to 0
|
| 635 |
tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
|
| 636 |
# Check if tag_id and wiki_entry are valid
|
|
|
|
| 638 |
# Construct the URL for the tag's wiki page
|
| 639 |
wiki_url = f"https://e621.net/wiki_pages/{tag_id}"
|
| 640 |
# Make the tag a hyperlink with a tooltip
|
| 641 |
+
tag_element = (
|
| 642 |
+
f"<a class='hover-underline' href='{wiki_url}' target='_blank' "
|
| 643 |
+
f"title='{wiki_entry}'>{word_with_escaped_parentheses}</a>"
|
| 644 |
+
)
|
| 645 |
else:
|
| 646 |
# Display the word without any hyperlink or tooltip
|
| 647 |
tag_element = word_with_escaped_parentheses
|
|
|
|
| 653 |
|
| 654 |
|
| 655 |
def create_top_artists_table(top_artists):
|
|
|
|
| 656 |
html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 657 |
+
html_str += "<h1>Top Artists</h1>"
|
|
|
|
| 658 |
html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
|
| 659 |
html_str += "<thead><tr><th>Artist</th><th>Similarity</th></tr></thead><tbody>"
|
| 660 |
+
|
| 661 |
for artist, score in top_artists:
|
| 662 |
+
artist_disp = escape_html(escape_parens_for_display(artist))
|
| 663 |
+
similarity_percentage = "{:.1f}%".format(score * 100)
|
| 664 |
+
html_str += (
|
| 665 |
+
f"<tr><td style='padding: 3px 20px; border: none;'>{artist_disp}</td>"
|
| 666 |
+
f"<td style='padding: 3px 20px; border: none;'>{similarity_percentage}</td></tr>"
|
| 667 |
+
)
|
| 668 |
|
|
|
|
| 669 |
html_str += "</tbody></table></div>"
|
|
|
|
| 670 |
return html_str
|
| 671 |
|
| 672 |
|
| 673 |
+
def construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index):
|
| 674 |
+
cols, data = [], []
|
| 675 |
+
for term, w in pseudo_doc_terms.items():
|
| 676 |
+
j = term_to_column_index.get(term)
|
| 677 |
+
if j is None:
|
| 678 |
+
continue
|
| 679 |
+
cols.append(j)
|
| 680 |
+
data.append(w * idf[j])
|
| 681 |
+
n_cols = len(idf)
|
| 682 |
+
indptr = [0, len(cols)]
|
| 683 |
+
return csr_matrix((data, cols, indptr), shape=(1, n_cols), dtype=np.float32)
|
|
|
|
| 684 |
|
| 685 |
|
| 686 |
def get_top_indices(reduced_pseudo_vector, reduced_matrix):
|
|
|
|
| 698 |
idf = tf_idf_components['idf']
|
| 699 |
term_to_column_index = tf_idf_components['tag_to_column_index']
|
| 700 |
row_to_tag = tf_idf_components['row_to_tag']
|
|
|
|
| 701 |
svd = tf_idf_components['svd_model']
|
| 702 |
|
| 703 |
+
# 1) Build the pseudo TF-IDF, reduce to SVD space (unchanged)
|
| 704 |
pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
|
| 705 |
+
reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector) # shape (1, D)
|
| 706 |
|
| 707 |
+
# 2) ANN: only fetch nearest non-artist candidates (no full-matrix cosine)
|
| 708 |
+
K = 2000 # tune for speed/recall
|
| 709 |
+
top_inds, top_sims = _ann_tags_topk(reduced_pseudo_vector, k=K)
|
|
|
|
|
|
|
| 710 |
|
| 711 |
+
# 3) Build similarity dict from those candidates
|
| 712 |
+
tag_similarity_dict = {}
|
| 713 |
+
for i, sim in zip(top_inds, top_sims):
|
| 714 |
+
tag = row_to_tag.get(int(i))
|
| 715 |
+
if tag is not None:
|
| 716 |
+
tag_similarity_dict[tag] = float(sim)
|
| 717 |
|
| 718 |
if not allow_nsfw_tags:
|
| 719 |
+
tag_similarity_dict = {t: s for t, s in tag_similarity_dict.items() if t not in nsfw_tags}
|
|
|
|
|
|
|
| 720 |
|
| 721 |
+
# 4) Sort & escape like before
|
| 722 |
sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
|
| 723 |
transformed_sorted_tag_similarity_dict = OrderedDict(
|
| 724 |
+
(key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), val)
|
| 725 |
+
for key, val in sorted_tag_similarity_dict.items()
|
| 726 |
)
|
|
|
|
| 727 |
return transformed_sorted_tag_similarity_dict
|
| 728 |
|
| 729 |
|
|
|
|
| 769 |
end_pos = tag_info['end_pos']
|
| 770 |
node_type = tag_info['node_type']
|
| 771 |
|
| 772 |
+
# Build the underscore form up-front
|
| 773 |
+
modified_tag_for_search = modified_tag.replace(' ', '_')
|
| 774 |
+
|
| 775 |
if modified_tag in special_tags:
|
| 776 |
bad_entities.append({"entity":"Special", "start":start_pos, "end":end_pos})
|
| 777 |
continue
|
| 778 |
+
|
| 779 |
+
# Only accept exact underscore model-specific tokens (e.g., "score_9")
|
| 780 |
+
# special score/rating tags (kept as-is)
|
| 781 |
+
if modified_tag in special_tags:
|
| 782 |
+
bad_entities.append({"entity": "Special", "start": start_pos, "end": end_pos})
|
| 783 |
+
continue
|
| 784 |
+
|
| 785 |
+
# Model-specific tokens must match the user's input *exactly* (no pre-normalization).
|
| 786 |
+
# Use the original token as typed in the prompt, lowercased.
|
| 787 |
+
original_raw = tag_info["original_tag"].strip().lower()
|
| 788 |
+
if original_raw in MODEL_SPECIFIC_TAGS:
|
| 789 |
+
bad_entities.append({"entity": "Model Specific", "start": start_pos, "end": end_pos})
|
| 790 |
+
continue
|
| 791 |
+
|
| 792 |
+
|
| 793 |
if modified_tag in encountered_modified_tags:
|
| 794 |
bad_entities.append({"entity":"Duplicate", "start":start_pos, "end":end_pos})
|
| 795 |
continue
|
| 796 |
encountered_modified_tags.add(modified_tag)
|
| 797 |
+
|
| 798 |
+
norm_artist = (
|
| 799 |
+
modified_tag_for_search
|
| 800 |
+
.lower()
|
| 801 |
+
.removeprefix('by_') # tolerate users typing "by something" or not
|
| 802 |
+
)
|
| 803 |
+
if is_artist(norm_artist):
|
| 804 |
+
by_key = f"by_{norm_artist}"
|
| 805 |
+
# try by_* first, then raw form as fallback
|
| 806 |
+
count = (find_similar_tags.tag2count.get(by_key) or
|
| 807 |
+
find_similar_tags.tag2count.get(modified_tag_for_search, 0))
|
| 808 |
+
tag_id, wiki_entry = (
|
| 809 |
+
find_similar_tags.tag2idwiki.get(by_key) or
|
| 810 |
+
find_similar_tags.tag2idwiki.get(modified_tag_for_search, (None, ''))
|
| 811 |
+
)
|
| 812 |
+
wiki_url = f"https://e621.net/wiki_pages/{tag_id}" if tag_id is not None and wiki_entry else ""
|
| 813 |
+
known_entities_in_prompt.append({
|
| 814 |
+
"entity": "Known Tag",
|
| 815 |
+
"start": start_pos,
|
| 816 |
+
"end": end_pos,
|
| 817 |
+
"count": count,
|
| 818 |
+
"wiki_url": wiki_url,
|
| 819 |
+
"wiki_entry": wiki_entry
|
| 820 |
+
})
|
| 821 |
+
continue
|
| 822 |
similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
|
| 823 |
result, seen = [], set(transformed_tags)
|
| 824 |
|
| 825 |
if modified_tag_for_search in find_similar_tags.tag2aliases:
|
| 826 |
if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
|
| 827 |
+
result.append((modified_tag_for_search.replace('_',' '), 1))
|
| 828 |
seen.add(modified_tag)
|
| 829 |
else: #The user correctly did not put underscores in their tag
|
| 830 |
count = find_similar_tags.tag2count.get(modified_tag_for_search, 0) # Get the count if available, otherwise default to 0
|
|
|
|
| 849 |
result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
|
| 850 |
seen.add(similar_tag)
|
| 851 |
|
| 852 |
+
# Remove NSFW tags if appropriate.
|
| 853 |
if not allow_nsfw_tags:
|
| 854 |
+
result = [(w, s) for (w, s) in result if w.replace(' ', '_') not in nsfw_tags]
|
| 855 |
+
|
| 856 |
+
# --- Context re-scoring (keys match how get_tfidf_reduced_similar_tags formats them) ---
|
| 857 |
+
def _ctx_score(name: str) -> float:
|
| 858 |
+
v = tag_to_context_similarity.get(name)
|
| 859 |
+
if v is None:
|
| 860 |
+
# TF-IDF dict escapes parentheses; candidates from FT do not.
|
| 861 |
+
v = tag_to_context_similarity.get(name.replace('(', '\\(').replace(')', '\\)'))
|
| 862 |
+
return float(v) if v is not None else 0.0
|
| 863 |
+
|
| 864 |
+
# If the slider is at 1.0, only keep candidates that exist in the TF-IDF context list.
|
| 865 |
+
if context_similarity_weight >= 0.999:
|
| 866 |
+
ctx_keys = set(tag_to_context_similarity.keys())
|
| 867 |
+
result = [
|
| 868 |
+
(w, s) for (w, s) in result
|
| 869 |
+
if (w in ctx_keys) or (w.replace('(', '\\(').replace(')', '\\)') in ctx_keys)
|
| 870 |
+
]
|
| 871 |
+
|
| 872 |
+
# Linear blend: final = (1-λ)*fasttext + λ*context (no extra 0.5 scaling)
|
| 873 |
+
result = [
|
| 874 |
+
(w, (1.0 - context_similarity_weight) * s + context_similarity_weight * _ctx_score(w))
|
| 875 |
+
for (w, s) in result
|
| 876 |
+
]
|
| 877 |
|
| 878 |
result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
|
| 879 |
+
|
| 880 |
html_content += create_html_tables_for_tags(modified_tag, "Corrected Tag", result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
| 881 |
|
| 882 |
bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
|
| 883 |
|
| 884 |
tags_added=True
|
| 885 |
+
# If no tags were processed, add a message; otherwise close the wrapper div
|
| 886 |
if not tags_added:
|
| 887 |
html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
|
| 888 |
+
else:
|
| 889 |
+
html_content += "</div>"
|
| 890 |
+
|
| 891 |
+
return html_content, bad_entities, known_entities_in_prompt
|
| 892 |
|
|
|
|
| 893 |
|
| 894 |
|
| 895 |
def build_tag_offsets_dicts(new_image_tags_with_positions):
|
|
|
|
| 947 |
def escape_html(text):
|
| 948 |
return text.replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """).replace("'", "'")
|
| 949 |
|
| 950 |
+
def escape_parens_for_display(s: str) -> str:
|
| 951 |
+
# ensure single backslash before any literal parens in display text
|
| 952 |
+
return (
|
| 953 |
+
s.replace("\\(", "(")
|
| 954 |
+
.replace("\\)", ")")
|
| 955 |
+
.replace("(", "\\(")
|
| 956 |
+
.replace(")", "\\)")
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
def format_annotated_html(bad_entities, known_entities, text):
|
| 960 |
tooltip_map = {
|
| 961 |
"Unknown Tag": "This may not be a valid e621 tag. Consider removing or replacing it with tag(s) from the \"Unknown Tags\" section.",
|
| 962 |
"Duplicate": "This tag has appeared multiple times in your prompt. Consider removing the copies.",
|
| 963 |
"Remove Final Comma": "There should be no comma at the end of your prompt. Consider removing it.",
|
| 964 |
"Move Comma Inside Parentheses": "In most e621-based models, the comma following a tag functions as an "attention anchor", carrying most of the tag's information. It should therefore be assigned the same weight as the rest of the tag. So instead of "(lineless:1.1),", consider "(lineless,:1.1)" or "(lineless,)"",
|
| 965 |
+
"Double Comma": "One comma between tags is considered ample.",
|
| 966 |
+
"Model Specific": "This is not an e621 tag, but may still be valid with the right model. Check your model's documentation. If the tag is not mentioned in the documentation, do not use it."
|
| 967 |
}
|
| 968 |
color_map = {
|
| 969 |
+
"Unknown Tag": ("white", "red"),
|
| 970 |
+
"Duplicate": ("black", "yellow"),
|
| 971 |
+
"Move Comma Inside Parentheses": ("white", "green"),
|
| 972 |
+
"Double Comma": ("white", "orange"),
|
| 973 |
+
"Model Specific": ("black", "lightgray"),
|
| 974 |
+
"Remove Final Comma": ("white", "brown")
|
| 975 |
}
|
| 976 |
+
|
| 977 |
+
# Splice from the original text so indexes stay valid.
|
| 978 |
+
combined = sorted(bad_entities + known_entities, key=lambda x: x["start"], reverse=True)
|
|
|
|
|
|
|
|
|
|
| 979 |
html_text = text
|
| 980 |
+
|
| 981 |
+
for entity in combined:
|
| 982 |
+
start = entity["start"]
|
| 983 |
+
end = entity["end"]
|
| 984 |
+
label = entity["entity"]
|
| 985 |
+
|
| 986 |
+
# Escape only the replaced segment (keeps indices correct).
|
| 987 |
+
segment = text[start:end]
|
| 988 |
+
disp = escape_html(escape_parens_for_display(segment))
|
| 989 |
+
|
| 990 |
if label == "Known Tag":
|
| 991 |
+
wiki_url = entity.get("wiki_url", "")
|
| 992 |
+
count = entity.get("count", 0)
|
| 993 |
+
wiki_entry = entity.get("wiki_entry", "")
|
| 994 |
+
sanitized_wiki_entry = escape_html(wiki_entry) if wiki_entry else "Unavailable"
|
| 995 |
+
|
| 996 |
+
if wiki_url:
|
| 997 |
+
html_part = (
|
| 998 |
+
f'<a class="hover-underline" href="{wiki_url}" target="_blank" '
|
| 999 |
+
f'title="Count: {count}\tWiki: {sanitized_wiki_entry}" '
|
| 1000 |
+
f'style="cursor: pointer; font-style: italic;">{disp}</a>'
|
| 1001 |
+
)
|
| 1002 |
else:
|
| 1003 |
+
html_part = (
|
| 1004 |
+
f'<span class="hover-underline" title="Count: {count}\tWiki: {sanitized_wiki_entry}" '
|
| 1005 |
+
f'style="cursor: help; font-style: italic;">{disp}</span>'
|
| 1006 |
+
)
|
| 1007 |
else:
|
| 1008 |
+
fg, bg = color_map.get(label, ("black", "white"))
|
| 1009 |
+
html_part = f'<span style="background-color: {bg}; color: {fg};">{disp}</span>'
|
| 1010 |
+
|
| 1011 |
html_text = html_text[:start] + html_part + html_text[end:]
|
| 1012 |
+
|
| 1013 |
+
# Color key (only for labels that actually appeared)
|
| 1014 |
color_key_html = "<div style='text-align: right; margin-top: 20px;'>Key:"
|
| 1015 |
+
used_labels = {e["entity"] for e in bad_entities}
|
| 1016 |
+
for label, (fg, bg) in color_map.items():
|
| 1017 |
if label in used_labels:
|
| 1018 |
tooltip = tooltip_map.get(label, "")
|
| 1019 |
+
color_key_html += (
|
| 1020 |
+
f" <span class='hover-underline' style='background-color: {bg}; color: {fg}; margin-right: 10px;' "
|
| 1021 |
+
f"title='{tooltip}'>{label}</span>"
|
| 1022 |
+
)
|
| 1023 |
color_key_html += "</div>"
|
| 1024 |
+
|
| 1025 |
+
# Wrap the whole annotated area so we can place the hint inside it
|
| 1026 |
+
annotated_box = (
|
| 1027 |
+
"<div class='annotated-wrap' style='padding:10px;font-size:16px;'>"
|
| 1028 |
+
f"{html_text}"
|
| 1029 |
+
f"{TOOLTIP_NOTE_HTML}"
|
| 1030 |
+
"</div>"
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
return annotated_box + color_key_html
|
| 1034 |
|
|
|
|
| 1035 |
|
| 1036 |
|
| 1037 |
def find_similar_artists(original_tags_string, top_n, context_similarity_weight, allow_nsfw_tags):
|
|
|
|
| 1048 |
#Suggested tags stuff
|
| 1049 |
suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
|
| 1050 |
suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
|
| 1051 |
+
terms = [item["tf_idf_matrix_tag"] for item in tag_data] + removed_tags
|
| 1052 |
+
suggested_tags = get_tfidf_reduced_similar_tags(dict(Counter(terms)), allow_nsfw_tags)
|
| 1053 |
+
|
| 1054 |
unseen_tags_data, bad_entities, known_entities = find_similar_tags(tag_data, suggested_tags, context_similarity_weight, allow_nsfw_tags)
|
| 1055 |
|
| 1056 |
#Bad tags stuff
|
|
|
|
| 1061 |
|
| 1062 |
# Create a set of tags that should be filtered out
|
| 1063 |
filter_tags = {entry["original_tag"].strip() for entry in tag_data}
|
| 1064 |
+
filter_tags_norm = { _norm_tag_for_lookup(t.lower().removeprefix('by ').removeprefix('by_')) for t in filter_tags }
|
| 1065 |
+
suggested_tags_filtered = OrderedDict(
|
| 1066 |
+
(k, v) for k, v in suggested_tags.items()
|
| 1067 |
+
if k not in filter_tags and _norm_tag_for_lookup(k.lower()) not in filter_tags_norm
|
| 1068 |
+
)
|
| 1069 |
|
| 1070 |
# Splitting the dictionary into two based on the condition
|
| 1071 |
+
def _norm_no_by(s: str) -> str:
|
| 1072 |
+
n = _norm_tag_for_lookup(s)
|
| 1073 |
+
return n[3:] if n.startswith("by_") else n
|
| 1074 |
+
|
| 1075 |
+
suggested_artist_tags_filtered = OrderedDict(
|
| 1076 |
+
(k, v) for k, v in suggested_tags_filtered.items()
|
| 1077 |
+
if is_artist(_norm_no_by(k))
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
suggested_non_artist_tags_filtered = OrderedDict(
|
| 1081 |
+
(k, v) for k, v in suggested_tags_filtered.items()
|
| 1082 |
+
if not is_artist(_norm_no_by(k)) and k not in special_tags
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
|
| 1086 |
topnsuggestions = list(islice(suggested_non_artist_tags_filtered.items(), 100))
|
| 1087 |
suggested_tags_html_content += create_html_tables_for_tags("-", "Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
|
| 1088 |
+
suggested_tags_html_content += "</div>"
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
# --- Artist stuff: query artist-only index directly ---
|
| 1092 |
+
idf_vec = tf_idf_components['idf']
|
| 1093 |
+
t2c = tf_idf_components['tag_to_column_index']
|
| 1094 |
+
svd = tf_idf_components['svd_model']
|
| 1095 |
+
pseudo_terms = dict(Counter(terms))
|
| 1096 |
+
pseudo_vec = construct_pseudo_vector(pseudo_terms, idf_vec, t2c)
|
| 1097 |
+
reduced_q = svd.transform(pseudo_vec)
|
| 1098 |
+
|
| 1099 |
+
K_art = max(100, top_n * 10) # widen search to stabilize ranks
|
| 1100 |
+
art_inds, art_sims = _ann_artists_topk(reduced_q, k=K_art)
|
| 1101 |
+
|
| 1102 |
+
row_to_tag = tf_idf_components['row_to_tag']
|
| 1103 |
+
bad_labels = {"by_unknown_artist", "by_conditional_dnp", "unknown_artist", "conditional_dnp"}
|
| 1104 |
+
|
| 1105 |
+
top_artists_raw = []
|
| 1106 |
+
for idx_i, sim in zip(art_inds, art_sims):
|
| 1107 |
+
tag = row_to_tag.get(int(idx_i), "")
|
| 1108 |
+
if not tag:
|
| 1109 |
+
continue
|
| 1110 |
+
|
| 1111 |
+
# Normalize spaces to underscores for reliable checks
|
| 1112 |
+
norm = tag.replace(" ", "_")
|
| 1113 |
+
|
| 1114 |
+
# Drop known non-artist placeholders
|
| 1115 |
+
if norm in bad_labels:
|
| 1116 |
+
continue
|
| 1117 |
|
| 1118 |
+
# Accept either "by_foo" or plain "foo"
|
| 1119 |
+
base = norm[3:] if norm.startswith("by_") else norm
|
| 1120 |
+
|
| 1121 |
+
# Guard: only keep if this *really* is an artist we know
|
| 1122 |
+
if not is_artist(base):
|
| 1123 |
+
continue
|
| 1124 |
+
|
| 1125 |
+
name_disp = base.replace("_", " ")
|
| 1126 |
+
top_artists_raw.append((name_disp, float(sim)))
|
| 1127 |
+
|
| 1128 |
+
if not top_artists_raw:
|
| 1129 |
+
logging.debug("No artist hits. First few neighbor labels: %s",
|
| 1130 |
+
[row_to_tag.get(int(i), "") for i in art_inds[:10]])
|
| 1131 |
+
|
| 1132 |
+
# take the best unique names, in order
|
| 1133 |
+
seen = set()
|
| 1134 |
+
deduped = []
|
| 1135 |
+
for n, s in top_artists_raw:
|
| 1136 |
+
if n not in seen:
|
| 1137 |
+
deduped.append((n, s))
|
| 1138 |
+
seen.add(n)
|
| 1139 |
+
if len(deduped) >= top_n:
|
| 1140 |
+
break
|
| 1141 |
+
|
| 1142 |
+
top_artists = deduped
|
| 1143 |
+
logging.debug("Top artists (n=%d): %s", len(top_artists), top_artists)
|
| 1144 |
top_artists_str = create_top_artists_table(top_artists)
|
| 1145 |
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
|
| 1146 |
+
dynamic_prompts_formatted_artists = "{" + "|".join(
|
| 1147 |
+
[escape_parens_for_display(artist) for artist, _ in top_artists]
|
| 1148 |
+
) + "}"
|
| 1149 |
|
| 1150 |
image_galleries = []
|
| 1151 |
for root, dirs, files in os.walk(sample_images_directory_path):
|
| 1152 |
for name in dirs:
|
| 1153 |
+
baseline, artists = generate_artist_image_tuples([name for name, _ in top_artists], os.path.join(root, name))
|
| 1154 |
+
dir_path = os.path.join(root, name)
|
| 1155 |
+
baseline, artists = generate_artist_image_tuples([n for n, _ in top_artists], dir_path)
|
| 1156 |
+
logging.debug("Gallery built for %s -> baseline=%d, artists_found=%d", dir_path, len(baseline), len(artists))
|
| 1157 |
image_galleries.append(baseline) # Add baseline as its own gallery item
|
| 1158 |
image_galleries.append(artists) # Extend the list with artist tuples
|
| 1159 |
|
| 1160 |
return (unseen_tags_data, bad_tags_illustrated_html, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
|
| 1161 |
+
except ParseError:
|
| 1162 |
+
# Build empty galleries so the tuple length matches the declared outputs
|
| 1163 |
+
empty_galleries = []
|
| 1164 |
+
for _root, _dirs, _files in os.walk(sample_images_directory_path):
|
| 1165 |
+
for _ in _dirs:
|
| 1166 |
+
empty_galleries.extend([[], []]) # one empty list per Gallery component
|
| 1167 |
+
|
| 1168 |
+
return (
|
| 1169 |
+
create_html_placeholder(title="Unknown Tags", content="Parse Error"),
|
| 1170 |
+
"Parse Error: Check for mismatched parentheses or something",
|
| 1171 |
+
create_html_placeholder(title="Suggested Tags"),
|
| 1172 |
+
"", # top_artists
|
| 1173 |
+
"", # dynamic_prompts
|
| 1174 |
+
*empty_galleries,
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
|
| 1178 |
|
| 1179 |
with gr.Blocks(css=css) as app:
|
| 1180 |
with gr.Group():
|
| 1181 |
with gr.Row():
|
| 1182 |
+
with gr.Column(scale=3, elem_classes=["prompt-col"]):
|
| 1183 |
+
image_tags = gr.Textbox(
|
| 1184 |
+
label="Enter Prompt",
|
| 1185 |
+
placeholder="e.g. fox, outside, detailed background, ...",
|
| 1186 |
+
lines=1 # Enter submits (see .submit() below)
|
| 1187 |
+
)
|
| 1188 |
bad_tags_illustrated_string = gr.HTML()
|
| 1189 |
with gr.Column(scale=1):
|
| 1190 |
gr.HTML(
|
|
|
|
| 1199 |
with gr.Group():
|
| 1200 |
with gr.Row():
|
| 1201 |
context_similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Context Similarity Weight")
|
| 1202 |
+
allow_nsfw = gr.Checkbox(label="Allow NSFW Tag Suggestions", value=False)
|
| 1203 |
with gr.Row():
|
| 1204 |
with gr.Column(scale=2):
|
| 1205 |
+
unseen_tags = gr.HTML(
|
| 1206 |
+
label="Unknown Tags",
|
| 1207 |
+
value=create_html_placeholder(title="Unknown Tags"),
|
| 1208 |
+
elem_id="unseen_html",
|
| 1209 |
+
elem_classes=["scroll-fade"],
|
| 1210 |
+
)
|
| 1211 |
with gr.Column(scale=1):
|
| 1212 |
+
suggested_tags = gr.HTML(
|
| 1213 |
+
label="Suggested Tags",
|
| 1214 |
+
value=create_html_placeholder(title="Suggested Tags"),
|
| 1215 |
+
elem_id="suggested_html",
|
| 1216 |
+
elem_classes=["scroll-fade"],
|
| 1217 |
+
)
|
| 1218 |
with gr.Column(scale=1):
|
| 1219 |
with gr.Group():
|
| 1220 |
num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
|
| 1221 |
+
top_artists = gr.HTML(
|
| 1222 |
+
label="Top Artists",
|
| 1223 |
+
value=create_html_placeholder(title="Top Artists"),
|
| 1224 |
+
elem_id="artists_html",
|
| 1225 |
+
elem_classes=["scroll-fade"],
|
| 1226 |
+
)
|
| 1227 |
+
gr.HTML("""
|
| 1228 |
+
<script>
|
| 1229 |
+
(function(){
|
| 1230 |
+
function wire(id){
|
| 1231 |
+
const host = document.getElementById(id);
|
| 1232 |
+
if (!host) return;
|
| 1233 |
+
|
| 1234 |
+
// Always use the *inner* .scrollable-content as the scroller
|
| 1235 |
+
const getScroller = () => host.querySelector('.scrollable-content') || host;
|
| 1236 |
+
let scroller = getScroller();
|
| 1237 |
+
|
| 1238 |
+
// Set CSS var so the fade blends with host background
|
| 1239 |
+
const bg = getComputedStyle(host).backgroundColor;
|
| 1240 |
+
host.style.setProperty('--host-bg', bg);
|
| 1241 |
+
|
| 1242 |
+
const refresh = () => {
|
| 1243 |
+
// guard for fractional pixels across browsers
|
| 1244 |
+
const atBottom = Math.ceil(scroller.scrollTop + scroller.clientHeight) >= scroller.scrollHeight;
|
| 1245 |
+
host.classList.toggle('at-bottom', atBottom);
|
| 1246 |
+
};
|
| 1247 |
+
|
| 1248 |
+
// (Re)attach scroll listener to the current scroller
|
| 1249 |
+
const attach = (el) => {
|
| 1250 |
+
if (!el) return;
|
| 1251 |
+
el.addEventListener('scroll', refresh, {passive:true});
|
| 1252 |
+
// initial state
|
| 1253 |
+
refresh();
|
| 1254 |
+
};
|
| 1255 |
+
|
| 1256 |
+
attach(scroller);
|
| 1257 |
+
|
| 1258 |
+
// If Gradio replaces inner HTML, re-wire to new scroller
|
| 1259 |
+
new MutationObserver(() => {
|
| 1260 |
+
const next = getScroller();
|
| 1261 |
+
if (next && next !== scroller) {
|
| 1262 |
+
scroller.removeEventListener && scroller.removeEventListener('scroll', refresh);
|
| 1263 |
+
scroller = next;
|
| 1264 |
+
attach(scroller);
|
| 1265 |
+
}
|
| 1266 |
+
// background might change with themes; keep it fresh
|
| 1267 |
+
const newBg = getComputedStyle(host).backgroundColor;
|
| 1268 |
+
host.style.setProperty('--host-bg', newBg);
|
| 1269 |
+
refresh();
|
| 1270 |
+
}).observe(host, {childList: true, subtree: true});
|
| 1271 |
+
|
| 1272 |
+
// Also respond to resizes
|
| 1273 |
+
new ResizeObserver(refresh).observe(host);
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
['unseen_html','suggested_html','artists_html'].forEach(wire);
|
| 1277 |
+
})();
|
| 1278 |
+
</script>
|
| 1279 |
+
""", visible=False)
|
| 1280 |
+
|
| 1281 |
dynamic_prompts = gr.Textbox(label="Dynamic Prompts Format", info="For if you're using the Automatic1111 webui (https://github.com/AUTOMATIC1111/stable-diffusion-webui) with the Dynamic Prompts extension activated (https://github.com/adieyal/sd-dynamic-prompts) and want to try them all individually.")
|
| 1282 |
galleries = []
|
| 1283 |
for root, dirs, files in os.walk(sample_images_directory_path):
|
| 1284 |
for name in dirs:
|
| 1285 |
with gr.Row():
|
| 1286 |
baseline = gr.Gallery(allow_preview=False, rows=1, columns=1, height=420, scale=3)
|
| 1287 |
+
styles = gr.Gallery(allow_preview=False, rows=2, columns=5, height=420, scale=8)
|
| 1288 |
galleries.extend([baseline, styles])
|
| 1289 |
|
| 1290 |
submit_button.click(
|
|
|
|
| 1292 |
inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
|
| 1293 |
outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
|
| 1294 |
)
|
| 1295 |
+
# Also run when pressing Enter in the prompt box
|
| 1296 |
+
image_tags.submit(
|
| 1297 |
+
find_similar_artists,
|
| 1298 |
+
inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
|
| 1299 |
+
outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
|
| 1303 |
gr.Markdown(faq_content)
|
| 1304 |
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ gradio==4.44.1
|
|
| 2 |
gradio-client==1.3.0
|
| 3 |
fastapi==0.116.1
|
| 4 |
starlette==0.47.3
|
| 5 |
-
|
| 6 |
numpy==1.25.1
|
| 7 |
scikit-learn==1.4.1.post1
|
| 8 |
h5py==3.8.0
|
|
|
|
| 2 |
gradio-client==1.3.0
|
| 3 |
fastapi==0.116.1
|
| 4 |
starlette==0.47.3
|
| 5 |
+
hnswlib==0.8.0
|
| 6 |
numpy==1.25.1
|
| 7 |
scikit-learn==1.4.1.post1
|
| 8 |
h5py==3.8.0
|