| | import argparse |
| | import logging |
| | import os |
| | import shutil |
| |
|
| | import accelerate |
| | import torch |
| |
|
| | from utils import huggingface_utils |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | |
| | EPOCH_STATE_NAME = "{}-{:06d}-state" |
| | EPOCH_FILE_NAME = "{}-{:06d}" |
| | EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" |
| | LAST_STATE_NAME = "{}-state" |
| | STEP_STATE_NAME = "{}-step{:08d}-state" |
| | STEP_FILE_NAME = "{}-step{:08d}" |
| | STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" |
| |
|
| |
|
| | def get_sanitized_config_or_none(args: argparse.Namespace): |
| | |
| | |
| | |
| |
|
| | if not args.log_config: |
| | return None |
| |
|
| | sensitive_args = ["wandb_api_key", "huggingface_token"] |
| | sensitive_path_args = [ |
| | "dit", |
| | "vae", |
| | "text_encoder1", |
| | "text_encoder2", |
| | "base_weights", |
| | "network_weights", |
| | "output_dir", |
| | "logging_dir", |
| | ] |
| | filtered_args = {} |
| | for k, v in vars(args).items(): |
| | |
| | if k not in sensitive_args + sensitive_path_args: |
| | |
| | if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): |
| | filtered_args[k] = v |
| | |
| | elif isinstance(v, list): |
| | filtered_args[k] = f"{v}" |
| | |
| | elif isinstance(v, object): |
| | filtered_args[k] = f"{v}" |
| |
|
| | return filtered_args |
| |
|
| |
|
| | class LossRecorder: |
| | def __init__(self): |
| | self.loss_list: list[float] = [] |
| | self.loss_total: float = 0.0 |
| |
|
| | def add(self, *, epoch: int, step: int, loss: float) -> None: |
| | if epoch == 0: |
| | self.loss_list.append(loss) |
| | else: |
| | while len(self.loss_list) <= step: |
| | self.loss_list.append(0.0) |
| | self.loss_total -= self.loss_list[step] |
| | self.loss_list[step] = loss |
| | self.loss_total += loss |
| |
|
| | @property |
| | def moving_average(self) -> float: |
| | return self.loss_total / len(self.loss_list) |
| |
|
| |
|
| | def get_epoch_ckpt_name(model_name, epoch_no: int): |
| | return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors" |
| |
|
| |
|
| | def get_step_ckpt_name(model_name, step_no: int): |
| | return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors" |
| |
|
| |
|
| | def get_last_ckpt_name(model_name): |
| | return model_name + ".safetensors" |
| |
|
| |
|
| | def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): |
| | if args.save_last_n_epochs is None: |
| | return None |
| |
|
| | remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs |
| | if remove_epoch_no < 0: |
| | return None |
| | return remove_epoch_no |
| |
|
| |
|
| | def get_remove_step_no(args: argparse.Namespace, step_no: int): |
| | if args.save_last_n_steps is None: |
| | return None |
| |
|
| | |
| | |
| | remove_step_no = step_no - args.save_last_n_steps - 1 |
| | remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) |
| | if remove_step_no < 0: |
| | return None |
| | return remove_step_no |
| |
|
| |
|
| | def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int): |
| | model_name = args.output_name |
| |
|
| | logger.info("") |
| | logger.info(f"saving state at epoch {epoch_no}") |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) |
| | accelerator.save_state(state_dir) |
| | if args.save_state_to_huggingface: |
| | logger.info("uploading state to huggingface.") |
| | huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) |
| |
|
| | last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs |
| | if last_n_epochs is not None: |
| | remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs |
| | state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) |
| | if os.path.exists(state_dir_old): |
| | logger.info(f"removing old state: {state_dir_old}") |
| | shutil.rmtree(state_dir_old) |
| |
|
| |
|
| | def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int): |
| | model_name = args.output_name |
| |
|
| | logger.info("") |
| | logger.info(f"saving state at step {step_no}") |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) |
| | accelerator.save_state(state_dir) |
| | if args.save_state_to_huggingface: |
| | logger.info("uploading state to huggingface.") |
| | huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) |
| |
|
| | last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps |
| | if last_n_steps is not None: |
| | |
| | remove_step_no = step_no - last_n_steps - 1 |
| | remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) |
| |
|
| | if remove_step_no > 0: |
| | state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) |
| | if os.path.exists(state_dir_old): |
| | logger.info(f"removing old state: {state_dir_old}") |
| | shutil.rmtree(state_dir_old) |
| |
|
| |
|
| | def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator): |
| | model_name = args.output_name |
| |
|
| | logger.info("") |
| | logger.info("saving last state.") |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) |
| | accelerator.save_state(state_dir) |
| |
|
| | if args.save_state_to_huggingface: |
| | logger.info("uploading last state to huggingface.") |
| | huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) |
| |
|
| |
|