| | import gradio as gr
|
| | from typing import Tuple
|
| | from .custom_logging import setup_logging
|
| |
|
| |
|
| | log = setup_logging()
|
| |
|
| | class BasicTraining:
|
| | """
|
| | This class configures and initializes the basic training settings for a machine learning model,
|
| | including options for SDXL, learning rate, learning rate scheduler, and training epochs.
|
| |
|
| | Attributes:
|
| | sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training.
|
| | learning_rate_value (str): Initial learning rate value.
|
| | lr_scheduler_value (str): Initial learning rate scheduler value.
|
| | lr_warmup_value (str): Initial learning rate warmup value.
|
| | finetuning (bool): If True, enables fine-tuning of the model.
|
| | dreambooth (bool): If True, enables Dreambooth training.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | sdxl_checkbox: gr.Checkbox,
|
| | learning_rate_value: float = "1e-6",
|
| | lr_scheduler_value: str = "constant",
|
| | lr_warmup_value: float = "0",
|
| | finetuning: bool = False,
|
| | dreambooth: bool = False,
|
| | config: dict = {},
|
| | ) -> None:
|
| | """
|
| | Initializes the BasicTraining object with the given parameters.
|
| |
|
| | Args:
|
| | sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training.
|
| | learning_rate_value (str): Initial learning rate value.
|
| | lr_scheduler_value (str): Initial learning rate scheduler value.
|
| | lr_warmup_value (str): Initial learning rate warmup value.
|
| | finetuning (bool): If True, enables fine-tuning of the model.
|
| | dreambooth (bool): If True, enables Dreambooth training.
|
| | """
|
| | self.sdxl_checkbox = sdxl_checkbox
|
| | self.learning_rate_value = learning_rate_value
|
| | self.lr_scheduler_value = lr_scheduler_value
|
| | self.lr_warmup_value = lr_warmup_value
|
| | self.finetuning = finetuning
|
| | self.dreambooth = dreambooth
|
| | self.config = config
|
| | self.old_lr_warmup = 0
|
| |
|
| |
|
| | self.initialize_ui_components()
|
| |
|
| | def initialize_ui_components(self) -> None:
|
| | """
|
| | Initializes the UI components for the training settings.
|
| | """
|
| |
|
| | self.init_training_controls()
|
| |
|
| | self.init_precision_and_resources_controls()
|
| |
|
| | self.init_lr_and_optimizer_controls()
|
| |
|
| | self.init_grad_and_lr_controls()
|
| |
|
| | self.init_learning_rate_controls()
|
| |
|
| | self.init_scheduler_controls()
|
| |
|
| | self.init_resolution_and_bucket_controls()
|
| |
|
| | self.setup_sdxl_checkbox_behavior()
|
| |
|
| | def init_training_controls(self) -> None:
|
| | """
|
| | Initializes the training controls for the model.
|
| | """
|
| |
|
| | with gr.Row():
|
| |
|
| | self.train_batch_size = gr.Slider(
|
| | minimum=1,
|
| | maximum=64,
|
| | label="Train batch size",
|
| | value=1,
|
| | step=self.config.get("basic.train_batch_size", 1),
|
| | )
|
| |
|
| | self.epoch = gr.Number(
|
| | label="Epoch", value=self.config.get("basic.epoch", 1), precision=0
|
| | )
|
| |
|
| | self.max_train_epochs = gr.Number(
|
| | label="Max train epoch",
|
| | info="training epochs (overrides max_train_steps). 0 = no override",
|
| | step=1,
|
| |
|
| | minimum=0,
|
| | value=self.config.get("basic.max_train_epochs", 0),
|
| | )
|
| |
|
| | self.max_train_steps = gr.Number(
|
| | label="Max train steps",
|
| | info="Overrides # training steps. 0 = no override",
|
| | step=1,
|
| |
|
| | value=self.config.get("basic.max_train_steps", 1600),
|
| | )
|
| |
|
| | self.save_every_n_epochs = gr.Number(
|
| | label="Save every N epochs",
|
| | value=self.config.get("basic.save_every_n_epochs", 1),
|
| | precision=0,
|
| | )
|
| |
|
| | self.caption_extension = gr.Dropdown(
|
| | label="Caption file extension",
|
| | choices=["", ".cap", ".caption", ".txt"],
|
| | value=".txt",
|
| | interactive=True,
|
| | )
|
| |
|
| | def init_precision_and_resources_controls(self) -> None:
|
| | """
|
| | Initializes the precision and resources controls for the model.
|
| | """
|
| | with gr.Row():
|
| |
|
| | self.seed = gr.Number(
|
| | label="Seed",
|
| |
|
| | step=1,
|
| | minimum=0,
|
| | value=self.config.get("basic.seed", 0),
|
| | info="Set to 0 to make random",
|
| | )
|
| |
|
| | self.cache_latents = gr.Checkbox(
|
| | label="Cache latents",
|
| | value=self.config.get("basic.cache_latents", True),
|
| | )
|
| |
|
| | self.cache_latents_to_disk = gr.Checkbox(
|
| | label="Cache latents to disk",
|
| | value=self.config.get("basic.cache_latents_to_disk", False),
|
| | )
|
| |
|
| | def init_lr_and_optimizer_controls(self) -> None:
|
| | """
|
| | Initializes the learning rate and optimizer controls for the model.
|
| | """
|
| | with gr.Row():
|
| |
|
| | self.lr_scheduler = gr.Dropdown(
|
| | label="LR Scheduler",
|
| | choices=[
|
| | "adafactor",
|
| | "constant",
|
| | "constant_with_warmup",
|
| | "cosine",
|
| | "cosine_with_restarts",
|
| | "linear",
|
| | "polynomial",
|
| | ],
|
| | value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | self.optimizer = gr.Dropdown(
|
| | label="Optimizer",
|
| | choices=[
|
| | "AdamW",
|
| | "AdamW8bit",
|
| | "Adafactor",
|
| | "DAdaptation",
|
| | "DAdaptAdaGrad",
|
| | "DAdaptAdam",
|
| | "DAdaptAdan",
|
| | "DAdaptAdanIP",
|
| | "DAdaptAdamPreprint",
|
| | "DAdaptLion",
|
| | "DAdaptSGD",
|
| | "Lion",
|
| | "Lion8bit",
|
| | "PagedAdamW8bit",
|
| | "PagedAdamW32bit",
|
| | "PagedLion8bit",
|
| | "Prodigy",
|
| | "SGDNesterov",
|
| | "SGDNesterov8bit",
|
| | ],
|
| | value=self.config.get("basic.optimizer", "AdamW8bit"),
|
| | interactive=True,
|
| | )
|
| |
|
| | def init_grad_and_lr_controls(self) -> None:
|
| | """
|
| | Initializes the gradient and learning rate controls for the model.
|
| | """
|
| | with gr.Row():
|
| |
|
| | self.max_grad_norm = gr.Slider(
|
| | label="Max grad norm",
|
| | value=self.config.get("basic.max_grad_norm", 1.0),
|
| | minimum=0.0,
|
| | maximum=1.0,
|
| | interactive=True,
|
| | )
|
| |
|
| | self.lr_scheduler_args = gr.Textbox(
|
| | label="LR scheduler extra arguments",
|
| | lines=2,
|
| | placeholder="(Optional) eg: milestones=[1,10,30,50] gamma=0.1",
|
| | value=self.config.get("basic.lr_scheduler_args", ""),
|
| | )
|
| |
|
| | self.optimizer_args = gr.Textbox(
|
| | label="Optimizer extra arguments",
|
| | lines=2,
|
| | placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True",
|
| | value=self.config.get("basic.optimizer_args", ""),
|
| | )
|
| |
|
| | def init_learning_rate_controls(self) -> None:
|
| | """
|
| | Initializes the learning rate controls for the model.
|
| | """
|
| | with gr.Row():
|
| |
|
| | lr_label = (
|
| | "Learning rate Unet"
|
| | if self.finetuning or self.dreambooth
|
| | else "Learning rate"
|
| | )
|
| |
|
| | self.learning_rate = gr.Number(
|
| | label=lr_label,
|
| | value=self.config.get("basic.learning_rate", self.learning_rate_value),
|
| | minimum=0,
|
| | maximum=1,
|
| | info="Set to 0 to not train the Unet",
|
| | )
|
| |
|
| | self.learning_rate_te = gr.Number(
|
| | label="Learning rate TE",
|
| | value=self.config.get(
|
| | "basic.learning_rate_te", self.learning_rate_value
|
| | ),
|
| | visible=self.finetuning or self.dreambooth,
|
| | minimum=0,
|
| | maximum=1,
|
| | info="Set to 0 to not train the Text Encoder",
|
| | )
|
| |
|
| | self.learning_rate_te1 = gr.Number(
|
| | label="Learning rate TE1",
|
| | value=self.config.get(
|
| | "basic.learning_rate_te1", self.learning_rate_value
|
| | ),
|
| | visible=False,
|
| | minimum=0,
|
| | maximum=1,
|
| | info="Set to 0 to not train the Text Encoder 1",
|
| | )
|
| |
|
| | self.learning_rate_te2 = gr.Number(
|
| | label="Learning rate TE2",
|
| | value=self.config.get(
|
| | "basic.learning_rate_te2", self.learning_rate_value
|
| | ),
|
| | visible=False,
|
| | minimum=0,
|
| | maximum=1,
|
| | info="Set to 0 to not train the Text Encoder 2",
|
| | )
|
| |
|
| | self.lr_warmup = gr.Slider(
|
| | label="LR warmup (% of total steps)",
|
| | value=self.config.get("basic.lr_warmup", self.lr_warmup_value),
|
| | minimum=0,
|
| | maximum=100,
|
| | step=1,
|
| | )
|
| |
|
| | def lr_scheduler_changed(scheduler, value):
|
| | if scheduler == "constant":
|
| | self.old_lr_warmup = value
|
| | value = 0
|
| | interactive=False
|
| | info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
|
| | else:
|
| | if self.old_lr_warmup != 0:
|
| | value = self.old_lr_warmup
|
| | self.old_lr_warmup = 0
|
| | interactive=True
|
| | info=""
|
| | return gr.Slider(value=value, interactive=interactive, info=info)
|
| |
|
| | self.lr_scheduler.change(
|
| | lr_scheduler_changed,
|
| | inputs=[self.lr_scheduler, self.lr_warmup],
|
| | outputs=self.lr_warmup,
|
| | )
|
| |
|
| | def init_scheduler_controls(self) -> None:
|
| | """
|
| | Initializes the scheduler controls for the model.
|
| | """
|
| | with gr.Row(visible=not self.finetuning):
|
| |
|
| | self.lr_scheduler_num_cycles = gr.Number(
|
| | label="LR # cycles",
|
| | minimum=1,
|
| |
|
| | step=1,
|
| | info="Number of restarts for cosine scheduler with restarts",
|
| | value=self.config.get("basic.lr_scheduler_num_cycles", 1),
|
| | )
|
| |
|
| | self.lr_scheduler_power = gr.Number(
|
| | label="LR power",
|
| | minimum=0.0,
|
| | step=0.01,
|
| | info="Polynomial power for polynomial scheduler",
|
| | value=self.config.get("basic.lr_scheduler_power", 1.0),
|
| | )
|
| |
|
| | def init_resolution_and_bucket_controls(self) -> None:
|
| | """
|
| | Initializes the resolution and bucket controls for the model.
|
| | """
|
| | with gr.Row(visible=not self.finetuning):
|
| |
|
| | self.max_resolution = gr.Textbox(
|
| | label="Max resolution",
|
| | value=self.config.get("basic.max_resolution", "512,512"),
|
| | placeholder="512,512",
|
| | )
|
| |
|
| | self.stop_text_encoder_training = gr.Slider(
|
| | minimum=-1,
|
| | maximum=100,
|
| | value=self.config.get("basic.stop_text_encoder_training", 0),
|
| | step=1,
|
| | label="Stop TE (% of total steps)",
|
| | )
|
| |
|
| | self.enable_bucket = gr.Checkbox(
|
| | label="Enable buckets",
|
| | value=self.config.get("basic.enable_bucket", True),
|
| | )
|
| |
|
| | self.min_bucket_reso = gr.Slider(
|
| | label="Minimum bucket resolution",
|
| | value=self.config.get("basic.min_bucket_reso", 256),
|
| | minimum=64,
|
| | maximum=4096,
|
| | step=64,
|
| | info="Minimum size in pixel a bucket can be (>= 64)",
|
| | )
|
| |
|
| | self.max_bucket_reso = gr.Slider(
|
| | label="Maximum bucket resolution",
|
| | value=self.config.get("basic.max_bucket_reso", 2048),
|
| | minimum=64,
|
| | maximum=4096,
|
| | step=64,
|
| | info="Maximum size in pixel a bucket can be (>= 64)",
|
| | )
|
| |
|
| | def setup_sdxl_checkbox_behavior(self) -> None:
|
| | """
|
| | Sets up the behavior of the SDXL checkbox based on the finetuning and dreambooth flags.
|
| | """
|
| | self.sdxl_checkbox.change(
|
| | self.update_learning_rate_te,
|
| | inputs=[
|
| | self.sdxl_checkbox,
|
| | gr.Checkbox(value=self.finetuning, visible=False),
|
| | gr.Checkbox(value=self.dreambooth, visible=False),
|
| | ],
|
| | outputs=[
|
| | self.learning_rate_te,
|
| | self.learning_rate_te1,
|
| | self.learning_rate_te2,
|
| | ],
|
| | )
|
| |
|
| | def update_learning_rate_te(
|
| | self,
|
| | sdxl_checkbox: gr.Checkbox,
|
| | finetuning: bool,
|
| | dreambooth: bool,
|
| | ) -> Tuple[gr.Number, gr.Number, gr.Number]:
|
| | """
|
| | Updates the visibility of the learning rate TE, TE1, and TE2 based on the SDXL checkbox and finetuning/dreambooth flags.
|
| |
|
| | Args:
|
| | sdxl_checkbox (gr.Checkbox): The SDXL checkbox.
|
| | finetuning (bool): Whether finetuning is enabled.
|
| | dreambooth (bool): Whether dreambooth is enabled.
|
| |
|
| | Returns:
|
| | Tuple[gr.Number, gr.Number, gr.Number]: A tuple containing the updated visibility for learning rate TE, TE1, and TE2.
|
| | """
|
| |
|
| | visibility_condition = finetuning or dreambooth
|
| |
|
| | return (
|
| | gr.Number(visible=(not sdxl_checkbox and visibility_condition)),
|
| | gr.Number(visible=(sdxl_checkbox and visibility_condition)),
|
| | gr.Number(visible=(sdxl_checkbox and visibility_condition)),
|
| | )
|
| |
|