Upload kohya_gui_kaggle.py
Browse files- kohya_gui_kaggle.py +22 -11
kohya_gui_kaggle.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import argparse
|
| 4 |
-
from
|
| 5 |
-
from
|
| 6 |
-
from
|
|
|
|
| 7 |
from kohya_gui.utilities import utilities_tab
|
| 8 |
-
from lora_gui import lora_tab
|
| 9 |
from kohya_gui.class_lora_tab import LoRATools
|
| 10 |
|
| 11 |
from kohya_gui.custom_logging import setup_logging
|
|
@@ -22,9 +23,9 @@ def UI(**kwargs):
|
|
| 22 |
headless = kwargs.get("headless", False)
|
| 23 |
log.info(f"headless: {headless}")
|
| 24 |
|
| 25 |
-
if os.path.exists("./style.css"):
|
| 26 |
-
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
|
| 27 |
-
log.
|
| 28 |
css += file.read() + "\n"
|
| 29 |
|
| 30 |
if os.path.exists("./.release"):
|
|
@@ -38,6 +39,8 @@ def UI(**kwargs):
|
|
| 38 |
interface = gr.Blocks(
|
| 39 |
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
| 40 |
)
|
|
|
|
|
|
|
| 41 |
|
| 42 |
with interface:
|
| 43 |
with gr.Tab("Dreambooth"):
|
|
@@ -46,13 +49,13 @@ def UI(**kwargs):
|
|
| 46 |
reg_data_dir_input,
|
| 47 |
output_dir_input,
|
| 48 |
logging_dir_input,
|
| 49 |
-
) = dreambooth_tab(headless=headless)
|
| 50 |
with gr.Tab("LoRA"):
|
| 51 |
-
lora_tab(headless=headless)
|
| 52 |
with gr.Tab("Textual Inversion"):
|
| 53 |
-
ti_tab(headless=headless)
|
| 54 |
with gr.Tab("Finetuning"):
|
| 55 |
-
finetune_tab(headless=headless)
|
| 56 |
with gr.Tab("Utilities"):
|
| 57 |
utilities_tab(
|
| 58 |
train_data_dir_input=train_data_dir_input,
|
|
@@ -102,6 +105,12 @@ def UI(**kwargs):
|
|
| 102 |
if __name__ == "__main__":
|
| 103 |
# torch.cuda.set_per_process_memory_fraction(0.48)
|
| 104 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
parser.add_argument(
|
| 106 |
"--listen",
|
| 107 |
type=str,
|
|
@@ -130,10 +139,12 @@ if __name__ == "__main__":
|
|
| 130 |
)
|
| 131 |
|
| 132 |
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
|
|
|
| 133 |
|
| 134 |
args = parser.parse_args()
|
| 135 |
|
| 136 |
UI(
|
|
|
|
| 137 |
username=args.username,
|
| 138 |
password=args.password,
|
| 139 |
inbrowser=args.inbrowser,
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import argparse
|
| 4 |
+
from kohya_gui.class_gui_config import KohyaSSGUIConfig
|
| 5 |
+
from kohya_gui.dreambooth_gui import dreambooth_tab
|
| 6 |
+
from kohya_gui.finetune_gui import finetune_tab
|
| 7 |
+
from kohya_gui.textual_inversion_gui import ti_tab
|
| 8 |
from kohya_gui.utilities import utilities_tab
|
| 9 |
+
from kohya_gui.lora_gui import lora_tab
|
| 10 |
from kohya_gui.class_lora_tab import LoRATools
|
| 11 |
|
| 12 |
from kohya_gui.custom_logging import setup_logging
|
|
|
|
| 23 |
headless = kwargs.get("headless", False)
|
| 24 |
log.info(f"headless: {headless}")
|
| 25 |
|
| 26 |
+
if os.path.exists("./assets/style.css"):
|
| 27 |
+
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
|
| 28 |
+
log.debug("Load CSS...")
|
| 29 |
css += file.read() + "\n"
|
| 30 |
|
| 31 |
if os.path.exists("./.release"):
|
|
|
|
| 39 |
interface = gr.Blocks(
|
| 40 |
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
| 41 |
)
|
| 42 |
+
|
| 43 |
+
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config_file_path"))
|
| 44 |
|
| 45 |
with interface:
|
| 46 |
with gr.Tab("Dreambooth"):
|
|
|
|
| 49 |
reg_data_dir_input,
|
| 50 |
output_dir_input,
|
| 51 |
logging_dir_input,
|
| 52 |
+
) = dreambooth_tab(headless=headless, config=config)
|
| 53 |
with gr.Tab("LoRA"):
|
| 54 |
+
lora_tab(headless=headless, config=config)
|
| 55 |
with gr.Tab("Textual Inversion"):
|
| 56 |
+
ti_tab(headless=headless, config=config)
|
| 57 |
with gr.Tab("Finetuning"):
|
| 58 |
+
finetune_tab(headless=headless, config=config)
|
| 59 |
with gr.Tab("Utilities"):
|
| 60 |
utilities_tab(
|
| 61 |
train_data_dir_input=train_data_dir_input,
|
|
|
|
| 105 |
if __name__ == "__main__":
|
| 106 |
# torch.cuda.set_per_process_memory_fraction(0.48)
|
| 107 |
parser = argparse.ArgumentParser()
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--config",
|
| 110 |
+
type=str,
|
| 111 |
+
default="./config.toml",
|
| 112 |
+
help="Path to the toml config file for interface defaults",
|
| 113 |
+
)
|
| 114 |
parser.add_argument(
|
| 115 |
"--listen",
|
| 116 |
type=str,
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
| 142 |
+
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
| 143 |
|
| 144 |
args = parser.parse_args()
|
| 145 |
|
| 146 |
UI(
|
| 147 |
+
config_file_path=args.config,
|
| 148 |
username=args.username,
|
| 149 |
password=args.password,
|
| 150 |
inbrowser=args.inbrowser,
|