| import os
|
| import gradio as gr
|
| import shlex
|
|
|
| from .custom_logging import setup_logging
|
| from .class_gui_config import KohyaSSGUIConfig
|
|
|
|
|
| log = setup_logging()
|
|
|
| folder_symbol = "\U0001f4c2"
|
| refresh_symbol = "\U0001f504"
|
| save_style_symbol = "\U0001f4be"
|
| document_symbol = "\U0001F4C4"
|
|
|
|
|
|
|
|
|
|
|
| def create_prompt_file(sample_prompts, output_dir):
|
| """
|
| Creates a prompt file for image sampling.
|
|
|
| Args:
|
| sample_prompts (str): The prompts to use for image sampling.
|
| output_dir (str): The directory where the output images will be saved.
|
|
|
| Returns:
|
| str: The path to the prompt file.
|
| """
|
| sample_prompts_path = os.path.join(output_dir, "prompt.txt")
|
|
|
| with open(sample_prompts_path, "w", encoding="utf-8") as f:
|
| f.write(sample_prompts)
|
|
|
| return sample_prompts_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class SampleImages:
|
| """
|
| A class for managing the Gradio interface for sampling images during training.
|
| """
|
|
|
| def __init__(
|
| self,
|
| config: KohyaSSGUIConfig = {},
|
| ):
|
| """
|
| Initializes the SampleImages class.
|
| """
|
| self.config = config
|
|
|
| self.initialize_accordion()
|
|
|
| def initialize_accordion(self):
|
| """
|
| Initializes the accordion for the Gradio interface.
|
| """
|
| with gr.Row():
|
| self.sample_every_n_steps = gr.Number(
|
| label="Sample every n steps",
|
| value=self.config.get("samples.sample_every_n_steps", 0),
|
| precision=0,
|
| interactive=True,
|
| )
|
| self.sample_every_n_epochs = gr.Number(
|
| label="Sample every n epochs",
|
| value=self.config.get("samples.sample_every_n_epochs", 0),
|
| precision=0,
|
| interactive=True,
|
| )
|
| self.sample_sampler = gr.Dropdown(
|
| label="Sample sampler",
|
| choices=[
|
| "ddim",
|
| "pndm",
|
| "lms",
|
| "euler",
|
| "euler_a",
|
| "heun",
|
| "dpm_2",
|
| "dpm_2_a",
|
| "dpmsolver",
|
| "dpmsolver++",
|
| "dpmsingle",
|
| "k_lms",
|
| "k_euler",
|
| "k_euler_a",
|
| "k_dpm_2",
|
| "k_dpm_2_a",
|
| ],
|
| value=self.config.get("samples.sample_sampler", "euler_a"),
|
| interactive=True,
|
| )
|
| with gr.Row():
|
| self.sample_prompts = gr.Textbox(
|
| lines=5,
|
| label="Sample prompts",
|
| interactive=True,
|
| placeholder="masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28",
|
| info="Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.",
|
| value=self.config.get("samples.sample_prompts", ""),
|
| )
|
|
|