| import itertools | |
| import math | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from diffusers import DDPMScheduler, StableDiffusionPipeline | |
| from diffusers.optimization import get_scheduler | |
| import bitsandbytes as bnb | |
| from tqdm.auto import tqdm | |
| from argparse import Namespace | |
| import logging | |
| from dataset import DreamBoothDataset | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def load_models(pretrained_model_name_or_path): | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | |
| unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | |
| return text_encoder, vae, unet, tokenizer | |
| def training_function(args, text_encoder, vae, unet, tokenizer): | |
| set_seed(args.seed) | |
| accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) | |
| vae.requires_grad_(False) | |
| if not args.train_text_encoder: | |
| text_encoder.requires_grad_(False) | |
| if args.gradient_checkpointing: | |
| unet.enable_gradient_checkpointing() | |
| if args.train_text_encoder: | |
| text_encoder.gradient_checkpointing_enable() | |
| optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW | |
| params_to_optimize = ( | |
| itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() | |
| ) | |
| optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) | |
| noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| train_dataset = DreamBoothDataset( | |
| instance_data_root=args.instance_data_dir, | |
| instance_prompt=args.instance_prompt, | |
| class_data_root=args.class_data_dir if args.with_prior_preservation else None, | |
| class_prompt=args.class_prompt, | |
| tokenizer=tokenizer, | |
| size=args.resolution, | |
| center_crop=args.center_crop, | |
| ) | |
| def collate_fn(examples): | |
| input_ids = [example["instance_prompt_ids"] for example in examples] | |
| pixel_values = [example["instance_images"] for example in examples] | |
| if args.with_prior_preservation: | |
| input_ids += [example["class_prompt_ids"] for example in examples] | |
| pixel_values += [example["class_images"] for example in examples] | |
| pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() | |
| input_ids = tokenizer.pad({"input_ids": input_ids}, padding="max_length", return_tensors="pt", max_length=tokenizer.model_max_length).input_ids | |
| return {"input_ids": input_ids, "pixel_values": pixel_values} | |
| train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn) | |
| unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) | |
| if args.train_text_encoder: | |
| text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader) | |
| if args.with_prior_preservation: | |
| class_images_dir = Path(args.class_data_dir) | |
| class_images_dir.mkdir(parents=True, exist_ok=True) | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| vae.to(accelerator.device, dtype=weight_dtype) | |
| text_encoder.to(accelerator.device, dtype=weight_dtype) | |
| if args.train_text_encoder: | |
| text_encoder.train() | |
| unet.train() | |
| global_step = 0 | |
| for epoch in range(args.num_train_epochs): | |
| progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description(f"Epoch {epoch}") | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(unet): | |
| latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() | |
| latents = latents * 0.18215 | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] | |
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(unet.parameters(), 1.0) | |
| if args.train_text_encoder: | |
| accelerator.clip_grad_norm_(text_encoder.parameters(), 1.0) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1) | |
| global_step += 1 | |
| logs = {"loss": loss.detach().item(), "lr": args.learning_rate} | |
| progress_bar.set_postfix(**logs) | |
| accelerator.log(logs, step=global_step) | |
| if global_step >= args.max_train_steps: | |
| break | |
| progress_bar.close() | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| if (epoch + 1) % args.save_interval == 0: | |
| pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path) | |
| pipeline.save_pretrained(args.output_dir) | |
| accelerator.end_training() | |
| def parse_args(): | |
| args = Namespace( | |
| pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
| instance_data_dir="datasets/imagedata/images", | |
| class_data_dir="./class_images", | |
| output_dir="./output", | |
| instance_prompt="a photo of yash Kothari", | |
| class_prompt="A photo of Yash Kothari with medium, dark hair and a full beard, smiling slightly", | |
| resolution=512, | |
| center_crop=False, | |
| train_text_encoder=True, | |
| gradient_accumulation_steps=1, | |
| mixed_precision="fp16", | |
| learning_rate=5e-6, | |
| use_8bit_adam=True, | |
| train_batch_size=4, | |
| num_train_epochs=100, | |
| save_interval=10, | |
| max_train_steps=2000, | |
| gradient_checkpointing=False, | |
| with_prior_preservation=True, | |
| seed=42, | |
| ) | |
| return args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path) | |
| training_function(args, text_encoder, vae, unet, tokenizer) | |