Spaces:
Running
Running
Upload 5 files
Browse filesAdded "Prompt Enhancer" from [John6666/danbooru-tags-transformer-v2-with-wd-tagger](https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger/blob/main/tagger/promptenhancer.py) (Thanks!) and cleared some code.
- app.py +32 -18
- modules/classifyTags.py +2 -5
- modules/florence2.py +1 -6
- modules/llama_loader.py +1 -5
- modules/tag_enhancer.py +53 -0
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import json
|
|
| 11 |
from modules.classifyTags import classify_tags,process_tags
|
| 12 |
from modules.florence2 import process_image,single_task_list,update_task_dropdown
|
| 13 |
from modules.llama_loader import llama_list,llama3reorganize
|
|
|
|
| 14 |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
|
| 15 |
|
| 16 |
TITLE = "Multi-Tagger"
|
|
@@ -249,9 +250,9 @@ class Predictor:
|
|
| 249 |
reverse=True,
|
| 250 |
)
|
| 251 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
| 252 |
-
#Remove values from character_list that already exist in sorted_general_list
|
| 253 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
| 254 |
-
#Remove values from sorted_general_list that already exist in prepend_list or append_list
|
| 255 |
if prepend_list:
|
| 256 |
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
|
| 257 |
if append_list:
|
|
@@ -312,7 +313,8 @@ class Predictor:
|
|
| 312 |
"rating": rating,
|
| 313 |
"character_res": character_res,
|
| 314 |
"general_res": general_res,
|
| 315 |
-
"unclassified_tags": unclassified_tags
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
timer.report()
|
|
@@ -348,11 +350,12 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
|
|
| 348 |
"rating": "",
|
| 349 |
"character_res": "",
|
| 350 |
"general_res": "",
|
| 351 |
-
"unclassified_tags": "{}"
|
|
|
|
| 352 |
}
|
| 353 |
if selected_state.value["image"]["path"] in tag_results:
|
| 354 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
| 355 |
-
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
| 356 |
def append_gallery(gallery:list,image:str):
|
| 357 |
if gallery is None:gallery=[]
|
| 358 |
if not image:return gallery,None
|
|
@@ -417,7 +420,6 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 417 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 418 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 419 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
|
| 420 |
-
|
| 421 |
model_repo = gr.Dropdown(
|
| 422 |
dropdown_list,
|
| 423 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
|
@@ -485,14 +487,17 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 485 |
size="lg",
|
| 486 |
)
|
| 487 |
with gr.Column(variant="panel"):
|
| 488 |
-
download_file = gr.File(label="Download includes: All outputs* and image(s)")
|
| 489 |
-
character_res = gr.Label(label="Output (characters)")
|
| 490 |
-
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True)
|
| 491 |
-
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True)
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
| 496 |
clear.add(
|
| 497 |
[
|
| 498 |
download_file,
|
|
@@ -503,8 +508,10 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 503 |
character_res,
|
| 504 |
general_res,
|
| 505 |
unclassified,
|
|
|
|
|
|
|
| 506 |
]
|
| 507 |
-
)
|
| 508 |
tag_results = gr.State({})
|
| 509 |
# Define the event listener to add the uploaded image to the gallery
|
| 510 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
|
@@ -512,9 +519,11 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 512 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 513 |
# Event to update the selected image when an image is clicked in the gallery
|
| 514 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 515 |
-
gallery.select(get_selection_from_gallery,
|
| 516 |
# Event to remove a selected image from the gallery
|
| 517 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
|
|
|
|
|
|
| 518 |
submit.click(
|
| 519 |
predictor.predict,
|
| 520 |
inputs=[
|
|
@@ -543,7 +552,7 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 543 |
character_mcut_enabled,
|
| 544 |
],
|
| 545 |
)
|
| 546 |
-
with gr.Tab(label="Tag Categorizer"):
|
| 547 |
with gr.Row():
|
| 548 |
with gr.Column(variant="panel"):
|
| 549 |
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
|
|
@@ -551,7 +560,12 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
| 551 |
with gr.Column(variant="panel"):
|
| 552 |
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
|
| 553 |
categorized_json = gr.JSON(label="Categorized (tags) - JSON")
|
| 554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
with gr.Tab(label="Florence 2 Image Captioning"):
|
| 556 |
with gr.Row():
|
| 557 |
with gr.Column(variant="panel"):
|
|
|
|
| 11 |
from modules.classifyTags import classify_tags,process_tags
|
| 12 |
from modules.florence2 import process_image,single_task_list,update_task_dropdown
|
| 13 |
from modules.llama_loader import llama_list,llama3reorganize
|
| 14 |
+
from modules.tag_enhancer import prompt_enhancer
|
| 15 |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
|
| 16 |
|
| 17 |
TITLE = "Multi-Tagger"
|
|
|
|
| 250 |
reverse=True,
|
| 251 |
)
|
| 252 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
| 253 |
+
# Remove values from character_list that already exist in sorted_general_list
|
| 254 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
| 255 |
+
# Remove values from sorted_general_list that already exist in prepend_list or append_list
|
| 256 |
if prepend_list:
|
| 257 |
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
|
| 258 |
if append_list:
|
|
|
|
| 313 |
"rating": rating,
|
| 314 |
"character_res": character_res,
|
| 315 |
"general_res": general_res,
|
| 316 |
+
"unclassified_tags": unclassified_tags,
|
| 317 |
+
"enhanced_tags": "" # Initialize as empty string
|
| 318 |
}
|
| 319 |
|
| 320 |
timer.report()
|
|
|
|
| 350 |
"rating": "",
|
| 351 |
"character_res": "",
|
| 352 |
"general_res": "",
|
| 353 |
+
"unclassified_tags": "{}",
|
| 354 |
+
"enhanced_tags": ""
|
| 355 |
}
|
| 356 |
if selected_state.value["image"]["path"] in tag_results:
|
| 357 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
| 358 |
+
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
|
| 359 |
def append_gallery(gallery:list,image:str):
|
| 360 |
if gallery is None:gallery=[]
|
| 361 |
if not image:return gallery,None
|
|
|
|
| 420 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 421 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 422 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
|
|
|
|
| 423 |
model_repo = gr.Dropdown(
|
| 424 |
dropdown_list,
|
| 425 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
|
|
|
| 487 |
size="lg",
|
| 488 |
)
|
| 489 |
with gr.Column(variant="panel"):
|
| 490 |
+
download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
|
| 491 |
+
character_res = gr.Label(label="Output (characters)") # 1
|
| 492 |
+
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
|
| 493 |
+
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
|
| 494 |
+
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
|
| 495 |
+
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
|
| 496 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
|
| 497 |
+
categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
|
| 498 |
+
rating = gr.Label(label="Rating") # 8
|
| 499 |
+
general_res = gr.Label(label="Output (tags)") # 9
|
| 500 |
+
unclassified = gr.JSON(label="Unclassified (tags)") # 10
|
| 501 |
clear.add(
|
| 502 |
[
|
| 503 |
download_file,
|
|
|
|
| 508 |
character_res,
|
| 509 |
general_res,
|
| 510 |
unclassified,
|
| 511 |
+
prompt_enhancer_model,
|
| 512 |
+
enhanced_tags,
|
| 513 |
]
|
| 514 |
+
)
|
| 515 |
tag_results = gr.State({})
|
| 516 |
# Define the event listener to add the uploaded image to the gallery
|
| 517 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
|
|
|
| 519 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 520 |
# Event to update the selected image when an image is clicked in the gallery
|
| 521 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 522 |
+
gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
|
| 523 |
# Event to remove a selected image from the gallery
|
| 524 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 525 |
+
# Event to for the Prompt Enhancer Button
|
| 526 |
+
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
|
| 527 |
submit.click(
|
| 528 |
predictor.predict,
|
| 529 |
inputs=[
|
|
|
|
| 552 |
character_mcut_enabled,
|
| 553 |
],
|
| 554 |
)
|
| 555 |
+
with gr.Tab(label="Tag Categorizer + Enhancer"):
|
| 556 |
with gr.Row():
|
| 557 |
with gr.Column(variant="panel"):
|
| 558 |
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
|
|
|
|
| 560 |
with gr.Column(variant="panel"):
|
| 561 |
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
|
| 562 |
categorized_json = gr.JSON(label="Categorized (tags) - JSON")
|
| 563 |
+
submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
|
| 564 |
+
with gr.Column(variant="panel"):
|
| 565 |
+
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
|
| 566 |
+
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
|
| 567 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
|
| 568 |
+
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
|
| 569 |
with gr.Tab(label="Florence 2 Image Captioning"):
|
| 570 |
with gr.Row():
|
| 571 |
with gr.Column(variant="panel"):
|
modules/classifyTags.py
CHANGED
|
@@ -171,9 +171,6 @@ def process_tags(input_tags: str):
|
|
| 171 |
categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
|
| 172 |
categorized_json = {category: tags for category, tags in classified_tags.items()}
|
| 173 |
|
| 174 |
-
return categorized_string, categorized_json
|
| 175 |
|
| 176 |
-
tags = []
|
| 177 |
-
if __name__ == "__main__":
|
| 178 |
-
classify_tags (tags, True)
|
| 179 |
-
process_tags(input_tags)
|
|
|
|
| 171 |
categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
|
| 172 |
categorized_json = {category: tags for category, tags in classified_tags.items()}
|
| 173 |
|
| 174 |
+
return categorized_string, categorized_json, "" # Initialize enhanced_prompt as empty
|
| 175 |
|
| 176 |
+
tags = []
|
|
|
|
|
|
|
|
|
modules/florence2.py
CHANGED
|
@@ -94,9 +94,4 @@ def update_task_dropdown(choice):
|
|
| 94 |
if choice == 'Cascaded task':
|
| 95 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
| 96 |
else:
|
| 97 |
-
return gr.Dropdown(choices=single_task_list, value='Caption')
|
| 98 |
-
|
| 99 |
-
if __name__ == "__main__":
|
| 100 |
-
process_image()
|
| 101 |
-
single_task_list
|
| 102 |
-
update_task_dropdown()
|
|
|
|
| 94 |
if choice == 'Cascaded task':
|
| 95 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
| 96 |
else:
|
| 97 |
+
return gr.Dropdown(choices=single_task_list, value='Caption')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/llama_loader.py
CHANGED
|
@@ -182,8 +182,4 @@ class llama3reorganize:
|
|
| 182 |
except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
|
| 183 |
return result
|
| 184 |
|
| 185 |
-
llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
|
| 186 |
-
|
| 187 |
-
if __name__ == "__main__":
|
| 188 |
-
llama3reorganize()
|
| 189 |
-
llama_list
|
|
|
|
| 182 |
except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
|
| 183 |
return result
|
| 184 |
|
| 185 |
+
llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
|
|
|
|
|
|
|
|
|
|
|
|
modules/tag_enhancer.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
|
| 3 |
+
import re,torch
|
| 4 |
+
|
| 5 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 6 |
+
|
| 7 |
+
def load_models():
|
| 8 |
+
try:
|
| 9 |
+
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
|
| 10 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
| 11 |
+
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
| 14 |
+
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
| 15 |
+
except Exception as e:
|
| 16 |
+
print(e)
|
| 17 |
+
enhancer_medium = enhancer_long = enhancer_flux = None
|
| 18 |
+
return enhancer_medium, enhancer_long, enhancer_flux
|
| 19 |
+
|
| 20 |
+
enhancer_medium, enhancer_long, enhancer_flux = load_models()
|
| 21 |
+
|
| 22 |
+
def enhance_prompt(input_prompt, model_choice):
|
| 23 |
+
if model_choice == "Medium":
|
| 24 |
+
result = enhancer_medium("Enhance the description: " + input_prompt)
|
| 25 |
+
enhanced_text = result[0]['summary_text']
|
| 26 |
+
|
| 27 |
+
pattern = r'^.*?of\s+(.*?(?:\.|$))'
|
| 28 |
+
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
|
| 29 |
+
|
| 30 |
+
if match:
|
| 31 |
+
remaining_text = enhanced_text[match.end():].strip()
|
| 32 |
+
modified_sentence = match.group(1).capitalize()
|
| 33 |
+
enhanced_text = modified_sentence + ' ' + remaining_text
|
| 34 |
+
elif model_choice == "Flux":
|
| 35 |
+
result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256)
|
| 36 |
+
enhanced_text = result[0]['generated_text']
|
| 37 |
+
else: # Long
|
| 38 |
+
result = enhancer_long("Enhance the description: " + input_prompt)
|
| 39 |
+
enhanced_text = result[0]['summary_text']
|
| 40 |
+
|
| 41 |
+
return enhanced_text
|
| 42 |
+
|
| 43 |
+
def prompt_enhancer(character: str, series: str, general: str, model_choice: str):
|
| 44 |
+
characters = character.split(",") if character else []
|
| 45 |
+
serieses = series.split(",") if series else []
|
| 46 |
+
generals = general.split(",") if general else []
|
| 47 |
+
tags = characters + serieses + generals
|
| 48 |
+
cprompt = ",".join(tags) if tags else ""
|
| 49 |
+
|
| 50 |
+
output = enhance_prompt(cprompt, model_choice)
|
| 51 |
+
prompt = cprompt + ", " + output
|
| 52 |
+
|
| 53 |
+
return prompt, gr.update(interactive=True), gr.update(interactive=True)
|