GDPO / GDPOSR /train /train_GDPOSR.py
Joypop's picture
Add model weights
c3e16bb verified
import os
import gc
import lpips
import clip
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import copy
import diffusers
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler
import wandb
from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats
import sys
sys.path.append("GDPOSR")
from modelfile.GDPOSR import GDPOSR as GDPOSRModel
from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset
from pathlib import Path
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate import DistributedDataParallelKwargs
sys.path.append('GDPOSR')
from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
from diffusers.training_utils import compute_snr
from diffusers import DDPMScheduler, AutoencoderKL
from GDPOSR.losses.grpo import AdaptiveReward as RewardFunction
from ram.models.ram_lora import ram
from ram import inference_ram as inference
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[ddp_kwargs],
)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
net_pix2pix = GDPOSRModel(args)
net_pix2pix.set_train()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_pix2pix.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
if args.gradient_checkpointing:
net_pix2pix.unet.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
net_lpips = lpips.LPIPS(net='vgg').cuda()
net_lpips.requires_grad_(False)
net_ARF = RewardFunction()
net_ARF.requires_grad_(False)
# # set adapter
net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others'])
# make the optimizer
layers_to_opt = []
for n, _p in net_pix2pix.unet.named_parameters():
if "lora" in n:
assert _p.requires_grad
layers_to_opt.append(_p)
optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, power=args.lr_power,)
# make the dataloader
dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args)
dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args)
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
# init RAM
ram_transforms = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth',
pretrained_condition=None,
image_size=384,
vit='swin_l')
RAM.eval()
RAM.to("cuda", dtype=torch.float16)
# Prepare everything with our `accelerator`.
net_pix2pix, optimizer, dl_train, lr_scheduler = accelerator.prepare(
net_pix2pix, optimizer, dl_train, lr_scheduler
)
net_lpips, net_ARF = accelerator.prepare(net_lpips, net_ARF)
# renorm with image net statistics
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
disable=not accelerator.is_local_main_process,)
# start the training loop
global_step = 0
for epoch in range(0, args.num_training_epochs):
for step, batch in enumerate(dl_train):
with accelerator.accumulate(net_pix2pix):
x_src = batch["LR"]
x_tgt = batch["HR"]
fedilty_ratio = batch["fedilty_ratio"]
detail_ratio = batch["detail_ratio"]
B, C, H, W = x_src.shape
# image description
x_tgt_ram = ram_transforms(x_tgt*0.5+0.5)
caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
with torch.no_grad():
positive_prompt = []
negative_prompt = []
for i in range(B):
ram_image = x_tgt[i,:,:,:].unsqueeze(0)
x_tgt_ram = ram_transforms(ram_image*0.5+0.5)
caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
positive_prompt.append(f'{caption[0]}, {args.positive_prompt}')
negative_prompt.append(args.negative_prompt)
# generate some samples
if torch.cuda.device_count() > 1:
sample_images, _, _ = net_pix2pix.module.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize)
else:
sample_images, _, _ = net_pix2pix.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize)
# select winning and losing samples:
x_tgt_re = x_tgt.unsqueeze(1).repeat(1,args.groupsize,1,1,1)
rewards = net_ARF(sample_images, x_tgt_re, fedilty_ratio, detail_ratio)
rewards = rewards.cuda()
b_sample, g_sample, c_sample, h_sample, w_sample = sample_images.shape
x_src_wl = sample_images.view(b_sample*g_sample, c_sample, h_sample, w_sample)
ps_wl = []
nps_wl = []
for i in range(args.groupsize):
ps_wl += positive_prompt
nps_wl += negative_prompt
# forward pass
x_tgt_pred, latents_pred, model_pred, prompt_embeds, neg_prompt_embeds, noise, ref_output_image, ref_x_denoised, ref_model_pred = net_pix2pix(x_src_wl, positive_prompt=ps_wl, negative_prompt=nps_wl, args=args)
# GDPO
model_losses = (model_pred - noise).pow(2).mean(dim=[1,2,3])
# b_model, c_model, h_model, w_model = model_losses.shape
model_losses = model_losses.view(b_sample, g_sample)
model_losses = rewards * model_losses
model_diff = model_losses.sum(1)
# model_losses_w, model_losses_l = model_losses.chunk(2)
ref_losses = (ref_model_pred - noise).pow(2).mean(dim=[1,2,3])
ref_losses = ref_losses.view(b_sample, g_sample)
ref_losses = rewards * ref_losses
ref_diff = ref_losses.sum(1)
scale_term = -0.5 * 5000
inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
gdpo_loss = -1 * F.logsigmoid(inside_term).mean()
loss = gdpo_loss
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
logs = {}
# log all the losses
logs["loss"] = gdpo_loss.detach().item()
progress_bar.set_postfix(**logs)
# checkpoint the model
if global_step % args.checkpointing_steps == 1:
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
accelerator.unwrap_model(net_pix2pix).save_model(outf)
accelerator.log(logs, step=global_step)
if __name__ == "__main__":
args = parse_args_realsr_training()
main(args)