import gc from shutil import ignore_patterns import argparse import json import sys import os import os.path as osp import datetime import shutil from typing import List # from PIL import Image import toml from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor import torch import torch.nn.functional as F # import torch.distributed as dist from accelerate.utils import set_seed from diffusers import DDPMScheduler, DDIMScheduler from accelerate import DistributedType from diffusers.utils import logging # from diffusers.models import AutoencoderKL import library.train_util as train_util import library.chinese_sdxl_train_util as chinese_sdxl_train_util # import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, ) from model_lib.nets.layers.ema import LitEma, load_litema, save_litema, ema_scope from removal.v1_2 import ( RemovalDataset, RemovalDataset_v1_2, load_cfg, build_removal_model, load_removal_model, ) from utils_train import ( build_accelerator, build_dataloader, build_vae, build_models, save, common_arguments, build_progress_bar ) from model_lib.nets.utils import CustomOutput from utils_infer import encode_clean_latents #, predict_noise import warnings warnings.filterwarnings("ignore", message="Grad strides do not match bucket view strides.*") warnings.filterwarnings("ignore", message="Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.") logger = logging.get_logger(__name__) # pylint: disable=invalid-name def build_teacher_model(args, weight_dtype, accelerator): teacher_cfg = load_cfg(args.teacher_config_path) teacher_model = build_removal_model(teacher_cfg, args.num_embeddings) if args.teacher_weight_path: accelerator.print(f"==> Loading teacher model from: {args.teacher_weight_path}") state_dict = torch.load(args.teacher_weight_path, map_location=accelerator.device) teacher_model.load_state_dict(state_dict) accelerator.print(f"weight_dtype:{weight_dtype}") if getattr(teacher_model, 'unet', None): accelerator.print(f"unet:{teacher_model.unet.dtype}") else: accelerator.print(f"diff_model:{teacher_model.diff_model.dtype}") teacher_model.requires_grad_(False).eval() teacher_model.to(accelerator.device, dtype=torch.float32) if accelerator.is_main_process: from pprint import pprint pprint("Teacher Model Config:") pprint(teacher_model.diff_model.config) # set xformer/mem_eff_attn accelerator.print(f"Enable memory efficient attention, mem_eff_attn:{args.mem_eff_attn}, xformers:{args.xformers}") chinese_sdxl_train_util.set_diffusers_xformers_flag(teacher_model.diff_model, True) return teacher_model def cal_KD_loss(pred: CustomOutput, target: CustomOutput, args): loss_dict = dict() # get feat KD loss from intermediate layers. if args.kl_feat_loss or args.mse_feat_loss: feat_loss_list = [] assert len(args.feat_index_S) == len(args.feat_loss_weight) assert len(args.feat_index_S) == len(args.feat_index_T) for _is, _it, _weight in zip(args.feat_index_S, args.feat_index_T, args.feat_loss_weight): feat_S, feat_T = pred.block_outputs[_is], target.block_outputs[_it] if args.kl_feat_loss: with torch.no_grad(): probs_T = torch.softmax(feat_T / args.kl_temp, dim=1) log_probs_S = torch.log_softmax(feat_S / args.kl_temp, dim=1) feat_loss = torch.nn.functional.kl_div(log_probs_S, probs_T, reduction='batchmean') elif args.mse_feat_loss: feat_loss = torch.nn.functional.mse_loss(feat_S, feat_T, reduction='mean') else: print("no available KD_loss type!") feat_loss_list.append(feat_loss * _weight) loss_dict["loss_featkd"] = sum(feat_loss_list) else: loss_dict["loss_featkd"] = 0 loss_outkd = torch.nn.functional.mse_loss(pred.sample.float(), target.sample.float(), reduction="mean") loss_dict["loss_outkd"] = loss_outkd loss_kd = sum([ v for k,v in loss_dict.items()]) return loss_kd, loss_dict def cal_task_loss(pred: CustomOutput, target: torch.Tensor, args): ''' refer to task loss between gt noise and student pred noise in SnapGen. ''' loss_dict = dict() if args.task_loss: loss_task = torch.nn.functional.mse_loss(pred.sample.float(), target.float(),reduction='mean') loss_dict['loss_task'] = loss_task else: loss_task = 0 loss_dict['loss_task'] = loss_task return loss_task, loss_dict def cal_elatentlpips_loss( pred: CustomOutput, target: torch.Tensor, encoder_model:torch.nn.Module, noise_scheduler, timesteps, noisy_latents, args = None): ''' refer to task loss between gt noise and student pred noise in SnapGen. ''' loss_dict = dict() if args.elatentlpips_loss: # Compute the perceptual distance between the two latent representations # Note: Set `normalize=True` if the latents (latent0 and latent1) are not already normalized # by `vae.config.scaling_factor` and `vae.config.shift_factor`. noise_pred = pred.sample noisy_latents_pred = torch.stack([ noise_scheduler.step(n, t, noisy_latent).pred_original_sample \ for (n, t, noisy_latent) in zip(noise_pred, timesteps, noisy_latents) ]) target_latents_pred = torch.stack([ noise_scheduler.step(tgt, t, noisy_latent).pred_original_sample \ for (tgt, t, noisy_latent) in zip(target.float(), timesteps, noisy_latents) ]) loss_elatentlpips = encoder_model(noisy_latents_pred, target_latents_pred, normalize=True, ensembling=True).mean() loss_dict['loss_elatentlpips'] = loss_elatentlpips else: loss_elatentlpips = 0 loss_dict['loss_elatentlpips'] = loss_elatentlpips return loss_elatentlpips, loss_dict def cal_adaptive_weights_type8(featkd_loss, task_loss, outkd_loss, elatentlpips_loss, last_featkd_layer=None, outkd_layer=None): assert last_featkd_layer is not None, "need last_featkd_layer's parameter to get gradient" assert outkd_layer is not None, "need outkd_layer's parameter to get gradient" from torch.autograd import grad as get_grad from torch import norm as get_norm feat_grad_featkd = get_grad(featkd_loss, last_featkd_layer, retain_graph=True)[0] feat_grad_outkd = get_grad(outkd_loss, last_featkd_layer, retain_graph=True)[0] feat_grad_task = get_grad(task_loss, last_featkd_layer, retain_graph=True)[0] feat_grad_elatentlpips = get_grad(elatentlpips_loss, last_featkd_layer, retain_graph=True)[0] out_grad_outkd = get_grad(outkd_loss, outkd_layer, retain_graph=True)[0] out_grad_task = get_grad(task_loss, outkd_layer, retain_graph=True)[0] out_grad_elatentlpips = get_grad(elatentlpips_loss, outkd_layer, retain_graph=True)[0] out_weight_outkd = get_norm(out_grad_task) / (get_norm(out_grad_outkd) + 1e-6) out_weight_outkd = torch.clamp(out_weight_outkd, 0.0, 1e6).detach() out_weight_elatentlpips = get_norm(out_grad_task) / (get_norm(out_grad_elatentlpips) + 1e-6) out_weight_elatentlpips = torch.clamp(out_weight_elatentlpips, 0.0, 1e6).detach() feat_weight_task = get_norm(feat_grad_featkd) / (get_norm(feat_grad_task) + 1e-4) feat_weight_task = torch.clamp(feat_weight_task, 0.0, 1e4).detach() return feat_weight_task, out_weight_outkd, out_weight_elatentlpips, \ get_norm(feat_grad_featkd), get_norm(feat_grad_task), get_norm(feat_grad_outkd), get_norm(feat_grad_elatentlpips), \ get_norm(out_grad_task), get_norm(out_grad_outkd), get_norm(out_grad_elatentlpips) def train_distillation(args): chinese_sdxl_train_util.verify_sdxl_training_args(args,False) if args.seed is not None: set_seed(args.seed) accelerator = build_accelerator(args, fsdp_plugin=None) weight_dtype, save_dtype = train_util.prepare_dtype(args) student, vae = build_models(args, weight_dtype, accelerator) teacher = build_teacher_model(args, weight_dtype, accelerator) del vae.decoder # cause not need docoder in current training paradigm # EMA if args.use_model_ema: student_ema = LitEma(student, decay=args.ema_decay).to(accelerator.device) accelerator.print(f"Keeping EMAs of {len(list(student_ema.buffers()))}.") # torch.compile if args.use_compile: student.compile(backend="cudagraphs", fullgraph=False, dynamic=False) noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if args.resume_from_ckpt: accelerator.print(f"==> resume ckpt from : {args.resume_from_ckpt}") msg = load_removal_model( student, args.resume_from_ckpt, accelerator.device, strict=True ) accelerator.print(f'load state dict {msg}') if args.use_model_ema: load_path = args.resume_from_ckpt.replace( "diffusion_pytorch_model.bin", "diffusion_pytorch_model.EMA.bin" ) load_litema(student_ema, load_path, map_location=accelerator.device) accelerator.print(f'load EMA state dict {msg}') # E-Latent-LPIPS if args.elatentlpips_loss: from elatentlpips import ELatentLPIPS # Initialize E-LatentLPIPS with the specified encoder model (options: sd15, sd21, sdxl, sd3, flux) # The 'augment' parameter can be set to one of the following: b, bg, bgc, bgco elatentlpips_model = ELatentLPIPS(encoder="sdxl", augment="bg").eval() elatentlpips_model = accelerator.prepare(elatentlpips_model) else: elatentlpips_model = None # training_models training_models = [] params_to_optimize = [] named_params_to_optimize = [] training_models.append(student) params_to_optimize.append({"params": list(student.parameters()), "lr": args.learning_rate}) named_params_to_optimize.append({"params": list(student.named_parameters()), "lr": args.learning_rate}) n_params = 0 for params in params_to_optimize: for p in params["params"]: n_params += p.numel() accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, named_trainable_params=named_params_to_optimize) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) executor = ThreadPoolExecutor(max_workers=1) student, optimizer, lr_scheduler = accelerator.prepare( student, optimizer, lr_scheduler ) teacher = accelerator.prepare(teacher) if accelerator.is_main_process: init_kwargs = {} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) loss_total = 0 accumulate_loss = 0 for m in training_models: m.train() dataset_class = eval(args.data_type) train_dataloader, _ = build_dataloader(args, dataset_class, accelerator) global_step = args.global_step pbar = build_progress_bar( range(args.max_train_steps), args.global_step, disable=not accelerator.is_local_main_process) for step in range(args.global_step, args.max_train_steps): with accelerator.accumulate(training_models[0]): batch = next(train_dataloader) latents, masked_image_latents = encode_clean_latents(batch, vae, weight_dtype, accelerator) # resize mask masks = batch["masks"] h, w = masks.shape[-2:] vae_ds_ratio = 2 ** (len(vae.config.block_out_channels) - 1) size = (h // vae_ds_ratio, w // vae_ds_ratio) resized_masks = F.interpolate(masks, size=size).to(accelerator.device, dtype=weight_dtype) # Sample noise noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # Predict the noise residual with accelerator.autocast(): latent_model_input = torch.cat([ noisy_latents, resized_masks, masked_image_latents], dim=1) pred_S = student( latent_model_input, timesteps=timesteps, input_ids=batch["input_ids"]) pred_T = teacher( latent_model_input, timesteps=timesteps, input_ids=batch["input_ids"]) # target = noise loss_kd, loss_dict_kd = cal_KD_loss(pred_S, pred_T, args) loss_task, loss_dict_task = cal_task_loss(pred_S, noise, args) loss_elatentlpips, loss_dict_elatentlpips = cal_elatentlpips_loss( pred_S, noise, elatentlpips_model, noise_scheduler = noise_scheduler, timesteps = timesteps, noisy_latents = noisy_latents, args = args) loss_dict = loss_dict_kd | loss_dict_task | loss_dict_elatentlpips raw_student = accelerator.unwrap_model(student) feat_weight_task, \ out_weight_outkd, \ out_weight_elatentlpips, \ feat_gnorm_featkd, \ feat_gnorm_task, \ feat_gnorm_outkd, \ feat_gnorm_elatentlpips, \ out_gnorm_task, \ out_gnorm_outkd, \ out_gnorm_elatentlpips = cal_adaptive_weights_type8( loss_dict["loss_featkd"], loss_dict["loss_task"], loss_dict["loss_outkd"], loss_dict["loss_elatentlpips"], last_featkd_layer = raw_student.diff_model.down_blocks[2].attentions[1].proj_out.weight, outkd_layer = raw_student.diff_model.conv_out.conv_pw.weight) loss = loss_dict["loss_featkd"] * args.KD_loss_weight \ + feat_weight_task * ( \ loss_dict["loss_task"] * args.task_loss_weight \ + loss_dict["loss_outkd"] * out_weight_outkd * args.KD_loss_weight \ + loss_dict["loss_elatentlpips"] * out_weight_elatentlpips * args.elatentlpips_loss_weight) accelerator.backward(loss) if args.max_grad_norm != 0.0: grad_norm = accelerator.clip_grad_norm_( student.parameters(), args.max_grad_norm).item() optimizer.step() if args.use_model_ema: raw_student = accelerator.unwrap_model(student) student_ema(accelerator.unwrap_model(raw_student)) lr_scheduler.step() optimizer.zero_grad() current_loss = loss.detach() accumulate_loss += current_loss # logging if accelerator.sync_gradients: loss_total += accumulate_loss #current_loss logs = { "avr_loss": loss_total.item() / (step + 1 - args.global_step), "loss": accumulate_loss.item() / accelerator.gradient_accumulation_steps, #current_loss, "lr": float(lr_scheduler.get_last_lr()[0]), "grad_norm": grad_norm, 'global_step': global_step, "feat_gnorm_featkd": feat_gnorm_featkd.item(), "feat_gnorm_task": feat_gnorm_task.item(), "feat_gnorm_outkd": feat_gnorm_outkd.item(), "feat_gnorm_elatentlpips": feat_gnorm_elatentlpips.item(), "out_gnorm_task": out_gnorm_task.item(), "out_gnorm_outkd": out_gnorm_outkd.item(), "out_gnorm_elatentlpips": out_gnorm_elatentlpips.item(), "feat_weight_task": feat_weight_task.item(), "out_weight_outkd": out_weight_outkd.item(), "out_weight_elatentlpips": out_weight_elatentlpips.item() } logs |= { k:v.item() for k,v in loss_dict.items()} pbar.set_postfix(**logs, refresh=False) if args.logging_dir: tb_logs = logs | {"rank": accelerator.process_index,} executor.submit(accelerator.log, tb_logs, step=global_step) accumulate_loss = 0 # save model by step if (global_step != args.global_step \ and args.save_every_n_steps \ and global_step % args.save_every_n_steps == 0): save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.bin") save(student, save_path, accelerator) if args.use_model_ema: save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.EMA.bin") if accelerator.is_main_process: save_litema(student_ema, save_path) accelerator.print(f"d[info]: EMA Model saved at: {save_path}\n") pbar.update() global_step += 1 # save the final model save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.bin") save(student, save_path, accelerator) if args.use_model_ema: save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.EMA.bin") save_litema(student_ema, save_path) accelerator.print(f"d[info]: EMA Model saved at: {save_path}\n") accelerator.wait_for_everyone() accelerator.end_training() def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) # config_util.add_config_arguments(parser) common_arguments(parser) '''add teacher config path''' parser.add_argument('--teacher_config_path', type=str, default=None) parser.add_argument('--teacher_weight_path', type=str, default=None) parser.add_argument('--kl_feat_loss', action='store_true', help='enable KLDivLoss for feat and output KD.') parser.add_argument('--kl_tempeature', type=float, default=1.0, dest = 'kl_temp', help='temperature for the smoothment of soft label feature.') parser.add_argument('--mse_feat_loss', action='store_true', help='enable MSELoss for feat and output KD.') parser.add_argument('--feat_index_T', nargs='*', type=int, default=[4,], help='index list of Teacher intermediate feautures for KD.') parser.add_argument('--feat_index_S', nargs='*', type=int, default=[4,], help='index list of Student intermediate feautures for KD.') parser.add_argument('--feat_loss_weight', nargs='*',type=float, default=[0.2,], help='loss weights of intermediate feautures for KD.') parser.add_argument('--task_loss', action='store_true', help='enable MSELoss for output and gt_noise.') parser.add_argument('--task_loss_weight', type=float, default=1.0, help='weight multiplied to loss_task.') parser.add_argument('--KD_loss_weight', type=float, default=1.0, help='weight multiplied to loss_kd.') parser.add_argument('--elatentlpips_loss', action='store_true', help='enable MSELoss for output and gt_noise.') parser.add_argument('--elatentlpips_loss_weight', type=float, default=1.0, help='weight multiplied to loss_task.') parser.add_argument('--use_model_ema', action='store_true', help='enable EMA on training model.') parser.add_argument('--ema_decay', type=float, default=0.9999) parser.add_argument('--use_compile', action='store_true', help='use torch.compile on foward & backward.') # datatype parser.add_argument('--data_type', type=str, default="RemovalDataset", choices=['RemovalDataset', 'RemovalDataset_v1_2'], help='different mask assignment strategy.') return parser if __name__ == "__main__": # timeout_seconds = 1800 # timeout_timedelta = datetime.timedelta(seconds=timeout_seconds) # torch.distributed.init_process_group(backend='nccl', timeout=timeout_timedelta) import torch._dynamo torch._dynamo.config.suppress_errors = True parser = setup_parser() args = parser.parse_args() args = train_util.read_config_from_file(args, parser) train_distillation(args)