Food Desert commited on
Commit
1136048
·
1 Parent(s): 41d10ff

Update UI (scroll cues, tooltips, model-specific tags, enter-to-run) and fixes

Browse files
Files changed (3) hide show
  1. .gitignore +6 -0
  2. app.py +689 -122
  3. requirements.txt +1 -1
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.log
5
+ *.tmp
6
+ .DS_Store
app.py CHANGED
@@ -23,9 +23,28 @@ import itertools
23
  from itertools import islice
24
  from pathlib import Path
25
  import logging
 
 
 
26
 
27
  # Set up logging
28
- logging.basicConfig(filename='error.log', level=logging.DEBUG, format='%(asctime)s %(levelname)s:%(message)s')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  faq_content="""
@@ -35,7 +54,7 @@ faq_content="""
35
 
36
  Since Stable Diffusion's initial release in 2022, users have developed a myriad of fine-tuned text to image models, each with unique "linguistic" preferences depending on the data from which it was fine-tuned.
37
  Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
38
- This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, Fluffusion, and Pony Diffusion v6 were trained.
39
 
40
  When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
41
  If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
@@ -115,16 +134,123 @@ Each subsequent row of images was generated using the same process, but with a d
115
  See SamplePrompts.csv for the list of prompts used and their descriptions.
116
  """
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
120
 
121
- css = """
122
- .scrollable-content {
123
- max-height: 500px;
124
- overflow-y: auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  """
127
 
 
 
 
128
  grammar=r"""
129
  !start: (prompt | /[][():]/+)*
130
  prompt: (emphasized | plain | comma | WHITESPACE)*
@@ -139,6 +265,127 @@ plain: /([^,\\\[\]():|]|\\.)+/
139
  # Initialize the parser
140
  parser = Lark(grammar, start='start')
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # Function to extract tags
143
  def extract_tags(tree):
144
  tags_with_positions = []
@@ -160,7 +407,38 @@ def remove_special_tags(original_string):
160
  remaining_tags = [tag for tag in tags if tag not in special_tags]
161
  removed_tags = [tag for tag in tags if tag in special_tags]
162
  return ", ".join(remaining_tags), removed_tags
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # Define a function to load all necessary components
166
  def load_model_components(file_path):
@@ -182,7 +460,15 @@ def load_model_components(file_path):
182
 
183
  # Load all components at the start
184
  tf_idf_components = load_model_components('tf_idf_files_420.joblib')
185
-
 
 
 
 
 
 
 
 
186
 
187
  nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
188
  # Open and read the CSV file
@@ -218,18 +504,41 @@ def is_artist(name):
218
  sample_images_directory_path = 'sampleimages'
219
  def generate_artist_image_tuples(top_artists, image_directory):
220
  json_files = glob.glob(f'{image_directory}/*.json')
221
- json_file_path = json_files[0] if json_files else None
 
 
222
  with open(json_file_path, 'r') as json_file:
223
  artist_to_file_map = json.load(json_file)
224
-
 
 
 
 
 
 
 
 
 
 
 
225
  filename = artist_to_file_map.get("")
226
- image_path = os.path.join(image_directory, filename)
227
- if os.path.exists(image_path):
228
- baseline_tuple = [(image_path, "No Artist")]
229
-
 
230
  artist_image_tuples = []
231
  for artist in top_artists:
232
  filename = artist_to_file_map.get(artist)
 
 
 
 
 
 
 
 
 
233
  if filename:
234
  image_path = os.path.join(image_directory, filename)
235
  if os.path.exists(image_path):
@@ -321,7 +630,7 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
321
  # Loop through the results and add table rows for each
322
  for word, sim in word_similarity_tuples:
323
  word_with_underscores = word.replace(' ', '_')
324
- word_with_escaped_parentheses = word.replace("\\(", "(").replace("\\)", ")").replace("(", "\\(").replace(")", "\\)")
325
  count = tag2count.get(word_with_underscores.replace("\\(", "(").replace("\\)", ")"), 0) # Get the count if available, otherwise default to 0
326
  tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
327
  # Check if tag_id and wiki_entry are valid
@@ -329,7 +638,10 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
329
  # Construct the URL for the tag's wiki page
330
  wiki_url = f"https://e621.net/wiki_pages/{tag_id}"
331
  # Make the tag a hyperlink with a tooltip
332
- tag_element = f"<a href='{wiki_url}' target='_blank' title='{wiki_entry}'>{word_with_escaped_parentheses}</a>"
 
 
 
333
  else:
334
  # Display the word without any hyperlink or tooltip
335
  tag_element = word_with_escaped_parentheses
@@ -341,36 +653,34 @@ def create_html_tables_for_tags(subtable_heading, item_heading, word_similarity_
341
 
342
 
343
  def create_top_artists_table(top_artists):
344
- # Add a heading above the table
345
  html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
346
- html_str += "<h1>Top Artists</h1>" # Heading for the table
347
- # Start the table with increased font size and no borders between rows
348
  html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
349
  html_str += "<thead><tr><th>Artist</th><th>Similarity</th></tr></thead><tbody>"
350
- # Loop through the top artists and add a row for each without the rank and without borders between rows
351
  for artist, score in top_artists:
352
- artist_name = artist[3:] if artist.startswith("by ") else artist # Remove "by " prefix
353
- similarity_percentage = "{:.1f}%".format(score * 100) # Convert score to percentage string with one decimal
354
- html_str += f"<td style='padding: 3px 20px; border: none;'>{artist_name}</td><td style='padding: 3px 20px; border: none;'>{similarity_percentage}</td></tr>"
 
 
 
355
 
356
- # Close the table HTML
357
  html_str += "</tbody></table></div>"
358
-
359
  return html_str
360
 
361
 
362
- def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
363
- # Initialize a vector of zeros with the length of the term_to_index mapping
364
- pseudo_vector = np.zeros(len(tag_to_row_loaded))
365
-
366
- # Fill in the vector for terms in the pseudo document
367
- for term in pseudo_doc_terms:
368
- if term in tag_to_row_loaded:
369
- index = tag_to_row_loaded[term]
370
- pseudo_vector[index] = idf_loaded.get(term, 0)
371
-
372
- # Return the vector as a 2D array for compatibility with SVD transform
373
- return pseudo_vector.reshape(1, -1)
374
 
375
 
376
  def get_top_indices(reduced_pseudo_vector, reduced_matrix):
@@ -388,36 +698,32 @@ def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
388
  idf = tf_idf_components['idf']
389
  term_to_column_index = tf_idf_components['tag_to_column_index']
390
  row_to_tag = tf_idf_components['row_to_tag']
391
- reduced_matrix = tf_idf_components['reduced_matrix']
392
  svd = tf_idf_components['svd_model']
393
 
394
- # Construct the TF-IDF vector
395
  pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
 
396
 
397
- # Reduce the dimensionality of the pseudo-document vector for the reduced matrix
398
- reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector)
399
-
400
- # Compute cosine similarities in the reduced space
401
- cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
402
 
403
- # Sort the indices by descending cosine similarity
404
- top_indices_reduced = np.argsort(cosine_similarities_reduced)
405
-
406
- # Map indices to tags with their similarities
407
- tag_similarity_dict = {row_to_tag[i]: cosine_similarities_reduced[i] for i in top_indices_reduced if i in row_to_tag}
 
408
 
409
  if not allow_nsfw_tags:
410
- tag_similarity_dict = {tag: sim for tag, sim in tag_similarity_dict.items() if tag not in nsfw_tags}
411
-
412
- tag_similarity_dict = {"by " + tag if is_artist(tag) else tag: sim for tag, sim in tag_similarity_dict.items()}
413
 
414
- # Sort and transform tag names
415
  sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
416
  transformed_sorted_tag_similarity_dict = OrderedDict(
417
- (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), value)
418
- for key, value in sorted_tag_similarity_dict.items()
419
  )
