Instructions to use Joypop/GDPO with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Joypop/GDPO with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Joypop/GDPO", dtype=torch.bfloat16, device_map="cuda") prompt = "Turn this cat into a dog" input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") image = pipe(image=input_image, prompt=prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import os | |
| import gc | |
| import lpips | |
| import clip | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| import transformers | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| import copy | |
| import diffusers | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from diffusers.optimization import get_scheduler | |
| import wandb | |
| from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats | |
| import sys | |
| sys.path.append("GDPOSR") | |
| from modelfile.GDPOSR import GDPOSR as GDPOSRModel | |
| from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset | |
| from pathlib import Path | |
| from accelerate.utils import set_seed, ProjectConfiguration | |
| from accelerate import DistributedDataParallelKwargs | |
| sys.path.append('GDPOSR') | |
| from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix | |
| from diffusers.training_utils import compute_snr | |
| from diffusers import DDPMScheduler, AutoencoderKL | |
| from GDPOSR.losses.grpo import AdaptiveReward as RewardFunction | |
| from ram.models.ram_lora import ram | |
| from ram import inference_ram as inference | |
| def main(args): | |
| logging_dir = Path(args.output_dir, args.logging_dir) | |
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| mixed_precision=args.mixed_precision, | |
| log_with=args.report_to, | |
| project_config=accelerator_project_config, | |
| kwargs_handlers=[ddp_kwargs], | |
| ) | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| if accelerator.is_main_process: | |
| os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) | |
| net_pix2pix = GDPOSRModel(args) | |
| net_pix2pix.set_train() | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| net_pix2pix.unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available, please install it by running `pip install xformers`") | |
| if args.gradient_checkpointing: | |
| net_pix2pix.unet.enable_gradient_checkpointing() | |
| if args.allow_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| net_lpips = lpips.LPIPS(net='vgg').cuda() | |
| net_lpips.requires_grad_(False) | |
| net_ARF = RewardFunction() | |
| net_ARF.requires_grad_(False) | |
| # # set adapter | |
| net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others']) | |
| # make the optimizer | |
| layers_to_opt = [] | |
| for n, _p in net_pix2pix.unet.named_parameters(): | |
| if "lora" in n: | |
| assert _p.requires_grad | |
| layers_to_opt.append(_p) | |
| optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon,) | |
| lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, | |
| num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
| num_training_steps=args.max_train_steps * accelerator.num_processes, | |
| num_cycles=args.lr_num_cycles, power=args.lr_power,) | |
| # make the dataloader | |
| dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args) | |
| dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args) | |
| dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) | |
| dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) | |
| # init RAM | |
| ram_transforms = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth', | |
| pretrained_condition=None, | |
| image_size=384, | |
| vit='swin_l') | |
| RAM.eval() | |
| RAM.to("cuda", dtype=torch.float16) | |
| # Prepare everything with our `accelerator`. | |
| net_pix2pix, optimizer, dl_train, lr_scheduler = accelerator.prepare( | |
| net_pix2pix, optimizer, dl_train, lr_scheduler | |
| ) | |
| net_lpips, net_ARF = accelerator.prepare(net_lpips, net_ARF) | |
| # renorm with image net statistics | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| # We need to initialize the trackers we use, and also store our configuration. | |
| # The trackers initializes automatically on the main process. | |
| if accelerator.is_main_process: | |
| tracker_config = dict(vars(args)) | |
| accelerator.init_trackers(args.tracker_project_name, config=tracker_config) | |
| progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", | |
| disable=not accelerator.is_local_main_process,) | |
| # start the training loop | |
| global_step = 0 | |
| for epoch in range(0, args.num_training_epochs): | |
| for step, batch in enumerate(dl_train): | |
| with accelerator.accumulate(net_pix2pix): | |
| x_src = batch["LR"] | |
| x_tgt = batch["HR"] | |
| fedilty_ratio = batch["fedilty_ratio"] | |
| detail_ratio = batch["detail_ratio"] | |
| B, C, H, W = x_src.shape | |
| # image description | |
| x_tgt_ram = ram_transforms(x_tgt*0.5+0.5) | |
| caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM) | |
| with torch.no_grad(): | |
| positive_prompt = [] | |
| negative_prompt = [] | |
| for i in range(B): | |
| ram_image = x_tgt[i,:,:,:].unsqueeze(0) | |
| x_tgt_ram = ram_transforms(ram_image*0.5+0.5) | |
| caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM) | |
| positive_prompt.append(f'{caption[0]}, {args.positive_prompt}') | |
| negative_prompt.append(args.negative_prompt) | |
| # generate some samples | |
| if torch.cuda.device_count() > 1: | |
| sample_images, _, _ = net_pix2pix.module.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize) | |
| else: | |
| sample_images, _, _ = net_pix2pix.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize) | |
| # select winning and losing samples: | |
| x_tgt_re = x_tgt.unsqueeze(1).repeat(1,args.groupsize,1,1,1) | |
| rewards = net_ARF(sample_images, x_tgt_re, fedilty_ratio, detail_ratio) | |
| rewards = rewards.cuda() | |
| b_sample, g_sample, c_sample, h_sample, w_sample = sample_images.shape | |
| x_src_wl = sample_images.view(b_sample*g_sample, c_sample, h_sample, w_sample) | |
| ps_wl = [] | |
| nps_wl = [] | |
| for i in range(args.groupsize): | |
| ps_wl += positive_prompt | |
| nps_wl += negative_prompt | |
| # forward pass | |
| x_tgt_pred, latents_pred, model_pred, prompt_embeds, neg_prompt_embeds, noise, ref_output_image, ref_x_denoised, ref_model_pred = net_pix2pix(x_src_wl, positive_prompt=ps_wl, negative_prompt=nps_wl, args=args) | |
| # GDPO | |
| model_losses = (model_pred - noise).pow(2).mean(dim=[1,2,3]) | |
| # b_model, c_model, h_model, w_model = model_losses.shape | |
| model_losses = model_losses.view(b_sample, g_sample) | |
| model_losses = rewards * model_losses | |
| model_diff = model_losses.sum(1) | |
| # model_losses_w, model_losses_l = model_losses.chunk(2) | |
| ref_losses = (ref_model_pred - noise).pow(2).mean(dim=[1,2,3]) | |
| ref_losses = ref_losses.view(b_sample, g_sample) | |
| ref_losses = rewards * ref_losses | |
| ref_diff = ref_losses.sum(1) | |
| scale_term = -0.5 * 5000 | |
| inside_term = scale_term * (model_diff - ref_diff) | |
| implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) | |
| gdpo_loss = -1 * F.logsigmoid(inside_term).mean() | |
| loss = gdpo_loss | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad(set_to_none=args.set_grads_to_none) | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| if accelerator.is_main_process: | |
| logs = {} | |
| # log all the losses | |
| logs["loss"] = gdpo_loss.detach().item() | |
| progress_bar.set_postfix(**logs) | |
| # checkpoint the model | |
| if global_step % args.checkpointing_steps == 1: | |
| outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") | |
| accelerator.unwrap_model(net_pix2pix).save_model(outf) | |
| accelerator.log(logs, step=global_step) | |
| if __name__ == "__main__": | |
| args = parse_args_realsr_training() | |
| main(args) | |