| import argparse | |
| import gradio as gr | |
| from kohya_ss.library import train_util, config_util | |
| from scripts import presets, ui, ui_overrides | |
| from scripts.runner import initialize_runner | |
| from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio | |
| def title(): | |
| return "Train network" | |
| def create_ui(): | |
| sd_models_arguments = argparse.ArgumentParser() | |
| dataset_arguments = argparse.ArgumentParser() | |
| training_arguments = argparse.ArgumentParser() | |
| optimizer_arguments = argparse.ArgumentParser() | |
| config_arguments = argparse.ArgumentParser() | |
| train_util.add_sd_models_arguments(sd_models_arguments) | |
| train_util.add_dataset_arguments(dataset_arguments, True, True, True) | |
| train_util.add_training_arguments(training_arguments, True) | |
| train_util.add_optimizer_arguments(optimizer_arguments) | |
| config_util.add_config_arguments(config_arguments) | |
| sd_models_options = {} | |
| dataset_options = {} | |
| training_options = {} | |
| optimizer_options = {} | |
| config_options = {} | |
| network_options = {} | |
| templates, script_file = load_args_template("train_network.py") | |
| get_options = lambda: { | |
| **sd_models_options, | |
| **dataset_options, | |
| **training_options, | |
| **optimizer_options, | |
| **config_options, | |
| **network_options, | |
| } | |
| get_templates = lambda: { | |
| **sd_models_arguments.__dict__["_option_string_actions"], | |
| **dataset_arguments.__dict__["_option_string_actions"], | |
| **training_arguments.__dict__["_option_string_actions"], | |
| **optimizer_arguments.__dict__["_option_string_actions"], | |
| **config_arguments.__dict__["_option_string_actions"], | |
| **templates, | |
| } | |
| with gr.Column(): | |
| init_runner = initialize_runner(script_file, get_templates, get_options) | |
| with gr.Box(): | |
| with gr.Row(): | |
| init_id = presets.create_ui("train_network", get_templates, get_options) | |
| with gr.Row(): | |
| with gr.Group(): | |
| with gr.Box(): | |
| ui.title("Network options") | |
| options_to_gradio(templates, network_options) | |
| with gr.Box(): | |
| ui.title("Model options") | |
| args_to_gradio(sd_models_arguments, sd_models_options) | |
| with gr.Box(): | |
| ui.title("Dataset Config options") | |
| args_to_gradio(config_arguments, config_options) | |
| with gr.Box(): | |
| ui.title("Dataset options") | |
| args_to_gradio(dataset_arguments, dataset_options) | |
| with gr.Box(): | |
| ui.title("Training options") | |
| args_to_gradio(training_arguments, training_options) | |
| with gr.Box(): | |
| ui.title("Optimizer options") | |
| args_to_gradio( | |
| optimizer_arguments, | |
| optimizer_options, | |
| ui_overrides.OPTIMIZER_OPTIONS, | |
| ) | |
| init_runner() | |
| init_id() | |