420
-
421
  return transformed_sorted_tag_similarity_dict
422
 
423
 
@@ -463,22 +769,62 @@ def find_similar_tags(test_tags, tag_to_context_similarity, context_similarity_w
463
  end_pos = tag_info['end_pos']
464
  node_type = tag_info['node_type']
465
 
 
 
 
466
  if modified_tag in special_tags:
467
  bad_entities.append({"entity":"Special", "start":start_pos, "end":end_pos})
468
  continue
469
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  if modified_tag in encountered_modified_tags:
471
  bad_entities.append({"entity":"Duplicate", "start":start_pos, "end":end_pos})
472
  continue
473
  encountered_modified_tags.add(modified_tag)
474
-
475
- modified_tag_for_search = modified_tag.replace(' ','_')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
477
  result, seen = [], set(transformed_tags)
478
 
479
  if modified_tag_for_search in find_similar_tags.tag2aliases:
480
  if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
481
- result.append(modified_tag_for_search.replace('_',' '), 1)
482
  seen.add(modified_tag)
483
  else: #The user correctly did not put underscores in their tag
484
  count = find_similar_tags.tag2count.get(modified_tag_for_search, 0) # Get the count if available, otherwise default to 0
@@ -503,27 +849,47 @@ def find_similar_tags(test_tags, tag_to_context_similarity, context_similarity_w
503
  result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
504
  seen.add(similar_tag)
505
 
506
- #Remove NSFW tags if appropriate.
507
  if not allow_nsfw_tags:
508
- result = [(word, score) for word, score in result if word.replace(' ','_') not in nsfw_tags]
509
-
510
- #Adjust score based on context
511
- for i in range(len(result)):
512
- word, score = result[i] # Unpack the tuple
513
- context_score = tag_to_context_similarity.get(word,0)
514
- result[i] = (word, .5 * ((context_similarity_weight * context_score) + ((1 - context_similarity_weight) * score)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
 
517
  html_content += create_html_tables_for_tags(modified_tag, "Corrected Tag", result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
518
 
519
  bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
520
 
521
  tags_added=True
522
- # If no tags were processed, add a message
523
  if not tags_added:
524
  html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
 
 
 
 
525
 
526
- return html_content, bad_entities, known_entities_in_prompt # Return list of lists for Dataframe
527
 
528
 
529
  def build_tag_offsets_dicts(new_image_tags_with_positions):
@@ -581,57 +947,91 @@ def augment_bad_entities_with_regex(text):
581
  def escape_html(text):
582
  return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;").replace("'", "&#039;")
583
 
 
 
 
 
 
 
 
 
 
584
  def format_annotated_html(bad_entities, known_entities, text):
585
  tooltip_map = {
586
  "Unknown Tag": "This may not be a valid e621 tag. Consider removing or replacing it with tag(s) from the \"Unknown Tags\" section.",
587
  "Duplicate": "This tag has appeared multiple times in your prompt. Consider removing the copies.",
588
  "Remove Final Comma": "There should be no comma at the end of your prompt. Consider removing it.",
589
  "Move Comma Inside Parentheses": "In most e621-based models, the comma following a tag functions as an &quot;attention anchor&quot;, carrying most of the tag&apos;s information. It should therefore be assigned the same weight as the rest of the tag. So instead of &quot;(lineless:1.1),&quot;, consider &quot;(lineless,:1.1)&quot; or &quot;(lineless,)&quot;",
590
- "Double Comma": "One comma between tags is considered ample."
 
591
  }
592
  color_map = {
593
- "Unknown Tag": ("white", "red"), # White text on red background
594
- "Duplicate": ("black", "yellow"), # Black text on yellow background
595
- "Remove Final Comma": ("white", "blue"), # White text on blue background
596
- "Move Comma Inside Parentheses": ("white", "green"), # White text on green background
597
- "Double Comma": ("white","orange")
 
598
  }
599
-
600
- # Combine and sort entities
601
- combined_entities = bad_entities + known_entities
602
- combined_entities = sorted(combined_entities, key=lambda x: x['start'],reverse=True)
603
-
604
- # Generate HTML for the main text
605
  html_text = text
606
- for entity in combined_entities:
607
- start = entity['start']
608
- end = entity['end']
609
- label = entity['entity']
 
 
 
 
 
 
610
  if label == "Known Tag":
611
- wiki_url = entity.get('wiki_url', '')
612
- count = entity['count']
613
- wiki_entry = entity.get('wiki_entry', '')
614
- sanitized_wiki_entry = escape_html(wiki_entry) if wiki_entry else 'Unavailable'
615
- if wiki_url: # Check if wiki_url is not empty
616
- html_part = f'<a href="{wiki_url}" target="_blank" title="Count: {count}\tWiki: {sanitized_wiki_entry}" style="text-decoration: none; cursor: pointer; font-style: italic;">{text[start:end]}</a>'
 
 
 
 
 
617
  else:
618
- html_part = f'<span title="Count: {count}\tWiki: {sanitized_wiki_entry}" style="text-decoration: none; cursor: help; font-style: italic;">{text[start:end]}</span>'
 
 
 
619
  else:
620
- color = color_map.get(label, ("black", "white"))
621
- html_part = f'<span style="background-color: {color[1]}; color: {color[0]};">{text[start:end]}</span>'
 
622
  html_text = html_text[:start] + html_part + html_text[end:]
623
-
624
- # Generate HTML for the color key
625
  color_key_html = "<div style='text-align: right; margin-top: 20px;'>Key:"
626
- used_labels = set(entity['entity'] for entity in bad_entities)
627
- for label, colors in color_map.items():
628
  if label in used_labels:
629
  tooltip = tooltip_map.get(label, "")
630
- # Adding margin-right for spacing between items
631
- color_key_html += f" <span style='background-color: {colors[1]}; color: {colors[0]}; margin-right: 10px;' title='{tooltip}'>{label}</span>"
 
 
632
  color_key_html += "</div>"
 
 
 
 
 
 
 
 
 
 
633
 
634
- return f'<div style="padding: 10px; font-size: 16px;">{html_text}</div>{color_key_html}'
635
 
636
 
637
  def find_similar_artists(original_tags_string, top_n, context_similarity_weight, allow_nsfw_tags):
@@ -648,8 +1048,9 @@ def find_similar_artists(original_tags_string, top_n, context_similarity_weight,
648
  #Suggested tags stuff
649
  suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
650
  suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
651
- suggested_tags = get_tfidf_reduced_similar_tags([item["tf_idf_matrix_tag"] for item in tag_data] + removed_tags, allow_nsfw_tags)
652
-
 
653
  unseen_tags_data, bad_entities, known_entities = find_similar_tags(tag_data, suggested_tags, context_similarity_weight, allow_nsfw_tags)
654
 
655
  #Bad tags stuff
@@ -660,40 +1061,130 @@ def find_similar_artists(original_tags_string, top_n, context_similarity_weight,
660
 
661
  # Create a set of tags that should be filtered out
662
  filter_tags = {entry["original_tag"].strip() for entry in tag_data}
663
- # Use this set to filter suggested_tags
664
- suggested_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags.items() if k not in filter_tags)
 
 
 
665
 
666
  # Splitting the dictionary into two based on the condition
667
- suggested_artist_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags_filtered.items() if k.startswith("by "))
668
- suggested_non_artist_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags_filtered.items() if not k.startswith("by ") and k not in special_tags)
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
  topnsuggestions = list(islice(suggested_non_artist_tags_filtered.items(), 100))
671
  suggested_tags_html_content += create_html_tables_for_tags("-", "Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
- #Artist stuff
674
- excluded_artists = ["by conditional dnp", "by unknown artist"]
675
- top_artists = [(key, value) for key, value in suggested_artist_tags_filtered.items() if key.lower() not in excluded_artists][:top_n]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  top_artists_str = create_top_artists_table(top_artists)
677
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
 
 
 
678
 
679
  image_galleries = []
680
  for root, dirs, files in os.walk(sample_images_directory_path):
681
  for name in dirs:
682
- baseline, artists = generate_artist_image_tuples([name[3:] for name, _ in top_artists], os.path.join(root, name))
 
 
 
683
  image_galleries.append(baseline) # Add baseline as its own gallery item
684
  image_galleries.append(artists) # Extend the list with artist tuples
685
 
686
  return (unseen_tags_data, bad_tags_illustrated_html, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
687
- except ParseError as e:
688
- return [], "Parse Error: Check for mismatched parentheses or something", "", "", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
 
691
  with gr.Blocks(css=css) as app:
692
  with gr.Group():
693
  with gr.Row():
694
- with gr.Column(scale=3):
695
- image_tags = gr.Textbox(label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, ...")
696
- #bad_tags_illustrated_string = gr.HighlightedText(show_legend=True, color_map={"Unknown Tag":"red","Duplicate":"yellow","Remove Final Comma":"purple","Move Comma Inside Parentheses":"green"}, label="Annotated Prompt")
 
 
 
697
  bad_tags_illustrated_string = gr.HTML()
698
  with gr.Column(scale=1):
699
  gr.HTML(
@@ -708,23 +1199,92 @@ with gr.Blocks(css=css) as app:
708
  with gr.Group():
709
  with gr.Row():
710
  context_similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Context Similarity Weight")
711
- allow_nsfw = gr.Checkbox(label="Allow NSFW Tags", value=False)
712
  with gr.Row():
713
  with gr.Column(scale=2):
714
- unseen_tags = gr.HTML(label="Unknown Tags", value=create_html_placeholder(title="Unknown Tags"))
 
 
 
 
 
715
  with gr.Column(scale=1):
716
- suggested_tags = gr.HTML(label="Suggested Tags", value=create_html_placeholder(title="Suggested Tags"))
 
 
 
 
 
717
  with gr.Column(scale=1):
718
  with gr.Group():
719
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
720
- top_artists = gr.HTML(label="Top Artists", value=create_html_placeholder(title="Top Artists"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  dynamic_prompts = gr.Textbox(label="Dynamic Prompts Format", info="For if you're using the Automatic1111 webui (https://github.com/AUTOMATIC1111/stable-diffusion-webui) with the Dynamic Prompts extension activated (https://github.com/adieyal/sd-dynamic-prompts) and want to try them all individually.")
722
  galleries = []
723
  for root, dirs, files in os.walk(sample_images_directory_path):
724
  for name in dirs:
725
  with gr.Row():
726
  baseline = gr.Gallery(allow_preview=False, rows=1, columns=1, height=420, scale=3)
727
- styles = gr.Gallery(preview=False, rows=2, columns=5, height=420, scale=8)
728
  galleries.extend([baseline, styles])
729
 
730
  submit_button.click(
@@ -732,6 +1292,13 @@ with gr.Blocks(css=css) as app:
732
  inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
733
  outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
734
  )
 
 
 
 
 
 
 
735
 
736
  gr.Markdown(faq_content)
737
 
 
23
  from itertools import islice
24
  from pathlib import Path
25
  import logging
26
+ import hnswlib
27
+ import pathlib
28
+ from collections import Counter
29
 
30
  # Set up logging
31
+ # Minimal prod logging: warnings+ to stderr, no file by default
32
+ import os, logging
33
+
34
+ LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
35
+ logging.basicConfig(
36
+ level=getattr(logging, LOG_LEVEL, logging.WARNING),
37
+ format="%(asctime)s %(levelname)s:%(message)s",
38
+ handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
39
+ )
40
+
41
+ # Quiet down common noisy libs (optional)
42
+ for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
43
+ logging.getLogger(_name).setLevel(logging.ERROR)
44
+
45
+ # Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
46
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
47
+
48
 
49
 
50
  faq_content="""
 
54
 
55
  Since Stable Diffusion's initial release in 2022, users have developed a myriad of fine-tuned text to image models, each with unique "linguistic" preferences depending on the data from which it was fine-tuned.
56
  Some models react best when prompted with verbose scene descriptions akin to DALL-E, while others fine-tuned on images scraped from popular image boards understand those boards' tag sets.
57
+ This tool serves as a linguistic bridge to the e621 image board tag lexicon, on which many popular models such as Fluffyrock, NoobAI, and Pony Diffusion v6 were trained.
58
 
59
  When you enter a txt2img prompt and press the "submit" button, Prompt Squirrel parses your prompt and checks that all your tags are valid e621 tags.
60
  If it finds any that are not, it recommends some valid e621 tags you can use to replace them in the "Unknown Tags" section.
 
134
  See SamplePrompts.csv for the list of prompts used and their descriptions.
135
  """
136
 
137
+ TOOLTIP_NOTE_HTML = '<div class="hover-hint">Underlined items can be hovered for more info.</div>'
138
+
139
+ HOVER_HINT_CSS = """
140
+ /* Solid, visible underline for tagged items */
141
+ .gradio-container .hover-underline{
142
+ text-decoration-line: underline !important;
143
+ text-decoration-thickness: 2px;
144
+ text-underline-offset: 2px;
145
+ }
146
+
147
+ /* Small, subtle hint text */
148
+ .hover-hint{
149
+ font-size: 12px;
150
+ opacity: .85;
151
+ line-height: 1.2;
152
+ }
153
+
154
+ /* Wrapper to position the hint in the bottom-right of the annotated box */
155
+ .annotated-wrap{ position: relative; }
156
+ .annotated-wrap .hover-hint{
157
+ position: absolute;
158
+ right: 6px;
159
+ bottom: 6px;
160
+ text-align: right;
161
+ }
162
+ """
163
+
164
+
165
+
166
+ try:
167
+ from gradio_client import utils as _gc_utils
168
+
169
+ _orig_get_type = _gc_utils.get_type
170
+ _orig_j2p = _gc_utils._json_schema_to_python_type
171
+ _orig_pub = _gc_utils.json_schema_to_python_type
172
+
173
+ def _get_type_safe(schema):
174
+ # Sometimes schema is a bare True/False (JSON Schema boolean form)
175
+ if not isinstance(schema, dict):
176
+ return "any"
177
+ return _orig_get_type(schema)
178
+
179
+ def _j2p_safe(schema, defs=None):
180
+ # Accept non-dict schemas (True/False/None) and treat as "any"
181
+ if not isinstance(schema, dict):
182
+ return "any"
183
+ return _orig_j2p(schema, defs or schema.get("$defs"))
184
+
185
+ def _pub_safe(schema):
186
+ # Public wrapper used by Gradio; keep it resilient too
187
+ if not isinstance(schema, dict):
188
+ return "any"
189
+ return _j2p_safe(schema, schema.get("$defs"))
190
+
191
+ _gc_utils.get_type = _get_type_safe
192
+ _gc_utils._json_schema_to_python_type = _j2p_safe
193
+ _gc_utils.json_schema_to_python_type = _pub_safe
194
+
195
+ except Exception as e:
196
+ print("gradio_client hotfix not applied:", e)
197
+ # -------------------------------------------------------------------------------
198
+
199
 
200
  nsfw_threshold = 0.95 # Assuming the threshold value is defined here
201
 
202
+ css = HOVER_HINT_CSS + """
203
+ .scrollable-content{
204
+ max-height: 420px;
205
+ overflow-y: scroll; /* always show scrollbar */
206
+ overflow-x: hidden;
207
+ padding-right: 8px;
208
+ padding-bottom: 14px; /* <— add this */
209
+ scrollbar-gutter: stable; /* prevent layout shift as it fills */
210
+
211
+ /* Firefox */
212
+ scrollbar-width: auto;
213
+ scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
214
+ }
215
+
216
+ /* WebKit/Chromium (Chrome/Edge/Safari) */
217
+ .scrollable-content::-webkit-scrollbar{ width: 10px; }
218
+ .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
219
+ .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
220
+
221
+ /* --- Fade that blends into the pane background, no chip --- */
222
+ .scroll-fade {
223
+ position: relative;
224
+ /* ensure our ::after overlay paints above children */
225
+ isolation: isolate;
226
  }
227
+
228
+ .scroll-fade::after{
229
+ content: "";
230
+ position: absolute;
231
+ left: 0; right: 0; bottom: 0;
232
+ height: 20px; /* a hair taller; tweak if you like */
233
+ pointer-events: none;
234
+ /* transparent → panel background (Gradio theme var, with dark fallback) */
235
+ background: linear-gradient(
236
+ to bottom,
237
+ rgba(0,0,0,0),
238
+ var(--background-fill-secondary, #1f2937)
239
+ );
240
+ transition: opacity .18s ease;
241
+ z-index: 3; /* sit above the scroller’s content */
242
+ }
243
+
244
+ .scroll-fade.at-bottom::after { opacity: 0; }
245
+
246
+ /* no chip */
247
+ .scroll-fade::before { content: none; }
248
+
249
  """
250
 
251
+
252
+
253
+ #Parser
254
  grammar=r"""
255
  !start: (prompt | /[][():]/+)*
256
  prompt: (emphasized | plain | comma | WHITESPACE)*
 
265
  # Initialize the parser
266
  parser = Lark(grammar, start='start')
267
 
268
+ # ---------- Two HNSW indexes: artists and non-artist tags ----------
269
+ _HNSW_ART = None
270
+ _HNSW_TAG = None
271
+ _HNSW_DIM = None
272
+ _HNSW_N_ART = None
273
+ _HNSW_N_TAG = None
274
+ _HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
275
+ _HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
276
+
277
+ def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray:
278
+ mat = np.asarray(mat, dtype=np.float32)
279
+ norms = np.linalg.norm(mat, axis=1, keepdims=True)
280
+ norms[norms == 0.0] = 1.0
281
+ return mat / norms
282
+
283
+ def _ensure_dual_hnsw_indexes():
284
+ """
285
+ Build/load two HNSW indexes over the SVD-reduced TF-IDF matrix:
286
+ • _HNSW_ART — rows whose tag (with optional 'by_' stripped) is in the artist_set
287
+ • _HNSW_TAG — only rows that are NOT artist tags
288
+ Index item IDs are the ORIGINAL row indices in reduced_matrix.
289
+ """
290
+ global _HNSW_ART, _HNSW_TAG, _HNSW_DIM, _HNSW_N_ART, _HNSW_N_TAG
291
+
292
+ if _HNSW_ART is not None and _HNSW_TAG is not None:
293
+ return
294
+
295
+ reduced_matrix = tf_idf_components['reduced_matrix'] # (N, D)
296
+ row_to_tag = tf_idf_components['row_to_tag'] # {row:int -> "tag_with_underscores"}
297
+ rm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
298
+ n_items, dim = rm.shape
299
+
300
+ # Partition rows
301
+ artist_rows = []
302
+ tag_rows = []
303
+
304
+ for i in range(n_items):
305
+ tag = row_to_tag.get(i, "")
306
+
307
+ # Strip leading "by_" if present in the TF-IDF vocabulary, but don't rely on it.
308
+ base = tag[3:] if tag.startswith("by_") else tag
309
+
310
+ # Some corpora contain buckets you don't want shown as artists:
311
+ if tag in {"by_unknown_artist", "by_conditional_dnp"}:
312
+ tag_rows.append(i)
313
+ continue
314
+
315
+ if is_artist(base):
316
+ artist_rows.append(i)
317
+ else:
318
+ tag_rows.append(i)
319
+
320
+ logging.debug(f"HNSW partition: artists={len(artist_rows)} non_artists={len(tag_rows)}")
321
+
322
+ # Helper: build or load an index for a subset of rows
323
+ def _build_or_load(path: pathlib.Path, rows: list[int]) -> hnswlib.Index:
324
+ idx = hnswlib.Index(space='cosine', dim=dim)
325
+ need_build = True
326
+ if path.exists():
327
+ try:
328
+ idx.load_index(str(path), max_elements=max(1, len(rows)))
329
+ # Rebuild if the saved index count doesn’t match our rows
330
+ if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0:
331
+ need_build = False
332
+ else:
333
+ logging.debug(f"Rebuilding {path.name}: saved_count!=rows_len ({idx.get_current_count()} vs {len(rows)})")
334
+ except Exception as e:
335
+ logging.debug(f"Reload {path.name} failed, rebuilding: {e}")
336
+
337
+ if need_build:
338
+ try:
339
+ if path.exists():
340
+ path.unlink()
341
+ except Exception:
342
+ pass
343
+ idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16)
344
+ if rows:
345
+ idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32))
346
+ idx.save_index(str(path))
347
+
348
+ idx.set_ef(200)
349
+ return idx
350
+
351
+
352
+ _HNSW_ART = _build_or_load(_HNSW_ART_PATH, artist_rows)
353
+ _HNSW_TAG = _build_or_load(_HNSW_TAG_PATH, tag_rows)
354
+ _HNSW_DIM = dim
355
+ _HNSW_N_ART = len(artist_rows)
356
+ _HNSW_N_TAG = len(tag_rows)
357
+
358
+ def _hnsw_query(idx: hnswlib.Index, vec: np.ndarray, k: int):
359
+ """
360
+ Query a given HNSW index with a (1, D) or (D,) vector in SVD space.
361
+ Returns (indices, sims) with cosine similarity scores.
362
+ """
363
+ _ensure_dual_hnsw_indexes()
364
+ q = np.asarray(vec, dtype=np.float32).reshape(-1)
365
+ q_norm = np.linalg.norm(q)
366
+ if q_norm > 0:
367
+ q = q / q_norm
368
+ labels, dists = idx.knn_query(q, k=k)
369
+ inds = labels[0]
370
+ sims = 1.0 - dists[0] # cosine distance -> similarity
371
+ return inds, sims
372
+
373
+ def _ann_tags_topk(vec: np.ndarray, k: int):
374
+ _ensure_dual_hnsw_indexes()
375
+ k = min(k, _HNSW_N_TAG if _HNSW_N_TAG else 0)
376
+ return _hnsw_query(_HNSW_TAG, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
377
+
378
+ def _ann_artists_topk(vec: np.ndarray, k: int):
379
+ _ensure_dual_hnsw_indexes()
380
+ k = min(k, _HNSW_N_ART if _HNSW_N_ART else 0)
381
+ return _hnsw_query(_HNSW_ART, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
382
+ # ------------------------------------------------------------------
383
+
384
+
385
+ def _norm_tag_for_lookup(s: str) -> str:
386
+ # convert "name with spaces" -> "name_with_spaces" and unescape parens
387
+ return s.replace(' ', '_').replace('\\(', '(').replace('\\)', ')')
388
+
389
  # Function to extract tags
390
  def extract_tags(tree):
391
  tags_with_positions = []
 
407
  remaining_tags = [tag for tag in tags if tag not in special_tags]
408
  removed_tags = [tag for tag in tags if tag in special_tags]
409
  return ", ".join(remaining_tags), removed_tags
410
+
411
+ #Model specific tags
412
+ MODEL_SPECIFIC_TAGS = {
413
+ "masterpiece",
414
+ "best quality",
415
+ "good quality",
416
+ "normal quality",
417
+ "newest",
418
+ "absurdres",
419
+ "highres",
420
+ "safe",
421
+ "worst quality",
422
+ "early",
423
+ "low quality",
424
+ "lowres",
425
+ "explict content",
426
+ "very awa",
427
+ "worst aesthetic",
428
+ "score_9",
429
+ "score_8_up",
430
+ "score_7_up",
431
+ "score_6_up",
432
+ "score_5_up",
433
+ "score_4_up",
434
+ "source_pony",
435
+ "source_furry",
436
+ "source_cartoon",
437
+ "source_anime",
438
+ "rating_safe",
439
+ "rating_questionable",
440
+ "rating_explicit"
441
+ }
442
 
443
  # Define a function to load all necessary components
444
  def load_model_components(file_path):
 
460
 
461
  # Load all components at the start
462
  tf_idf_components = load_model_components('tf_idf_files_420.joblib')
463
+ idf = tf_idf_components['idf']
464
+ if isinstance(idf, dict):
465
+ # idf is term -> idf_value; build a column-aligned vector
466
+ t2c = tf_idf_components['tag_to_column_index']
467
+ n_cols = max(t2c.values()) + 1
468
+ idf_by_col = np.ones(n_cols, dtype=np.float32)
469
+ for term, col in t2c.items():
470
+ idf_by_col[col] = float(idf.get(term, 1.0))
471
+ tf_idf_components['idf'] = idf_by_col
472
 
473
  nsfw_tags = set() # Initialize an empty set to store words meeting the threshold
474
  # Open and read the CSV file
 
504
  sample_images_directory_path = 'sampleimages'
505
  def generate_artist_image_tuples(top_artists, image_directory):
506
  json_files = glob.glob(f'{image_directory}/*.json')
507
+ if not json_files:
508
+ return [], [] # no mapping present; return empty galleries safely
509
+ json_file_path = json_files[0]
510
  with open(json_file_path, 'r') as json_file:
511
  artist_to_file_map = json.load(json_file)
512
+ # DEBUG: mapping + baseline info
513
+ logging.debug("Gallery %s: loaded %d entries (map file=%s)",
514
+ image_directory, len(artist_to_file_map), json_file_path)
515
+ _base = artist_to_file_map.get("")
516
+ logging.debug(
517
+ "Gallery %s: baseline '' -> %r (exists=%s)",
518
+ image_directory,
519
+ _base,
520
+ os.path.exists(os.path.join(image_directory, _base)) if _base else None,
521
+ )
522
+
523
+ baseline_tuple = []
524
  filename = artist_to_file_map.get("")
525
+ if filename:
526
+ image_path = os.path.join(image_directory, filename)
527
+ if os.path.exists(image_path):
528
+ baseline_tuple = [(image_path, "No Artist")]
529
+
530
  artist_image_tuples = []
531
  for artist in top_artists:
532
  filename = artist_to_file_map.get(artist)
533
+ # DEBUG: per-artist resolution
534
+ logging.debug(
535
+ "Gallery %s: %s -> %r (exists=%s)",
536
+ image_directory,
537
+ artist,
538
+ filename,
539
+ os.path.exists(os.path.join(image_directory, filename)) if filename else None,
540
+ )
541
+
542
  if filename:
543
  image_path = os.path.join(image_directory, filename)
544
  if os.path.exists(image_path):
 
630
  # Loop through the results and add table rows for each
631
  for word, sim in word_similarity_tuples:
632
  word_with_underscores = word.replace(' ', '_')
633
+ word_with_escaped_parentheses = escape_parens_for_display(word)
634
  count = tag2count.get(word_with_underscores.replace("\\(", "(").replace("\\)", ")"), 0) # Get the count if available, otherwise default to 0
635
  tag_id, wiki_entry = tag2idwiki.get(word_with_underscores, (None, ''))
636
  # Check if tag_id and wiki_entry are valid
 
638
  # Construct the URL for the tag's wiki page
639
  wiki_url = f"https://e621.net/wiki_pages/{tag_id}"
640
  # Make the tag a hyperlink with a tooltip
641
+ tag_element = (
642
+ f"<a class='hover-underline' href='{wiki_url}' target='_blank' "
643
+ f"title='{wiki_entry}'>{word_with_escaped_parentheses}</a>"
644
+ )
645
  else:
646
  # Display the word without any hyperlink or tooltip
647
  tag_element = word_with_escaped_parentheses
 
653
 
654
 
655
  def create_top_artists_table(top_artists):
 
656
  html_str = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
657
+ html_str += "<h1>Top Artists</h1>"
 
658
  html_str += "<table style='font-size: 20px; border-collapse: collapse;'>"
659
  html_str += "<thead><tr><th>Artist</th><th>Similarity</th></tr></thead><tbody>"
660
+
661
  for artist, score in top_artists:
662
+ artist_disp = escape_html(escape_parens_for_display(artist))
663
+ similarity_percentage = "{:.1f}%".format(score * 100)
664
+ html_str += (
665
+ f"<tr><td style='padding: 3px 20px; border: none;'>{artist_disp}</td>"
666
+ f"<td style='padding: 3px 20px; border: none;'>{similarity_percentage}</td></tr>"
667
+ )
668
 
 
669
  html_str += "</tbody></table></div>"
 
670
  return html_str
671
 
672
 
673
+ def construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index):
674
+ cols, data = [], []
675
+ for term, w in pseudo_doc_terms.items():
676
+ j = term_to_column_index.get(term)
677
+ if j is None:
678
+ continue
679
+ cols.append(j)
680
+ data.append(w * idf[j])
681
+ n_cols = len(idf)
682
+ indptr = [0, len(cols)]
683
+ return csr_matrix((data, cols, indptr), shape=(1, n_cols), dtype=np.float32)
 
684
 
685
 
686
  def get_top_indices(reduced_pseudo_vector, reduced_matrix):
 
698
  idf = tf_idf_components['idf']
699
  term_to_column_index = tf_idf_components['tag_to_column_index']
700
  row_to_tag = tf_idf_components['row_to_tag']
 
701
  svd = tf_idf_components['svd_model']
702
 
703
+ # 1) Build the pseudo TF-IDF, reduce to SVD space (unchanged)
704
  pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
705
+ reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector) # shape (1, D)
706
 
707
+ # 2) ANN: only fetch nearest non-artist candidates (no full-matrix cosine)
708
+ K = 2000 # tune for speed/recall
709
+ top_inds, top_sims = _ann_tags_topk(reduced_pseudo_vector, k=K)
 
 
710
 
711
+ # 3) Build similarity dict from those candidates
712
+ tag_similarity_dict = {}
713
+ for i, sim in zip(top_inds, top_sims):
714
+ tag = row_to_tag.get(int(i))
715
+ if tag is not None:
716
+ tag_similarity_dict[tag] = float(sim)
717
 
718
  if not allow_nsfw_tags:
719
+ tag_similarity_dict = {t: s for t, s in tag_similarity_dict.items() if t not in nsfw_tags}
 
 
720
 
721
+ # 4) Sort & escape like before
722
  sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
723
  transformed_sorted_tag_similarity_dict = OrderedDict(
724
+ (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), val)
725
+ for key, val in sorted_tag_similarity_dict.items()
726
  )
 
727
  return transformed_sorted_tag_similarity_dict
728
 
729
 
 
769
  end_pos = tag_info['end_pos']
770
  node_type = tag_info['node_type']
771
 
772
+ # Build the underscore form up-front
773
+ modified_tag_for_search = modified_tag.replace(' ', '_')
774
+
775
  if modified_tag in special_tags:
776
  bad_entities.append({"entity":"Special", "start":start_pos, "end":end_pos})
777
  continue
778
+
779
+ # Only accept exact underscore model-specific tokens (e.g., "score_9")
780
+ # special score/rating tags (kept as-is)
781
+ if modified_tag in special_tags:
782
+ bad_entities.append({"entity": "Special", "start": start_pos, "end": end_pos})
783
+ continue
784
+
785
+ # Model-specific tokens must match the user's input *exactly* (no pre-normalization).
786
+ # Use the original token as typed in the prompt, lowercased.
787
+ original_raw = tag_info["original_tag"].strip().lower()
788
+ if original_raw in MODEL_SPECIFIC_TAGS:
789
+ bad_entities.append({"entity": "Model Specific", "start": start_pos, "end": end_pos})
790
+ continue
791
+
792
+
793
  if modified_tag in encountered_modified_tags:
794
  bad_entities.append({"entity":"Duplicate", "start":start_pos, "end":end_pos})
795
  continue
796
  encountered_modified_tags.add(modified_tag)
797
+
798
+ norm_artist = (
799
+ modified_tag_for_search
800
+ .lower()
801
+ .removeprefix('by_') # tolerate users typing "by something" or not
802
+ )
803
+ if is_artist(norm_artist):
804
+ by_key = f"by_{norm_artist}"
805
+ # try by_* first, then raw form as fallback
806
+ count = (find_similar_tags.tag2count.get(by_key) or
807
+ find_similar_tags.tag2count.get(modified_tag_for_search, 0))
808
+ tag_id, wiki_entry = (
809
+ find_similar_tags.tag2idwiki.get(by_key) or
810
+ find_similar_tags.tag2idwiki.get(modified_tag_for_search, (None, ''))
811
+ )
812
+ wiki_url = f"https://e621.net/wiki_pages/{tag_id}" if tag_id is not None and wiki_entry else ""
813
+ known_entities_in_prompt.append({
814
+ "entity": "Known Tag",
815
+ "start": start_pos,
816
+ "end": end_pos,
817
+ "count": count,
818
+ "wiki_url": wiki_url,
819
+ "wiki_entry": wiki_entry
820
+ })
821
+ continue
822
  similar_words = find_similar_tags.fasttext_small_model.most_similar(modified_tag_for_search, topn = 100)
823
  result, seen = [], set(transformed_tags)
824
 
825
  if modified_tag_for_search in find_similar_tags.tag2aliases:
826
  if modified_tag in find_similar_tags.tag2aliases and "_" in modified_tag: #Implicitly tell the user that they should get rid of the underscore
827
+ result.append((modified_tag_for_search.replace('_',' '), 1))
828
  seen.add(modified_tag)
829
  else: #The user correctly did not put underscores in their tag
830
  count = find_similar_tags.tag2count.get(modified_tag_for_search, 0) # Get the count if available, otherwise default to 0
 
849
  result.append((similar_tag.replace('_', ' '), round(similarity, 3)))
850
  seen.add(similar_tag)
851
 
852
+ # Remove NSFW tags if appropriate.
853
  if not allow_nsfw_tags:
854
+ result = [(w, s) for (w, s) in result if w.replace(' ', '_') not in nsfw_tags]
855
+
856
+ # --- Context re-scoring (keys match how get_tfidf_reduced_similar_tags formats them) ---
857
+ def _ctx_score(name: str) -> float:
858
+ v = tag_to_context_similarity.get(name)
859
+ if v is None:
860
+ # TF-IDF dict escapes parentheses; candidates from FT do not.
861
+ v = tag_to_context_similarity.get(name.replace('(', '\\(').replace(')', '\\)'))
862
+ return float(v) if v is not None else 0.0
863
+
864
+ # If the slider is at 1.0, only keep candidates that exist in the TF-IDF context list.
865
+ if context_similarity_weight >= 0.999:
866
+ ctx_keys = set(tag_to_context_similarity.keys())
867
+ result = [
868
+ (w, s) for (w, s) in result
869
+ if (w in ctx_keys) or (w.replace('(', '\\(').replace(')', '\\)') in ctx_keys)
870
+ ]
871
+
872
+ # Linear blend: final = (1-λ)*fasttext + λ*context (no extra 0.5 scaling)
873
+ result = [
874
+ (w, (1.0 - context_similarity_weight) * s + context_similarity_weight * _ctx_score(w))
875
+ for (w, s) in result
876
+ ]
877
 
878
  result = sorted(result, key=lambda x: x[1], reverse=True)[:10]
879
+
880
  html_content += create_html_tables_for_tags(modified_tag, "Corrected Tag", result, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
881
 
882
  bad_entities.append({"entity":"Unknown Tag", "start":start_pos, "end":end_pos})
883
 
884
  tags_added=True
885
+ # If no tags were processed, add a message; otherwise close the wrapper div
886
  if not tags_added:
887
  html_content = create_html_placeholder(title="Unknown Tags", content="No Unknown Tags Found")
888
+ else:
889
+ html_content += "</div>"
890
+
891
+ return html_content, bad_entities, known_entities_in_prompt
892
 
 
893
 
894
 
895
  def build_tag_offsets_dicts(new_image_tags_with_positions):
 
947
  def escape_html(text):
948
  return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;").replace("'", "&#039;")
949
 
950
+ def escape_parens_for_display(s: str) -> str:
951
+ # ensure single backslash before any literal parens in display text
952
+ return (
953
+ s.replace("\\(", "(")
954
+ .replace("\\)", ")")
955
+ .replace("(", "\\(")
956
+ .replace(")", "\\)")
957
+ )
958
+
959
  def format_annotated_html(bad_entities, known_entities, text):
960
  tooltip_map = {
961
  "Unknown Tag": "This may not be a valid e621 tag. Consider removing or replacing it with tag(s) from the \"Unknown Tags\" section.",
962
  "Duplicate": "This tag has appeared multiple times in your prompt. Consider removing the copies.",
963
  "Remove Final Comma": "There should be no comma at the end of your prompt. Consider removing it.",
964
  "Move Comma Inside Parentheses": "In most e621-based models, the comma following a tag functions as an &quot;attention anchor&quot;, carrying most of the tag&apos;s information. It should therefore be assigned the same weight as the rest of the tag. So instead of &quot;(lineless:1.1),&quot;, consider &quot;(lineless,:1.1)&quot; or &quot;(lineless,)&quot;",
965
+ "Double Comma": "One comma between tags is considered ample.",
966
+ "Model Specific": "This is not an e621 tag, but may still be valid with the right model. Check your model&#39;s documentation. If the tag is not mentioned in the documentation, do not use it."
967
  }
968
  color_map = {
969
+ "Unknown Tag": ("white", "red"),
970
+ "Duplicate": ("black", "yellow"),
971
+ "Move Comma Inside Parentheses": ("white", "green"),
972
+ "Double Comma": ("white", "orange"),
973
+ "Model Specific": ("black", "lightgray"),
974
+ "Remove Final Comma": ("white", "brown")
975
  }
976
+
977
+ # Splice from the original text so indexes stay valid.
978
+ combined = sorted(bad_entities + known_entities, key=lambda x: x["start"], reverse=True)
 
 
 
979
  html_text = text
980
+
981
+ for entity in combined:
982
+ start = entity["start"]
983
+ end = entity["end"]
984
+ label = entity["entity"]
985
+
986
+ # Escape only the replaced segment (keeps indices correct).
987
+ segment = text[start:end]
988
+ disp = escape_html(escape_parens_for_display(segment))
989
+
990
  if label == "Known Tag":
991
+ wiki_url = entity.get("wiki_url", "")
992
+ count = entity.get("count", 0)
993
+ wiki_entry = entity.get("wiki_entry", "")
994
+ sanitized_wiki_entry = escape_html(wiki_entry) if wiki_entry else "Unavailable"
995
+
996
+ if wiki_url:
997
+ html_part = (
998
+ f'<a class="hover-underline" href="{wiki_url}" target="_blank" '
999
+ f'title="Count: {count}\tWiki: {sanitized_wiki_entry}" '
1000
+ f'style="cursor: pointer; font-style: italic;">{disp}</a>'
1001
+ )
1002
  else:
1003
+ html_part = (
1004
+ f'<span class="hover-underline" title="Count: {count}\tWiki: {sanitized_wiki_entry}" '
1005
+ f'style="cursor: help; font-style: italic;">{disp}</span>'
1006
+ )
1007
  else:
1008
+ fg, bg = color_map.get(label, ("black", "white"))
1009
+ html_part = f'<span style="background-color: {bg}; color: {fg};">{disp}</span>'
1010
+
1011
  html_text = html_text[:start] + html_part + html_text[end:]
1012
+
1013
+ # Color key (only for labels that actually appeared)
1014
  color_key_html = "<div style='text-align: right; margin-top: 20px;'>Key:"
1015
+ used_labels = {e["entity"] for e in bad_entities}
1016
+ for label, (fg, bg) in color_map.items():
1017
  if label in used_labels:
1018
  tooltip = tooltip_map.get(label, "")
1019
+ color_key_html += (
1020
+ f" <span class='hover-underline' style='background-color: {bg}; color: {fg}; margin-right: 10px;' "
1021
+ f"title='{tooltip}'>{label}</span>"
1022
+ )
1023
  color_key_html += "</div>"
1024
+
1025
+ # Wrap the whole annotated area so we can place the hint inside it
1026
+ annotated_box = (
1027
+ "<div class='annotated-wrap' style='padding:10px;font-size:16px;'>"
1028
+ f"{html_text}"
1029
+ f"{TOOLTIP_NOTE_HTML}"
1030
+ "</div>"
1031
+ )
1032
+
1033
+ return annotated_box + color_key_html
1034
 
 
1035
 
1036
 
1037
  def find_similar_artists(original_tags_string, top_n, context_similarity_weight, allow_nsfw_tags):
 
1048
  #Suggested tags stuff
1049
  suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
1050
  suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
1051
+ terms = [item["tf_idf_matrix_tag"] for item in tag_data] + removed_tags
1052
+ suggested_tags = get_tfidf_reduced_similar_tags(dict(Counter(terms)), allow_nsfw_tags)
1053
+
1054
  unseen_tags_data, bad_entities, known_entities = find_similar_tags(tag_data, suggested_tags, context_similarity_weight, allow_nsfw_tags)
1055
 
1056
  #Bad tags stuff
 
1061
 
1062
  # Create a set of tags that should be filtered out
1063
  filter_tags = {entry["original_tag"].strip() for entry in tag_data}
1064
+ filter_tags_norm = { _norm_tag_for_lookup(t.lower().removeprefix('by ').removeprefix('by_')) for t in filter_tags }
1065
+ suggested_tags_filtered = OrderedDict(
1066
+ (k, v) for k, v in suggested_tags.items()
1067
+ if k not in filter_tags and _norm_tag_for_lookup(k.lower()) not in filter_tags_norm
1068
+ )
1069
 
1070
  # Splitting the dictionary into two based on the condition
1071
+ def _norm_no_by(s: str) -> str:
1072
+ n = _norm_tag_for_lookup(s)
1073
+ return n[3:] if n.startswith("by_") else n
1074
+
1075
+ suggested_artist_tags_filtered = OrderedDict(
1076
+ (k, v) for k, v in suggested_tags_filtered.items()
1077
+ if is_artist(_norm_no_by(k))
1078
+ )
1079
+
1080
+ suggested_non_artist_tags_filtered = OrderedDict(
1081
+ (k, v) for k, v in suggested_tags_filtered.items()
1082
+ if not is_artist(_norm_no_by(k)) and k not in special_tags
1083
+ )
1084
+
1085
 
1086
  topnsuggestions = list(islice(suggested_non_artist_tags_filtered.items(), 100))
1087
  suggested_tags_html_content += create_html_tables_for_tags("-", "Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
1088
+ suggested_tags_html_content += "</div>"
1089
+
1090
+
1091
+ # --- Artist stuff: query artist-only index directly ---
1092
+ idf_vec = tf_idf_components['idf']
1093
+ t2c = tf_idf_components['tag_to_column_index']
1094
+ svd = tf_idf_components['svd_model']
1095
+ pseudo_terms = dict(Counter(terms))
1096
+ pseudo_vec = construct_pseudo_vector(pseudo_terms, idf_vec, t2c)
1097
+ reduced_q = svd.transform(pseudo_vec)
1098
+
1099
+ K_art = max(100, top_n * 10) # widen search to stabilize ranks
1100
+ art_inds, art_sims = _ann_artists_topk(reduced_q, k=K_art)
1101
+
1102
+ row_to_tag = tf_idf_components['row_to_tag']
1103
+ bad_labels = {"by_unknown_artist", "by_conditional_dnp", "unknown_artist", "conditional_dnp"}
1104
+
1105
+ top_artists_raw = []
1106
+ for idx_i, sim in zip(art_inds, art_sims):
1107
+ tag = row_to_tag.get(int(idx_i), "")
1108
+ if not tag:
1109
+ continue
1110
+
1111
+ # Normalize spaces to underscores for reliable checks
1112
+ norm = tag.replace(" ", "_")
1113
+
1114
+ # Drop known non-artist placeholders
1115
+ if norm in bad_labels:
1116
+ continue
1117
 
1118
+ # Accept either "by_foo" or plain "foo"
1119
+ base = norm[3:] if norm.startswith("by_") else norm
1120
+
1121
+ # Guard: only keep if this *really* is an artist we know
1122
+ if not is_artist(base):
1123
+ continue
1124
+
1125
+ name_disp = base.replace("_", " ")
1126
+ top_artists_raw.append((name_disp, float(sim)))
1127
+
1128
+ if not top_artists_raw:
1129
+ logging.debug("No artist hits. First few neighbor labels: %s",
1130
+ [row_to_tag.get(int(i), "") for i in art_inds[:10]])
1131
+
1132
+ # take the best unique names, in order
1133
+ seen = set()
1134
+ deduped = []
1135
+ for n, s in top_artists_raw:
1136
+ if n not in seen:
1137
+ deduped.append((n, s))
1138
+ seen.add(n)
1139
+ if len(deduped) >= top_n:
1140
+ break
1141
+
1142
+ top_artists = deduped
1143
+ logging.debug("Top artists (n=%d): %s", len(top_artists), top_artists)
1144
  top_artists_str = create_top_artists_table(top_artists)
1145
  dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
1146
+ dynamic_prompts_formatted_artists = "{" + "|".join(
1147
+ [escape_parens_for_display(artist) for artist, _ in top_artists]
1148
+ ) + "}"
1149
 
1150
  image_galleries = []
1151
  for root, dirs, files in os.walk(sample_images_directory_path):
1152
  for name in dirs:
1153
+ baseline, artists = generate_artist_image_tuples([name for name, _ in top_artists], os.path.join(root, name))
1154
+ dir_path = os.path.join(root, name)
1155
+ baseline, artists = generate_artist_image_tuples([n for n, _ in top_artists], dir_path)
1156
+ logging.debug("Gallery built for %s -> baseline=%d, artists_found=%d", dir_path, len(baseline), len(artists))
1157
  image_galleries.append(baseline) # Add baseline as its own gallery item
1158
  image_galleries.append(artists) # Extend the list with artist tuples
1159
 
1160
  return (unseen_tags_data, bad_tags_illustrated_html, suggested_tags_html_content, top_artists_str, dynamic_prompts_formatted_artists, *image_galleries)
1161
+ except ParseError:
1162
+ # Build empty galleries so the tuple length matches the declared outputs
1163
+ empty_galleries = []
1164
+ for _root, _dirs, _files in os.walk(sample_images_directory_path):
1165
+ for _ in _dirs:
1166
+ empty_galleries.extend([[], []]) # one empty list per Gallery component
1167
+
1168
+ return (
1169
+ create_html_placeholder(title="Unknown Tags", content="Parse Error"),
1170
+ "Parse Error: Check for mismatched parentheses or something",
1171
+ create_html_placeholder(title="Suggested Tags"),
1172
+ "", # top_artists
1173
+ "", # dynamic_prompts
1174
+ *empty_galleries,
1175
+ )
1176
+
1177
 
1178
 
1179
  with gr.Blocks(css=css) as app:
1180
  with gr.Group():
1181
  with gr.Row():
1182
+ with gr.Column(scale=3, elem_classes=["prompt-col"]):
1183
+ image_tags = gr.Textbox(
1184
+ label="Enter Prompt",
1185
+ placeholder="e.g. fox, outside, detailed background, ...",
1186
+ lines=1 # Enter submits (see .submit() below)
1187
+ )
1188
  bad_tags_illustrated_string = gr.HTML()
1189
  with gr.Column(scale=1):
1190
  gr.HTML(
 
1199
  with gr.Group():
1200
  with gr.Row():
1201
  context_similarity_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Context Similarity Weight")
1202
+ allow_nsfw = gr.Checkbox(label="Allow NSFW Tag Suggestions", value=False)
1203
  with gr.Row():
1204
  with gr.Column(scale=2):
1205
+ unseen_tags = gr.HTML(
1206
+ label="Unknown Tags",
1207
+ value=create_html_placeholder(title="Unknown Tags"),
1208
+ elem_id="unseen_html",
1209
+ elem_classes=["scroll-fade"],
1210
+ )
1211
  with gr.Column(scale=1):
1212
+ suggested_tags = gr.HTML(
1213
+ label="Suggested Tags",
1214
+ value=create_html_placeholder(title="Suggested Tags"),
1215
+ elem_id="suggested_html",
1216
+ elem_classes=["scroll-fade"],
1217
+ )
1218
  with gr.Column(scale=1):
1219
  with gr.Group():
1220
  num_artists = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
1221
+ top_artists = gr.HTML(
1222
+ label="Top Artists",
1223
+ value=create_html_placeholder(title="Top Artists"),
1224
+ elem_id="artists_html",
1225
+ elem_classes=["scroll-fade"],
1226
+ )
1227
+ gr.HTML("""
1228
+ <script>
1229
+ (function(){
1230
+ function wire(id){
1231
+ const host = document.getElementById(id);
1232
+ if (!host) return;
1233
+
1234
+ // Always use the *inner* .scrollable-content as the scroller
1235
+ const getScroller = () => host.querySelector('.scrollable-content') || host;
1236
+ let scroller = getScroller();
1237
+
1238
+ // Set CSS var so the fade blends with host background
1239
+ const bg = getComputedStyle(host).backgroundColor;
1240
+ host.style.setProperty('--host-bg', bg);
1241
+
1242
+ const refresh = () => {
1243
+ // guard for fractional pixels across browsers
1244
+ const atBottom = Math.ceil(scroller.scrollTop + scroller.clientHeight) >= scroller.scrollHeight;
1245
+ host.classList.toggle('at-bottom', atBottom);
1246
+ };
1247
+
1248
+ // (Re)attach scroll listener to the current scroller
1249
+ const attach = (el) => {
1250
+ if (!el) return;
1251
+ el.addEventListener('scroll', refresh, {passive:true});
1252
+ // initial state
1253
+ refresh();
1254
+ };
1255
+
1256
+ attach(scroller);
1257
+
1258
+ // If Gradio replaces inner HTML, re-wire to new scroller
1259
+ new MutationObserver(() => {
1260
+ const next = getScroller();
1261
+ if (next && next !== scroller) {
1262
+ scroller.removeEventListener && scroller.removeEventListener('scroll', refresh);
1263
+ scroller = next;
1264
+ attach(scroller);
1265
+ }
1266
+ // background might change with themes; keep it fresh
1267
+ const newBg = getComputedStyle(host).backgroundColor;
1268
+ host.style.setProperty('--host-bg', newBg);
1269
+ refresh();
1270
+ }).observe(host, {childList: true, subtree: true});
1271
+
1272
+ // Also respond to resizes
1273
+ new ResizeObserver(refresh).observe(host);
1274
+ }
1275
+
1276
+ ['unseen_html','suggested_html','artists_html'].forEach(wire);
1277
+ })();
1278
+ </script>
1279
+ """, visible=False)
1280
+
1281
  dynamic_prompts = gr.Textbox(label="Dynamic Prompts Format", info="For if you're using the Automatic1111 webui (https://github.com/AUTOMATIC1111/stable-diffusion-webui) with the Dynamic Prompts extension activated (https://github.com/adieyal/sd-dynamic-prompts) and want to try them all individually.")
1282
  galleries = []
1283
  for root, dirs, files in os.walk(sample_images_directory_path):
1284
  for name in dirs:
1285
  with gr.Row():
1286
  baseline = gr.Gallery(allow_preview=False, rows=1, columns=1, height=420, scale=3)
1287
+ styles = gr.Gallery(allow_preview=False, rows=2, columns=5, height=420, scale=8)
1288
  galleries.extend([baseline, styles])
1289
 
1290
  submit_button.click(
 
1292
  inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
1293
  outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
1294
  )
1295
+ # Also run when pressing Enter in the prompt box
1296
+ image_tags.submit(
1297
+ find_similar_artists,
1298
+ inputs=[image_tags, num_artists, context_similarity_weight, allow_nsfw],
1299
+ outputs=[unseen_tags, bad_tags_illustrated_string, suggested_tags, top_artists, dynamic_prompts] + galleries
1300
+ )
1301
+
1302
 
1303
  gr.Markdown(faq_content)
1304
 
requirements.txt CHANGED
@@ -2,7 +2,7 @@ gradio==4.44.1
2
  gradio-client==1.3.0
3
  fastapi==0.116.1
4
  starlette==0.47.3
5
-
6
  numpy==1.25.1
7
  scikit-learn==1.4.1.post1
8
  h5py==3.8.0
 
2
  gradio-client==1.3.0
3
  fastapi==0.116.1
4
  starlette==0.47.3
5
+ hnswlib==0.8.0
6
  numpy==1.25.1
7
  scikit-learn==1.4.1.post1
8
  h5py==3.8.0