| | import os |
| | import gc |
| | import lpips |
| | import clip |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | import transformers |
| | from accelerate import Accelerator |
| | from accelerate.utils import set_seed |
| | from PIL import Image |
| | from torchvision import transforms |
| | from tqdm.auto import tqdm |
| | import copy |
| |
|
| | import diffusers |
| | from diffusers.utils.import_utils import is_xformers_available |
| | from diffusers.optimization import get_scheduler |
| |
|
| | import wandb |
| | from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats |
| | import sys |
| | sys.path.append("GDPOSR") |
| | from modelfile.GDPOSR import GDPOSR as GDPOSRModel |
| | from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset |
| |
|
| | from pathlib import Path |
| | from accelerate.utils import set_seed, ProjectConfiguration |
| | from accelerate import DistributedDataParallelKwargs |
| |
|
| | sys.path.append('GDPOSR') |
| | from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix |
| | from diffusers.training_utils import compute_snr |
| | from diffusers import DDPMScheduler, AutoencoderKL |
| | from GDPOSR.losses.grpo import AdaptiveReward as RewardFunction |
| |
|
| | from ram.models.ram_lora import ram |
| | from ram import inference_ram as inference |
| |
|
| |
|
| | def main(args): |
| | logging_dir = Path(args.output_dir, args.logging_dir) |
| | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| |
|
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | mixed_precision=args.mixed_precision, |
| | log_with=args.report_to, |
| | project_config=accelerator_project_config, |
| | kwargs_handlers=[ddp_kwargs], |
| | ) |
| |
|
| | if accelerator.is_local_main_process: |
| | transformers.utils.logging.set_verbosity_warning() |
| | diffusers.utils.logging.set_verbosity_info() |
| | else: |
| | transformers.utils.logging.set_verbosity_error() |
| | diffusers.utils.logging.set_verbosity_error() |
| |
|
| | if args.seed is not None: |
| | set_seed(args.seed) |
| |
|
| | if accelerator.is_main_process: |
| | os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) |
| |
|
| | net_pix2pix = GDPOSRModel(args) |
| | net_pix2pix.set_train() |
| |
|
| | if args.enable_xformers_memory_efficient_attention: |
| | if is_xformers_available(): |
| | net_pix2pix.unet.enable_xformers_memory_efficient_attention() |
| | else: |
| | raise ValueError("xformers is not available, please install it by running `pip install xformers`") |
| |
|
| | if args.gradient_checkpointing: |
| | net_pix2pix.unet.enable_gradient_checkpointing() |
| |
|
| | if args.allow_tf32: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | net_lpips = lpips.LPIPS(net='vgg').cuda() |
| | net_lpips.requires_grad_(False) |
| | net_ARF = RewardFunction() |
| | net_ARF.requires_grad_(False) |
| |
|
| | |
| | net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others']) |
| |
|
| | |
| | layers_to_opt = [] |
| | for n, _p in net_pix2pix.unet.named_parameters(): |
| | if "lora" in n: |
| | assert _p.requires_grad |
| | layers_to_opt.append(_p) |
| |
|
| | optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, |
| | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, |
| | eps=args.adam_epsilon,) |
| | lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, |
| | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, |
| | num_training_steps=args.max_train_steps * accelerator.num_processes, |
| | num_cycles=args.lr_num_cycles, power=args.lr_power,) |
| |
|
| | |
| | dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args) |
| | dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args) |
| | dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) |
| | dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) |
| |
|
| | |
| | ram_transforms = transforms.Compose([ |
| | transforms.Resize((384, 384)), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth', |
| | pretrained_condition=None, |
| | image_size=384, |
| | vit='swin_l') |
| | RAM.eval() |
| | RAM.to("cuda", dtype=torch.float16) |
| |
|
| | |
| | net_pix2pix, optimizer, dl_train, lr_scheduler = accelerator.prepare( |
| | net_pix2pix, optimizer, dl_train, lr_scheduler |
| | ) |
| | net_lpips, net_ARF = accelerator.prepare(net_lpips, net_ARF) |
| | |
| | weight_dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| |
|
| | |
| | |
| | if accelerator.is_main_process: |
| | tracker_config = dict(vars(args)) |
| | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) |
| |
|
| | progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", |
| | disable=not accelerator.is_local_main_process,) |
| |
|
| | |
| | global_step = 0 |
| | for epoch in range(0, args.num_training_epochs): |
| | for step, batch in enumerate(dl_train): |
| | with accelerator.accumulate(net_pix2pix): |
| | x_src = batch["LR"] |
| | x_tgt = batch["HR"] |
| | fedilty_ratio = batch["fedilty_ratio"] |
| | detail_ratio = batch["detail_ratio"] |
| |
|
| | B, C, H, W = x_src.shape |
| | |
| | x_tgt_ram = ram_transforms(x_tgt*0.5+0.5) |
| | caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM) |
| | with torch.no_grad(): |
| | positive_prompt = [] |
| | negative_prompt = [] |
| | for i in range(B): |
| | ram_image = x_tgt[i,:,:,:].unsqueeze(0) |
| | x_tgt_ram = ram_transforms(ram_image*0.5+0.5) |
| | caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM) |
| | positive_prompt.append(f'{caption[0]}, {args.positive_prompt}') |
| | negative_prompt.append(args.negative_prompt) |
| | |
| | if torch.cuda.device_count() > 1: |
| | sample_images, _, _ = net_pix2pix.module.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize) |
| | else: |
| | sample_images, _, _ = net_pix2pix.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize) |
| | |
| | x_tgt_re = x_tgt.unsqueeze(1).repeat(1,args.groupsize,1,1,1) |
| | rewards = net_ARF(sample_images, x_tgt_re, fedilty_ratio, detail_ratio) |
| | rewards = rewards.cuda() |
| | b_sample, g_sample, c_sample, h_sample, w_sample = sample_images.shape |
| | x_src_wl = sample_images.view(b_sample*g_sample, c_sample, h_sample, w_sample) |
| | ps_wl = [] |
| | nps_wl = [] |
| | for i in range(args.groupsize): |
| | ps_wl += positive_prompt |
| | nps_wl += negative_prompt |
| | |
| | x_tgt_pred, latents_pred, model_pred, prompt_embeds, neg_prompt_embeds, noise, ref_output_image, ref_x_denoised, ref_model_pred = net_pix2pix(x_src_wl, positive_prompt=ps_wl, negative_prompt=nps_wl, args=args) |
| | |
| | model_losses = (model_pred - noise).pow(2).mean(dim=[1,2,3]) |
| | |
| | model_losses = model_losses.view(b_sample, g_sample) |
| | model_losses = rewards * model_losses |
| | model_diff = model_losses.sum(1) |
| | |
| | ref_losses = (ref_model_pred - noise).pow(2).mean(dim=[1,2,3]) |
| | ref_losses = ref_losses.view(b_sample, g_sample) |
| | ref_losses = rewards * ref_losses |
| | ref_diff = ref_losses.sum(1) |
| | scale_term = -0.5 * 5000 |
| | inside_term = scale_term * (model_diff - ref_diff) |
| | implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) |
| | gdpo_loss = -1 * F.logsigmoid(inside_term).mean() |
| | loss = gdpo_loss |
| |
|
| | accelerator.backward(loss) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad(set_to_none=args.set_grads_to_none) |
| |
|
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | global_step += 1 |
| |
|
| | if accelerator.is_main_process: |
| | logs = {} |
| | |
| | logs["loss"] = gdpo_loss.detach().item() |
| | progress_bar.set_postfix(**logs) |
| |
|
| | |
| | if global_step % args.checkpointing_steps == 1: |
| | outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") |
| | accelerator.unwrap_model(net_pix2pix).save_model(outf) |
| |
|
| | accelerator.log(logs, step=global_step) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args_realsr_training() |
| | main(args) |
| |
|