| | import argparse |
| | import itertools |
| | import json |
| | import os |
| | import random |
| | import time |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from accelerate import Accelerator |
| | from accelerate.utils import ProjectConfiguration |
| | from ip_adapter.resampler import Resampler |
| | from ip_adapter.utils import is_torch2_available |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection |
| |
|
| | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| |
|
| |
|
| | if is_torch2_available(): |
| | from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor |
| | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor |
| | else: |
| | from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor |
| |
|
| |
|
| | |
| | class MyDataset(torch.utils.data.Dataset): |
| | def __init__( |
| | self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path="" |
| | ): |
| | super().__init__() |
| |
|
| | self.tokenizer = tokenizer |
| | self.size = size |
| | self.i_drop_rate = i_drop_rate |
| | self.t_drop_rate = t_drop_rate |
| | self.ti_drop_rate = ti_drop_rate |
| | self.image_root_path = image_root_path |
| |
|
| | self.data = json.load(open(json_file)) |
| |
|
| | self.transform = transforms.Compose( |
| | [ |
| | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), |
| | transforms.CenterCrop(self.size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5], [0.5]), |
| | ] |
| | ) |
| | self.clip_image_processor = CLIPImageProcessor() |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | text = item["text"] |
| | image_file = item["image_file"] |
| |
|
| | |
| | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) |
| | image = self.transform(raw_image.convert("RGB")) |
| | clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values |
| |
|
| | |
| | drop_image_embed = 0 |
| | rand_num = random.random() |
| | if rand_num < self.i_drop_rate: |
| | drop_image_embed = 1 |
| | elif rand_num < (self.i_drop_rate + self.t_drop_rate): |
| | text = "" |
| | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): |
| | text = "" |
| | drop_image_embed = 1 |
| | |
| | text_input_ids = self.tokenizer( |
| | text, |
| | max_length=self.tokenizer.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ).input_ids |
| |
|
| | return { |
| | "image": image, |
| | "text_input_ids": text_input_ids, |
| | "clip_image": clip_image, |
| | "drop_image_embed": drop_image_embed, |
| | } |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| |
|
| | def collate_fn(data): |
| | images = torch.stack([example["image"] for example in data]) |
| | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) |
| | clip_images = torch.cat([example["clip_image"] for example in data], dim=0) |
| | drop_image_embeds = [example["drop_image_embed"] for example in data] |
| |
|
| | return { |
| | "images": images, |
| | "text_input_ids": text_input_ids, |
| | "clip_images": clip_images, |
| | "drop_image_embeds": drop_image_embeds, |
| | } |
| |
|
| |
|
| | class IPAdapter(torch.nn.Module): |
| | """IP-Adapter""" |
| |
|
| | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): |
| | super().__init__() |
| | self.unet = unet |
| | self.image_proj_model = image_proj_model |
| | self.adapter_modules = adapter_modules |
| |
|
| | if ckpt_path is not None: |
| | self.load_from_checkpoint(ckpt_path) |
| |
|
| | def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): |
| | ip_tokens = self.image_proj_model(image_embeds) |
| | encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) |
| | |
| | noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| | return noise_pred |
| |
|
| | def load_from_checkpoint(self, ckpt_path: str): |
| | |
| | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) |
| | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) |
| |
|
| | state_dict = torch.load(ckpt_path, map_location="cpu") |
| |
|
| | |
| | strict_load_image_proj_model = True |
| | if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict(): |
| | |
| | if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape: |
| | print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") |
| | print("Removing 'latents' from checkpoint and loading the rest of the weights.") |
| | del state_dict["image_proj"]["latents"] |
| | strict_load_image_proj_model = False |
| |
|
| | |
| | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model) |
| | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) |
| |
|
| | |
| | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) |
| | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) |
| |
|
| | |
| | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" |
| | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" |
| |
|
| | print(f"Successfully loaded weights from checkpoint {ckpt_path}") |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| | parser.add_argument( |
| | "--pretrained_model_name_or_path", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Path to pretrained model or model identifier from huggingface.co/models.", |
| | ) |
| | parser.add_argument( |
| | "--pretrained_ip_adapter_path", |
| | type=str, |
| | default=None, |
| | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", |
| | ) |
| | parser.add_argument( |
| | "--num_tokens", |
| | type=int, |
| | default=16, |
| | help="Number of tokens to query from the CLIP image encoding.", |
| | ) |
| | parser.add_argument( |
| | "--data_json_file", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Training data", |
| | ) |
| | parser.add_argument( |
| | "--data_root_path", |
| | type=str, |
| | default="", |
| | required=True, |
| | help="Training data root path", |
| | ) |
| | parser.add_argument( |
| | "--image_encoder_path", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Path to CLIP image encoder", |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="sd-ip_adapter", |
| | help="The output directory where the model predictions and checkpoints will be written.", |
| | ) |
| | parser.add_argument( |
| | "--logging_dir", |
| | type=str, |
| | default="logs", |
| | help=( |
| | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" |
| | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--resolution", |
| | type=int, |
| | default=512, |
| | help=("The resolution for input images"), |
| | ) |
| | parser.add_argument( |
| | "--learning_rate", |
| | type=float, |
| | default=1e-4, |
| | help="Learning rate to use.", |
| | ) |
| | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
| | parser.add_argument("--num_train_epochs", type=int, default=100) |
| | parser.add_argument( |
| | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." |
| | ) |
| | parser.add_argument( |
| | "--dataloader_num_workers", |
| | type=int, |
| | default=0, |
| | help=( |
| | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--save_steps", |
| | type=int, |
| | default=2000, |
| | help=("Save a checkpoint of the training state every X updates"), |
| | ) |
| | parser.add_argument( |
| | "--mixed_precision", |
| | type=str, |
| | default=None, |
| | choices=["no", "fp16", "bf16"], |
| | help=( |
| | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
| | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--report_to", |
| | type=str, |
| | default="tensorboard", |
| | help=( |
| | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
| | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
| | ), |
| | ) |
| | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
| |
|
| | args = parser.parse_args() |
| | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| | args.local_rank = env_local_rank |
| |
|
| | return args |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | logging_dir = Path(args.output_dir, args.logging_dir) |
| |
|
| | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| |
|
| | accelerator = Accelerator( |
| | mixed_precision=args.mixed_precision, |
| | log_with=args.report_to, |
| | project_config=accelerator_project_config, |
| | ) |
| |
|
| | if accelerator.is_main_process: |
| | if args.output_dir is not None: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | |
| | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") |
| | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") |
| | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") |
| | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) |
| | |
| | unet.requires_grad_(False) |
| | vae.requires_grad_(False) |
| | text_encoder.requires_grad_(False) |
| | image_encoder.requires_grad_(False) |
| |
|
| | |
| | image_proj_model = Resampler( |
| | dim=unet.config.cross_attention_dim, |
| | depth=4, |
| | dim_head=64, |
| | heads=12, |
| | num_queries=args.num_tokens, |
| | embedding_dim=image_encoder.config.hidden_size, |
| | output_dim=unet.config.cross_attention_dim, |
| | ff_mult=4, |
| | ) |
| | |
| | attn_procs = {} |
| | unet_sd = unet.state_dict() |
| | for name in unet.attn_processors.keys(): |
| | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim |
| | if name.startswith("mid_block"): |
| | hidden_size = unet.config.block_out_channels[-1] |
| | elif name.startswith("up_blocks"): |
| | block_id = int(name[len("up_blocks.")]) |
| | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
| | elif name.startswith("down_blocks"): |
| | block_id = int(name[len("down_blocks.")]) |
| | hidden_size = unet.config.block_out_channels[block_id] |
| | if cross_attention_dim is None: |
| | attn_procs[name] = AttnProcessor() |
| | else: |
| | layer_name = name.split(".processor")[0] |
| | weights = { |
| | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], |
| | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], |
| | } |
| | attn_procs[name] = IPAttnProcessor( |
| | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens |
| | ) |
| | attn_procs[name].load_state_dict(weights) |
| | unet.set_attn_processor(attn_procs) |
| | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
| |
|
| | ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) |
| |
|
| | weight_dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| | |
| | vae.to(accelerator.device, dtype=weight_dtype) |
| | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| | image_encoder.to(accelerator.device, dtype=weight_dtype) |
| |
|
| | |
| | params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) |
| | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) |
| |
|
| | |
| | train_dataset = MyDataset( |
| | args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path |
| | ) |
| | train_dataloader = torch.utils.data.DataLoader( |
| | train_dataset, |
| | shuffle=True, |
| | collate_fn=collate_fn, |
| | batch_size=args.train_batch_size, |
| | num_workers=args.dataloader_num_workers, |
| | ) |
| |
|
| | |
| | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) |
| |
|
| | global_step = 0 |
| | for epoch in range(0, args.num_train_epochs): |
| | begin = time.perf_counter() |
| | for step, batch in enumerate(train_dataloader): |
| | load_data_time = time.perf_counter() - begin |
| | with accelerator.accumulate(ip_adapter): |
| | |
| | with torch.no_grad(): |
| | latents = vae.encode( |
| | batch["images"].to(accelerator.device, dtype=weight_dtype) |
| | ).latent_dist.sample() |
| | latents = latents * vae.config.scaling_factor |
| |
|
| | |
| | noise = torch.randn_like(latents) |
| | bsz = latents.shape[0] |
| | |
| | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) |
| | timesteps = timesteps.long() |
| |
|
| | |
| | |
| | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| |
|
| | clip_images = [] |
| | for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]): |
| | if drop_image_embed == 1: |
| | clip_images.append(torch.zeros_like(clip_image)) |
| | else: |
| | clip_images.append(clip_image) |
| | clip_images = torch.stack(clip_images, dim=0) |
| | with torch.no_grad(): |
| | image_embeds = image_encoder( |
| | clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True |
| | ).hidden_states[-2] |
| |
|
| | with torch.no_grad(): |
| | encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] |
| |
|
| | noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) |
| |
|
| | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| |
|
| | |
| | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() |
| |
|
| | |
| | accelerator.backward(loss) |
| | optimizer.step() |
| | optimizer.zero_grad() |
| |
|
| | if accelerator.is_main_process: |
| | print( |
| | "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( |
| | epoch, step, load_data_time, time.perf_counter() - begin, avg_loss |
| | ) |
| | ) |
| |
|
| | global_step += 1 |
| |
|
| | if global_step % args.save_steps == 0: |
| | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| | accelerator.save_state(save_path) |
| |
|
| | begin = time.perf_counter() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|