import os import gc import lpips import clip import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.utils import set_seed from PIL import Image from torchvision import transforms from tqdm.auto import tqdm import copy import diffusers from diffusers.utils.import_utils import is_xformers_available from diffusers.optimization import get_scheduler import wandb from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats import sys sys.path.append("GDPOSR") from modelfile.GDPOSR import VSD, NAOSD from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset from pathlib import Path from accelerate.utils import set_seed, ProjectConfiguration from accelerate import DistributedDataParallelKwargs sys.path.append('GDPOSR') from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix from diffusers.training_utils import compute_snr from diffusers import DDPMScheduler, AutoencoderKL from ram.models.ram_lora import ram from ram import inference_ram as inference def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, kwargs_handlers=[ddp_kwargs], ) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() if args.seed is not None: set_seed(args.seed) if accelerator.is_main_process: os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) net_pix2pix = NAOSD(args) net_pix2pix.set_train() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): net_pix2pix.unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available, please install it by running `pip install xformers`") if args.gradient_checkpointing: net_pix2pix.unet.enable_gradient_checkpointing() if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True # init vsd model net_disc = VSD(args=args, accelerator=accelerator) net_disc.set_train() net_lpips = lpips.LPIPS(net='vgg').cuda() net_lpips.requires_grad_(False) # # set adapter if args.use_vae_encode_lora and (not args.use_vae_decode_lora): print('==== Use Lora at VAE Encoder ====') net_pix2pix.vae.set_adapter(['default_encoder']) elif (not args.use_vae_encode_lora) and args.use_vae_decode_lora: print('==== Use Lora at VAE Decoder ====') net_pix2pix.vae.set_adapter(['default_decoder']) elif args.use_vae_encode_lora and args.use_vae_decode_lora: print('==== Use Lora at VAE En&Decoder ====') net_pix2pix.vae.set_adapter(['default_encoder', 'default_decoder']) else: print('==== Use Fix VAE ====') net_pix2pix.vae.disable_adapters() net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others']) # make the optimizer layers_to_opt = [] for n, _p in net_pix2pix.unet.named_parameters(): if "lora" in n: assert _p.requires_grad layers_to_opt.append(_p) layers_to_opt += list(net_pix2pix.unet.conv_in.parameters()) for n, _p in net_pix2pix.vae.named_parameters(): if "lora" in n: # assert _p.requires_grad layers_to_opt.append(_p) optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,) lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power,) layers_to_opt_disc = [] for n, _p in net_disc.unet_update.named_parameters(): if "lora" in n: assert _p.requires_grad layers_to_opt_disc.append(_p) optimizer_disc = torch.optim.AdamW(layers_to_opt_disc, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,) lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power) # make the dataloader dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args) dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args) dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) # init RAM ram_transforms = transforms.Compose([ transforms.Resize((384, 384)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth', pretrained_condition=None, image_size=384, vit='swin_l') RAM.eval() RAM.to("cuda", dtype=torch.float16) # Prepare everything with our `accelerator`. net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare( net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc ) net_lpips = accelerator.prepare(net_lpips) # renorm with image net statistics weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", disable=not accelerator.is_local_main_process,) # start the training loop global_step = 0 for epoch in range(0, args.num_training_epochs): for step, batch in enumerate(dl_train): l_acc = [net_pix2pix, net_disc] with accelerator.accumulate(*l_acc): x_src = batch["LR"] x_tgt = batch["HR"] B, C, H, W = x_src.shape # image description x_tgt_ram = ram_transforms(x_tgt*0.5+0.5) caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM) with torch.no_grad(): positive_prompt = [] negative_prompt = [] for i in range(B): ram_image = x_tgt[i,:,:,:].unsqueeze(0) x_tgt_ram = ram_transforms(ram_image*0.5+0.5) caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM) positive_prompt.append(f'{caption[0]}, {args.positive_prompt}') negative_prompt.append(args.negative_prompt) # forward pass x_tgt_pred, latents_pred, prompt_embeds, neg_prompt_embeds, noise = net_pix2pix(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args) # Reconstruction loss loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2 loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips loss = loss_l2 + loss_lpips # KL loss if torch.cuda.device_count() > 1: loss_kl = net_disc.module.distribution_matching_loss(net_disc.module.unet_fix, net_disc.module.unet_update, net_disc.module.sched, latents_pred, prompt_embeds, neg_prompt_embeds, args, ) * args.lambda_vsd else: loss_kl = net_disc.distribution_matching_loss(net_disc.unet_fix, net_disc.unet_update, net_disc.sched, latents_pred, prompt_embeds, neg_prompt_embeds, args, ) * args.lambda_vsd loss = loss + loss_kl accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) """ Disc loss: let lora model closed to generator """ if torch.cuda.device_count() > 1: loss_d = net_disc.module.compute_lora_loss(latents_pred, prompt_embeds, args)*args.lambda_vsd_lora else: loss_d = net_disc.compute_lora_loss(latents_pred, prompt_embeds, args)*args.lambda_vsd_lora accelerator.backward(loss_d) if accelerator.sync_gradients: accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm) optimizer_disc.step() lr_scheduler_disc.step() optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: logs = {} # log all the losses logs["loss_d"] = loss_d.detach().item() logs["loss_kl"] = loss_kl.detach().item() logs["loss_l2"] = loss_l2.detach().item() logs["loss_lpips"] = loss_lpips.detach().item() progress_bar.set_postfix(**logs) # checkpoint the model if global_step % args.checkpointing_steps == 1: outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") accelerator.unwrap_model(net_pix2pix).save_model(outf) accelerator.log(logs, step=global_step) if __name__ == "__main__": args = parse_args_realsr_training() main(args)