| | import os |
| | import requests |
| | import sys |
| | import copy |
| | import random |
| | import time |
| | import glob |
| | import math |
| | import yaml |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| | from peft import LoraConfig |
| | from types import SimpleNamespace |
| | from transformers import AutoTokenizer, CLIPTextModel |
| | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
| | from diffusers.utils.peft_utils import set_weights_and_activate_adapters |
| | from diffusers.utils.import_utils import is_xformers_available |
| |
|
| | def make_1step_sched(pretrained_model_path): |
| | noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") |
| | noise_scheduler_1step.set_timesteps(1, device="cuda") |
| | noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() |
| | return noise_scheduler_1step |
| |
|
| | def find_filepath(directory, filename): |
| | matches = glob.glob(f"{directory}/**/{filename}", recursive=True) |
| | return matches[0] if matches else None |
| |
|
| |
|
| | def read_yaml(file_path): |
| | with open(file_path, 'r') as file: |
| | data = yaml.safe_load(file) |
| | return data |
| |
|
| | def initialize_vae(rank, return_lora_module_names=False, pretrained_model_name_or_path=None): |
| | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
| | vae.requires_grad_(False) |
| | vae.train() |
| | |
| | l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| | l_grep = ["conv1","conv2","conv_in", "conv_shortcut", |
| | "conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0", |
| | ] |
| | for n, p in vae.named_parameters(): |
| | if "bias" in n or "norm" in n: continue |
| | for pattern in l_grep: |
| | if pattern in n and ("encoder" in n): |
| | l_target_modules_encoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n and ("decoder" in n): |
| | l_target_modules_decoder.append(n.replace(".weight","")) |
| | break |
| | elif ('quant_conv' in n) and ('post_quant_conv' not in n): |
| | l_target_modules_encoder.append(n.replace(".weight","")) |
| | break |
| | elif 'post_quant_conv' in n: |
| | l_target_modules_decoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n: |
| | l_modules_others.append(n.replace(".weight","")) |
| | break |
| | lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| | lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| | vae.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| | vae.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| | |
| | if return_lora_module_names: |
| | return vae, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| | else: |
| | return vae |
| |
|
| | def initialize_unet(rank, return_lora_module_names=False, pretrained_model_name_or_path=None): |
| | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") |
| | unet.requires_grad_(False) |
| | unet.train() |
| |
|
| | l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| | l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] |
| | for n, p in unet.named_parameters(): |
| | if "bias" in n or "norm" in n: continue |
| | for pattern in l_grep: |
| | if pattern in n and ("down_blocks" in n or "conv_in" in n): |
| | l_target_modules_encoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n and "up_blocks" in n: |
| | l_target_modules_decoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n: |
| | l_modules_others.append(n.replace(".weight","")) |
| | break |
| | lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| | lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| | lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others) |
| | unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| | unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| | unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| | if return_lora_module_names: |
| | return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| | else: |
| | return unet |
| |
|
| | def initialize_unet_sr(rank, return_lora_module_names=False, pretrained_model_name_or_path=None, args=None): |
| | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") |
| | if args.use_lr_concat_lr_999noise: |
| | new_conv_in = torch.nn.Conv2d(8, 320, 3, 1, 1) |
| | new_conv_in.weight.data[:, :4, ...] = unet.conv_in.weight.data |
| | new_conv_in.weight.data[:, -4:, ...] = unet.conv_in.weight.data |
| | new_conv_in.bias.data = unet.conv_in.bias.data |
| | unet.conv_in = new_conv_in |
| | unet.requires_grad_(False) |
| | unet.train() |
| |
|
| | l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] |
| | l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] |
| | for n, p in unet.named_parameters(): |
| | if "bias" in n or "norm" in n: continue |
| | for pattern in l_grep: |
| | if pattern in n and ("down_blocks" in n or "conv_in" in n): |
| | l_target_modules_encoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n and "up_blocks" in n: |
| | l_target_modules_decoder.append(n.replace(".weight","")) |
| | break |
| | elif pattern in n: |
| | l_modules_others.append(n.replace(".weight","")) |
| | break |
| | lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder) |
| | lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder) |
| | lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others) |
| | unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| | unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| | unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| | if return_lora_module_names: |
| | return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others |
| | else: |
| | return unet |
| |
|
| | class VSD(torch.nn.Module): |
| | def __init__(self, args, accelerator): |
| | super().__init__() |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
| | self.sched = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | self.args = args |
| |
|
| | weight_dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| |
|
| | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| | self.unet_fix = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| | self.unet_update, self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others =\ |
| | initialize_unet(rank=args.lora_rank_unet_vsd, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True) |
| | self.lora_rank_unet = args.lora_rank_unet_vsd |
| |
|
| | if args.enable_xformers_memory_efficient_attention: |
| | if is_xformers_available(): |
| | self.unet_fix.enable_xformers_memory_efficient_attention() |
| | self.unet_update.enable_xformers_memory_efficient_attention() |
| | else: |
| | raise ValueError("xformers is not available, please install it by running `pip install xformers`") |
| |
|
| | if args.gradient_checkpointing: |
| | self.unet_fix.enable_gradient_checkpointing() |
| | self.unet_update.enable_gradient_checkpointing() |
| |
|
| | self.text_encoder.to(accelerator.device, dtype=weight_dtype) |
| | self.unet_fix.to(accelerator.device, dtype=weight_dtype) |
| | self.unet_update.to(accelerator.device) |
| | self.vae.to(accelerator.device) |
| | |
| | self.text_encoder.requires_grad_(False) |
| | self.vae.requires_grad_(False) |
| | self.unet_fix.requires_grad_(False) |
| |
|
| | def set_eval(self): |
| | self.unet_fix.eval() |
| | self.unet.eval() |
| | self.unet_update.eval() |
| |
|
| | def set_train(self): |
| | self.unet_update.train() |
| | for n, _p in self.unet_update.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| |
|
| | def forward(self, c_t, prompt=None, neg_prompt_tokens=None, prompt_tokens=None, deterministic=True, r=1.0, noise_map=None, args=None): |
| |
|
| | caption_enc = self.text_encoder(prompt_tokens)[0] |
| | neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0] |
| |
|
| | encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| | model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| | x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| |
|
| | output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) |
| |
|
| | return output_image, caption_enc, neg_caption_enc |
| |
|
| | def forward_latent(self, model, latents, timestep, prompt_embeds): |
| | |
| | noise_pred = model( |
| | latents, |
| | timestep=timestep, |
| | encoder_hidden_states=prompt_embeds, |
| | ).sample |
| |
|
| | return noise_pred |
| |
|
| | def compute_lora_loss(self, latents_pred, prompt_embeds, args): |
| |
|
| | latents_pred = latents_pred.detach() |
| | prompt_embeds = prompt_embeds.detach() |
| | noise = torch.randn_like(latents_pred) |
| | bsz = latents_pred.shape[0] |
| | timesteps = torch.randint(0, self.sched.config.num_train_timesteps, (bsz,), device=latents_pred.device) |
| | timesteps = timesteps.long() |
| | noisy_latents = self.sched.add_noise(latents_pred, noise, timesteps) |
| | disc_pred = self.forward_latent( |
| | self.unet_update, |
| | timestep=timesteps, |
| | latents=noisy_latents, |
| | prompt_embeds=prompt_embeds |
| | ) |
| | if args.snr_gamma_vsd is None: |
| | loss_d = F.mse_loss(disc_pred.float(), noise.float(), reduction="mean") |
| | else: |
| | |
| | |
| | |
| | snr = compute_snr(self.sched, timesteps) |
| | if self.sched.config.prediction_type == "v_prediction": |
| | |
| | snr = snr + 1 |
| | mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
| |
|
| | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
| | loss_d = loss.mean() |
| |
|
| | return loss_d |
| |
|
| | def eps_to_mu(self, scheduler, model_output, sample, timesteps): |
| | alphas_cumprod = scheduler.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) |
| | alpha_prod_t = alphas_cumprod[timesteps] |
| | while len(alpha_prod_t.shape) < len(sample.shape): |
| | alpha_prod_t = alpha_prod_t.unsqueeze(-1) |
| | beta_prod_t = 1 - alpha_prod_t |
| | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| | return pred_original_sample |
| |
|
| | def distribution_matching_loss( |
| | self, |
| | real_model, |
| | fake_model, |
| | noise_scheduler, |
| | latents, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | args, |
| | ): |
| | bsz = latents.shape[0] |
| | min_dm_step = int(noise_scheduler.config.num_train_timesteps * args.min_dm_step_ratio) |
| | max_dm_step = int(noise_scheduler.config.num_train_timesteps * args.max_dm_step_ratio) |
| |
|
| | timestep = torch.randint(min_dm_step, max_dm_step, (bsz,), device=latents.device).long() |
| | noise = torch.randn_like(latents) |
| | noisy_latents = noise_scheduler.add_noise(latents, noise, timestep) |
| |
|
| | with torch.no_grad(): |
| | noise_pred = self.forward_latent( |
| | fake_model, |
| | latents=noisy_latents, |
| | timestep=timestep, |
| | prompt_embeds=prompt_embeds.float(), |
| | ) |
| | pred_fake_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep) |
| |
|
| | noisy_latents_input = torch.cat([noisy_latents] * 2) |
| | timestep_input = torch.cat([timestep] * 2) |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| |
|
| | noise_pred = self.forward_latent( |
| | real_model, |
| | latents=noisy_latents_input.to(dtype=torch.float16), |
| | timestep=timestep_input, |
| | prompt_embeds=prompt_embeds.to(dtype=torch.float16), |
| | ) |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + args.cfg_vsd * (noise_pred_text - noise_pred_uncond) |
| | noise_pred.to(dtype=torch.float32) |
| |
|
| | pred_real_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep) |
| |
|
| | weighting_factor = torch.abs(latents - pred_real_latents).mean(dim=[1, 2, 3], keepdim=True) |
| |
|
| | grad = (pred_fake_latents - pred_real_latents) / weighting_factor |
| | loss = F.mse_loss(latents, self.stopgrad(latents - grad)) |
| | return loss |
| |
|
| | def stopgrad(self, x): |
| | return x.detach() |
| |
|
| | def save_model(self, outf): |
| | sd = {} |
| | sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| | sd["rank_unet"] = self.lora_rank_unet |
| | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k} |
| | torch.save(sd, outf) |
| |
|
| | class NAOSD(torch.nn.Module): |
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| | self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| | self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | self.args = args |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| | |
| | if args.pretrained_path is None: |
| | vae, lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others =\ |
| | initialize_vae(rank=args.lora_rank_vae, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True) |
| | unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\ |
| | initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True, args=args) |
| | self.lora_rank_unet = args.lora_rank_unet |
| | self.lora_rank_vae = args.lora_rank_vae |
| | self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \ |
| | lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| | lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others |
| | |
| | self.unet, self.vae = unet, vae |
| | |
| | if args.pretrained_path is not None: |
| | print('==================================> loading pre-trained weight') |
| | sd = torch.load(args.pretrained_path) |
| | self.load_ckpt_from_state_dict(sd) |
| | self.lora_rank_unet = sd['rank_unet'] |
| | self.lora_rank_vae = sd['rank_vae'] |
| | self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \ |
| | sd['vae_lora_encoder_modules'], sd['vae_lora_decoder_modules'], sd['vae_lora_others_modules'] |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| | sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules'] |
| |
|
| | self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| | self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| | self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| | self.text_encoder.requires_grad_(False) |
| |
|
| | def set_eval(self): |
| | self.unet.eval() |
| | self.vae.eval() |
| | self.unet.requires_grad_(False) |
| | self.vae.requires_grad_(False) |
| |
|
| | def set_train(self): |
| | self.unet.train() |
| | self.vae.train() |
| | for n, _p in self.unet.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| | self.unet.conv_in.requires_grad_(True) |
| | for n, _p in self.vae.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| |
|
| | def encode_prompt(self, prompt): |
| | with torch.no_grad(): |
| | text_input_ids = self.tokenizer( |
| | prompt, max_length=self.tokenizer.model_max_length, |
| | padding="max_length", truncation=True, return_tensors="pt" |
| | ).input_ids |
| | prompt_embeds = self.text_encoder( |
| | text_input_ids.to(self.text_encoder.device), |
| | )[0] |
| | return prompt_embeds |
| |
|
| | def forward(self, c_t, positive_prompt=None, negative_prompt=None, args=None): |
| | caption_enc = self.encode_prompt(positive_prompt) |
| | neg_caption_enc = self.encode_prompt(negative_prompt) |
| | encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| | noise = torch.randn_like(encoded_control) |
| | encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
| |
|
| | model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| | x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| | output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| | output_image = output_image.clamp(-1, 1) |
| |
|
| | return output_image, x_denoised, caption_enc, neg_caption_enc, noise |
| |
|
| | def save_model(self, outf): |
| | sd = {} |
| | sd["vae_lora_encoder_modules"], sd["vae_lora_decoder_modules"], sd["vae_lora_others_modules"] =\ |
| | self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others |
| | sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| | sd["rank_unet"] = self.lora_rank_unet |
| | sd["rank_vae"] = self.lora_rank_vae |
| | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} |
| | sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k} |
| | torch.save(sd, outf) |
| | |
| | def load_ckpt_from_state_dict(self, sd): |
| | |
| | lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"]) |
| | lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"]) |
| | lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"]) |
| | self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| | self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| | self.unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| | for n, p in self.unet.named_parameters(): |
| | if "lora" in n or "conv_in" in n: |
| | p.data.copy_(sd["state_dict_unet"][n]) |
| |
|
| | |
| | vae_lora_conf_encoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_encoder_modules"]) |
| | vae_lora_conf_decoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_decoder_modules"]) |
| | self.vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder") |
| | self.vae.add_adapter(vae_lora_conf_decoder, adapter_name="default_decoder") |
| | for n, p in self.vae.named_parameters(): |
| | if "lora" in n: |
| | p.data.copy_(sd["state_dict_vae"][n]) |
| |
|
| | class GDPOSR(torch.nn.Module): |
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| | self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| | self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | self.args = args |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | vae = AutoencoderKL.from_pretrained(args.basemodel_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet") |
| | ref_unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet") |
| | |
| | if args.pretrained_path is None: |
| | print('==================================> randomly initiate the weight') |
| | unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\ |
| | initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.basemodel_path, return_lora_module_names=True, args=args) |
| | self.lora_rank_unet = args.lora_rank_unet |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| | lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others |
| | |
| | self.unet, self.vae = unet, vae |
| | |
| | if args.pretrained_path is not None: |
| | print('==================================> loading pre-trained weight') |
| | sd = torch.load(args.pretrained_path) |
| | self.load_ckpt_from_state_dict(sd) |
| | self.lora_rank_unet = sd['rank_unet'] |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \ |
| | sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules'] |
| |
|
| | self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| | self.ref_unet = ref_unet.cuda() |
| | self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| | self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| | self.text_encoder.requires_grad_(False) |
| |
|
| | def set_eval(self): |
| | self.unet.eval() |
| | self.vae.eval() |
| | self.ref_unet.eval() |
| | self.unet.requires_grad_(False) |
| | self.vae.requires_grad_(False) |
| | self.ref_unet.requires_grad_(False) |
| |
|
| | def set_train(self): |
| | self.unet.train() |
| | self.vae.train() |
| | for n, _p in self.unet.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| | for n, _p in self.ref_unet.named_parameters(): |
| | _p.requires_grad = False |
| |
|
| | def encode_prompt(self, prompt): |
| | with torch.no_grad(): |
| | text_input_ids = self.tokenizer( |
| | prompt, max_length=self.tokenizer.model_max_length, |
| | padding="max_length", truncation=True, return_tensors="pt" |
| | ).input_ids |
| | prompt_embeds = self.text_encoder( |
| | text_input_ids.to(self.text_encoder.device), |
| | )[0] |
| | return prompt_embeds |
| |
|
| | def forward(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None): |
| | caption_enc = self.encode_prompt(positive_prompt) |
| | neg_caption_enc = self.encode_prompt(negative_prompt) |
| | with torch.no_grad(): |
| | encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| | encoded_control_ref = encoded_control |
| | noise = torch.randn_like(encoded_control) |
| | encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
| |
|
| | model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| | x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| | output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| | output_image = output_image.clamp(-1, 1) |
| |
|
| | with torch.no_grad(): |
| | encoded_control_ref = self.sched2.add_noise(encoded_control_ref, noise, self.timestepsnoise) |
| | ref_model_pred = self.ref_unet(encoded_control_ref, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| | ref_x_denoised = self.sched.step(ref_model_pred, self.timesteps, encoded_control_ref, return_dict=True).prev_sample |
| | ref_output_image = self.vae.decode(ref_x_denoised / self.vae.config.scaling_factor).sample |
| | ref_output_image = ref_output_image.clamp(-1, 1) |
| |
|
| | return output_image, x_denoised, model_pred, caption_enc, neg_caption_enc, noise, ref_output_image, ref_x_denoised, ref_model_pred |
| | |
| | def GDPOReference(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None, groupsize=6): |
| | |
| | with torch.no_grad(): |
| | |
| | caption_enc = self.encode_prompt(positive_prompt).unsqueeze(1) |
| | encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| | b,c,h,w=encoded_control.shape |
| | encoded_control = encoded_control.unsqueeze(1) |
| | caption_enc = caption_enc.repeat(1,groupsize,1,1) |
| | encoded_control = encoded_control.repeat(1, groupsize, 1, 1, 1) |
| | noise = torch.randn_like(encoded_control) |
| | output_image = torch.zeros_like(c_t).unsqueeze(1).repeat(1,groupsize,1,1,1) |
| | x_denoised = torch.zeros_like(noise) |
| | model_pred = torch.zeros_like(noise) |
| | for i in range(b): |
| | encoded_control_i = self.sched2.add_noise(encoded_control[i], noise[i], self.timestepsnoise) |
| | |
| | model_pred_i = self.ref_unet(encoded_control_i, self.timesteps, encoder_hidden_states=caption_enc[i],).sample |
| | x_denoised_i = self.sched.step(model_pred_i, self.timesteps, encoded_control_i, return_dict=True).prev_sample |
| | output_image_i = self.vae.decode(x_denoised_i / self.vae.config.scaling_factor).sample |
| | output_image_i = output_image_i.clamp(-1, 1) |
| | output_image[i] = output_image_i |
| | x_denoised[i] = x_denoised_i |
| | model_pred[i] = model_pred_i |
| |
|
| | return output_image, x_denoised, model_pred |
| | |
| | def save_model(self, outf): |
| | sd = {} |
| | sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\ |
| | self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others |
| | sd["rank_unet"] = self.lora_rank_unet |
| | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} |
| | torch.save(sd, outf) |
| | |
| | def load_ckpt_from_state_dict(self, sd): |
| | |
| | lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"]) |
| | lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"]) |
| | lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"]) |
| | self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") |
| | self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") |
| | self.unet.add_adapter(lora_conf_others, adapter_name="default_others") |
| | for n, p in self.unet.named_parameters(): |
| | if "lora" in n or "conv_in" in n: |
| | p.data.copy_(sd["state_dict_unet"][n]) |
| |
|
| | class GDPOSRTest(torch.nn.Module): |
| | def __init__(self, args): |
| | super().__init__() |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda() |
| | self.sched = make_1step_sched(args.pretrained_model_name_or_path) |
| | self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | self.args = args |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | vae = AutoencoderKL.from_pretrained(args.pretrained_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet") |
| | |
| | self.unet, self.vae = unet, vae |
| | self.unet, self.vae = self.unet.cuda(), self.vae.cuda() |
| | self.timesteps = torch.tensor([args.time_step], device="cuda").long() |
| | self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long() |
| | self.text_encoder.requires_grad_(False) |
| |
|
| | def set_eval(self): |
| | self.unet.eval() |
| | self.vae.eval() |
| | self.unet.requires_grad_(False) |
| | self.vae.requires_grad_(False) |
| |
|
| | def encode_prompt(self, prompt): |
| | with torch.no_grad(): |
| | text_input_ids = self.tokenizer( |
| | prompt, max_length=self.tokenizer.model_max_length, |
| | padding="max_length", truncation=True, return_tensors="pt" |
| | ).input_ids |
| | prompt_embeds = self.text_encoder( |
| | text_input_ids.to(self.text_encoder.device), |
| | )[0] |
| | return prompt_embeds |
| |
|
| | def forward(self, c_t, positive_prompt=['']): |
| | |
| | caption_enc = self.encode_prompt(positive_prompt) |
| | encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| | noise = torch.randn_like(encoded_control) |
| | encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise) |
| |
|
| | model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample |
| | x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample |
| | output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample |
| | output_image = output_image.clamp(-1, 1) |
| | |
| |
|
| | return output_image |
| |
|