Spaces:
Build error
Build error
| import gradio as gr | |
| from easygui import msgbox | |
| import subprocess | |
| from .common_gui import get_folder_path, add_pre_postfix | |
| import os | |
| from library.custom_logging import setup_logging | |
| # Set up logging | |
| log = setup_logging() | |
| def caption_images( | |
| train_data_dir, | |
| caption_extension, | |
| batch_size, | |
| general_threshold, | |
| character_threshold, | |
| replace_underscores, | |
| model, | |
| recursive, | |
| max_data_loader_n_workers, | |
| debug, | |
| undesired_tags, | |
| frequency_tags, | |
| prefix, | |
| postfix, | |
| ): | |
| # Check for images_dir_input | |
| if train_data_dir == '': | |
| msgbox('Image folder is missing...') | |
| return | |
| if caption_extension == '': | |
| msgbox('Please provide an extension for the caption files.') | |
| return | |
| log.info(f'Captioning files in {train_data_dir}...') | |
| run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' | |
| run_cmd += f' --batch_size={int(batch_size)}' | |
| run_cmd += f' --general_threshold={general_threshold}' | |
| run_cmd += f' --character_threshold={character_threshold}' | |
| run_cmd += f' --caption_extension="{caption_extension}"' | |
| run_cmd += f' --model="{model}"' | |
| run_cmd += ( | |
| f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' | |
| ) | |
| if recursive: | |
| run_cmd += f' --recursive' | |
| if debug: | |
| run_cmd += f' --debug' | |
| if replace_underscores: | |
| run_cmd += f' --remove_underscore' | |
| if frequency_tags: | |
| run_cmd += f' --frequency_tags' | |
| if not undesired_tags == '': | |
| run_cmd += f' --undesired_tags="{undesired_tags}"' | |
| run_cmd += f' "{train_data_dir}"' | |
| log.info(run_cmd) | |
| # Run the command | |
| if os.name == 'posix': | |
| os.system(run_cmd) | |
| else: | |
| subprocess.run(run_cmd) | |
| # Add prefix and postfix | |
| add_pre_postfix( | |
| folder=train_data_dir, | |
| caption_file_ext=caption_extension, | |
| prefix=prefix, | |
| postfix=postfix, | |
| ) | |
| log.info('...captioning done') | |
| ### | |
| # Gradio UI | |
| ### | |
| def gradio_wd14_caption_gui_tab(headless=False): | |
| with gr.Tab('WD14 Captioning'): | |
| gr.Markdown( | |
| 'This utility will use WD14 to caption files for each images in a folder.' | |
| ) | |
| # Input Settings | |
| # with gr.Section('Input Settings'): | |
| with gr.Row(): | |
| train_data_dir = gr.Textbox( | |
| label='Image folder to caption', | |
| placeholder='Directory containing the images to caption', | |
| interactive=True, | |
| ) | |
| button_train_data_dir_input = gr.Button( | |
| '📂', elem_id='open_folder_small', visible=(not headless) | |
| ) | |
| button_train_data_dir_input.click( | |
| get_folder_path, | |
| outputs=train_data_dir, | |
| show_progress=False, | |
| ) | |
| caption_extension = gr.Textbox( | |
| label='Caption file extension', | |
| placeholder='Extention for caption file. eg: .caption, .txt', | |
| value='.txt', | |
| interactive=True, | |
| ) | |
| undesired_tags = gr.Textbox( | |
| label='Undesired tags', | |
| placeholder='(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.', | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| prefix = gr.Textbox( | |
| label='Prefix to add to WD14 caption', | |
| placeholder='(Optional)', | |
| interactive=True, | |
| ) | |
| postfix = gr.Textbox( | |
| label='Postfix to add to WD14 caption', | |
| placeholder='(Optional)', | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| replace_underscores = gr.Checkbox( | |
| label='Replace underscores in filenames with spaces', | |
| value=True, | |
| interactive=True, | |
| ) | |
| recursive = gr.Checkbox( | |
| label='Recursive', | |
| value=False, | |
| info='Tag subfolders images as well', | |
| ) | |
| debug = gr.Checkbox( | |
| label='Verbose logging', | |
| value=True, | |
| info='Debug while tagging, it will print your image file with general tags and character tags.', | |
| ) | |
| frequency_tags = gr.Checkbox( | |
| label='Show tags frequency', | |
| value=True, | |
| info='Show frequency of tags for images.', | |
| ) | |
| # Model Settings | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| label='Model', | |
| choices=[ | |
| 'SmilingWolf/wd-v1-4-convnext-tagger-v2', | |
| 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2', | |
| 'SmilingWolf/wd-v1-4-vit-tagger-v2', | |
| 'SmilingWolf/wd-v1-4-swinv2-tagger-v2', | |
| ], | |
| value='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', | |
| ) | |
| general_threshold = gr.Slider( | |
| value=0.35, | |
| label='General threshold', | |
| info='Adjust `general_threshold` for pruning tags (less tags, less flexible)', | |
| minimum=0, | |
| maximum=1, | |
| step=0.05, | |
| ) | |
| character_threshold = gr.Slider( | |
| value=0.35, | |
| label='Character threshold', | |
| info='useful if you want to train with character', | |
| minimum=0, | |
| maximum=1, | |
| step=0.05, | |
| ) | |
| # Advanced Settings | |
| with gr.Row(): | |
| batch_size = gr.Number( | |
| value=8, label='Batch size', interactive=True | |
| ) | |
| max_data_loader_n_workers = gr.Number( | |
| value=2, label='Max dataloader workers', interactive=True | |
| ) | |
| caption_button = gr.Button('Caption images') | |
| caption_button.click( | |
| caption_images, | |
| inputs=[ | |
| train_data_dir, | |
| caption_extension, | |
| batch_size, | |
| general_threshold, | |
| character_threshold, | |
| replace_underscores, | |
| model, | |
| recursive, | |
| max_data_loader_n_workers, | |
| debug, | |
| undesired_tags, | |
| frequency_tags, | |
| prefix, | |
| postfix, | |
| ], | |
| show_progress=False, | |
| ) | |