Spaces:
Runtime error
Runtime error
| # modified starting from HuggingFace diffusers train_dreambooth.py example | |
| # https://github.com/huggingface/diffusers/blob/024c4376fb19caa85275c038f071b6e1446a5cad/examples/dreambooth/train_dreambooth.py | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from diffusers import AutoencoderKL, StableDiffusionPipeline | |
| from torchvision.utils import make_grid | |
| import numpy as np | |
| from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( | |
| download_from_original_stable_diffusion_ckpt, | |
| ) | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from diffusers.schedulers import UniPCMultistepScheduler | |
| from .data import PNGDataModule | |
| logger = get_logger(__name__) | |
| class Lab(Accelerator): | |
| def __init__(self, args, control_pipe=None): | |
| self.cond_key = "prompts" | |
| self.target_key = "images" | |
| self.args = args | |
| self.output_dir = Path(args.output_dir) | |
| logging_dir = str(self.output_dir / "logs") | |
| accelerator_project_config = ProjectConfiguration( | |
| logging_dir=logging_dir, | |
| ) | |
| super().__init__( | |
| mixed_precision=args.mixed_precision, | |
| log_with=args.report_to, | |
| project_config=accelerator_project_config, | |
| ) | |
| if self.mixed_precision == "fp16": | |
| self.weight_dtype = torch.float16 | |
| elif self.mixed_precision == "bf16": | |
| self.weight_dtype = torch.bfloat16 | |
| else: | |
| self.weight_dtype = torch.float32 | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| if control_pipe is None: | |
| control_pipe = self.load_pipe( | |
| args.pretrained_model_name_or_path, args.controlnet_weights_path | |
| ) | |
| self.control_pipe = control_pipe | |
| vae = control_pipe.vae | |
| unet = control_pipe.unet | |
| text_encoder = control_pipe.text_encoder | |
| tokenizer = control_pipe.tokenizer | |
| controlnet = ( | |
| control_pipe.controlnet if hasattr(control_pipe, "controlnet") else None | |
| ) | |
| self.noise_scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config) | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| if controlnet: | |
| unet.requires_grad_(False) | |
| if args.training_stage == "zero convolutions": | |
| controlnet.requires_grad_(False) | |
| controlnet.controlnet_down_blocks.requires_grad_(True) | |
| controlnet.controlnet_mid_block.requires_grad_(True) | |
| # optimize only the zero convolution weights | |
| params_to_optimize = list( | |
| controlnet.controlnet_down_blocks.parameters() | |
| ) + list(controlnet.controlnet_mid_block.parameters()) | |
| elif args.training_stage == "input hint blocks": | |
| controlnet.requires_grad_(False) | |
| controlnet.controlnet_cond_embedding.requires_grad_(True) | |
| params_to_optimize = list( | |
| controlnet.controlnet_cond_embedding.parameters() | |
| ) | |
| else: | |
| controlnet.requires_grad_(True) | |
| params_to_optimize = list(controlnet.parameters()) | |
| else: | |
| unet.requires_grad_(True) | |
| params_to_optimize = list(unet.parameters()) | |
| self.params_to_optimize = params_to_optimize | |
| args.learning_rate = ( | |
| args.learning_rate | |
| * args.gradient_accumulation_steps | |
| * args.batch_size | |
| * self.num_processes | |
| ) | |
| if args.use_8bit_adam: | |
| import bitsandbytes as bnb | |
| optimizer_class = bnb.optim.AdamW8bit | |
| else: | |
| optimizer_class = torch.optim.AdamW | |
| self.optimizer = self.prepare( | |
| optimizer_class( | |
| params_to_optimize, | |
| lr=args.learning_rate, | |
| ) | |
| ) | |
| if args.enable_xformers_memory_efficient_attention: | |
| unet.enable_xformers_memory_efficient_attention() | |
| if controlnet: | |
| controlnet.enable_xformers_memory_efficient_attention() | |
| if args.gradient_checkpointing: | |
| unet.enable_gradient_checkpointing() | |
| if controlnet: | |
| controlnet.enable_gradient_checkpointing() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| datamodule = PNGDataModule( | |
| tokenizer=tokenizer, | |
| from_hf_hub=args.from_hf_hub, | |
| resolution=[args.resolution, args.resolution], | |
| target_key=self.target_key, | |
| cond_key=self.cond_key, | |
| persistent_workers=True, | |
| num_workers=args.dataloader_num_workers, | |
| batch_size=args.batch_size, | |
| controlnet_hint_key=None if controlnet is None else args.controlnet_hint_key, | |
| ) | |
| self.train_dataloader = self.prepare( | |
| datamodule.get_dataloader(args.train_data_dir, shuffle=True) | |
| ) | |
| if args.valid_data_dir: | |
| self.valid_dataloader = self.prepare( | |
| datamodule.get_dataloader(args.valid_data_dir) | |
| ) | |
| self.vae = vae.to(self.device, dtype=self.weight_dtype) | |
| self.text_encoder = text_encoder.to(self.device, dtype=self.weight_dtype) | |
| if controlnet: | |
| controlnet = self.prepare(controlnet) | |
| self.controlnet = controlnet.to(self.device, dtype=torch.float32) | |
| self.unet = unet.to(self.device, dtype=self.weight_dtype) | |
| else: | |
| unet = self.prepare(unet) | |
| self.unet = unet.to(self.device, dtype=torch.float32) | |
| self.controlnet = None | |
| def load_pipe(self, sd_model_path, controlnet_path=None): | |
| if self.args.vae_path: | |
| vae = AutoencoderKL.from_pretrained( | |
| self.args.vae_path, torch_dtype=self.weight_dtype | |
| ) | |
| if os.path.isfile(sd_model_path): | |
| file_ext = sd_model_path.rsplit(".", 1)[-1] | |
| from_safetensors = file_ext == "safetensors" | |
| pipe = download_from_original_stable_diffusion_ckpt( | |
| sd_model_path, | |
| from_safetensors=from_safetensors, | |
| device="cpu", | |
| load_safety_checker=False, | |
| ) | |
| pipe.safety_checker = None | |
| pipe.feature_extractor = None | |
| if self.args.vae_path: | |
| pipe.vae = vae | |
| else: | |
| if self.args.vae_path: | |
| kw_args = dict(vae=vae) | |
| else: | |
| kw_args = dict() | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| sd_model_path, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False, | |
| torch_dtype=self.weight_dtype, | |
| **kw_args | |
| ) | |
| if not controlnet_path: | |
| return pipe | |
| pathobj = Path(controlnet_path) | |
| if pathobj.is_file(): | |
| controlnet = ControlNetModel.from_config( | |
| ControlNetModel.load_config("configs/controlnet_config.json") | |
| ) | |
| controlnet.load_weights_from_sd_ckpt(controlnet_path) | |
| else: | |
| controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) | |
| subfolder = str(pathobj.parts[-1]) | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_path, | |
| subfolder=subfolder, | |
| low_cpu_mem_usage=False, | |
| device_map=None, | |
| ) | |
| return StableDiffusionControlNetPipeline( | |
| **pipe.components, | |
| controlnet=controlnet, | |
| requires_safety_checker=False, | |
| ) | |
| def compute_loss(self, batch): | |
| images = batch[self.target_key].to(dtype=self.weight_dtype) | |
| latents = self.vae.encode(images).latent_dist.sample() | |
| latents = latents * self.vae.config.scaling_factor | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint( | |
| 0, | |
| self.noise_scheduler.config.num_train_timesteps, | |
| (latents.shape[0],), | |
| device=latents.device, | |
| ) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| # Get the text embedding for conditioning | |
| encoder_hidden_states = self.text_encoder(batch[self.cond_key])[0] | |
| if self.controlnet: | |
| if self.args.controlnet_hint_key in batch: | |
| controlnet_hint = batch[self.args.controlnet_hint_key].to( | |
| dtype=self.weight_dtype | |
| ) | |
| else: | |
| controlnet_hint = torch.zeros(images.shape).to(images) | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| noisy_latents, | |
| timesteps, | |
| encoder_hidden_states=encoder_hidden_states, | |
| controlnet_cond=controlnet_hint, | |
| return_dict=False, | |
| ) | |
| else: | |
| down_block_res_samples, mid_block_res_sample = None, None | |
| noise_pred = self.unet( | |
| noisy_latents, | |
| timesteps, | |
| encoder_hidden_states=encoder_hidden_states, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| ).sample | |
| # Get the target for loss depending on the prediction type | |
| if self.noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif self.noise_scheduler.config.prediction_type == "v_prediction": | |
| target = self.noise_scheduler.get_velocity(latents, noise, timesteps) | |
| else: | |
| raise ValueError( | |
| f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" | |
| ) | |
| loss = F.mse_loss(noise_pred, target, reduction="mean") | |
| return loss, encoder_hidden_states | |
| def decode_latents(self, latents): | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| output_latents = self.vae.decode(latents).sample | |
| output_latents = (output_latents / 2 + 0.5).clamp(0, 1) | |
| return output_latents | |
| def log_images(self, batch, encoder_hidden_states, cond_scales=[0.0, 0.5, 1.0]): | |
| input_tensors = batch[self.target_key].to(self.weight_dtype) | |
| input_tensors = (input_tensors / 2 + 0.5).clamp(0, 1) | |
| tensors_to_log = [input_tensors.cpu()] | |
| [height, width] = input_tensors.shape[-2:] | |
| if self.controlnet: | |
| if self.args.controlnet_hint_key in batch: | |
| controlnet_hint = batch[self.args.controlnet_hint_key].to( | |
| self.weight_dtype | |
| ) | |
| else: | |
| controlnet_hint = None | |
| for cond_scale in cond_scales: | |
| latents = self.control_pipe( | |
| image=controlnet_hint, | |
| prompt_embeds=encoder_hidden_states, | |
| controlnet_conditioning_scale=cond_scale, | |
| height=height, | |
| width=width, | |
| output_type="latent", | |
| num_inference_steps=25, | |
| )[0] | |
| tensors_to_log.append(self.decode_latents(latents).detach().cpu()) | |
| if controlnet_hint is not None: | |
| tensors_to_log.append(controlnet_hint.detach().cpu()) | |
| else: | |
| latents = self.control_pipe( | |
| prompt_embeds=encoder_hidden_states, | |
| height=height, | |
| width=width, | |
| output_type="latent", | |
| num_inference_steps=25, | |
| )[0] | |
| tensors_to_log.append(self.decode_latents(latents).detach().cpu()) | |
| image_tensors = torch.cat(tensors_to_log) | |
| grid = make_grid(image_tensors, normalize=False, nrow=input_tensors.shape[0]) | |
| grid = grid.permute(1, 2, 0).squeeze(-1) * 255 | |
| grid = grid.numpy().astype(np.uint8) | |
| image_grid = Image.fromarray(grid) | |
| image_grid.save(Path(self.trackers[0].logging_dir) / f"{self.global_step}.png") | |
| def save_weights(self, to_safetensors=True): | |
| save_dir = self.output_dir / f"checkpoint-{self.global_step}" | |
| os.makedirs(save_dir, exist_ok=True) | |
| if self.args.save_whole_pipeline: | |
| self.control_pipe.save_pretrained( | |
| str(save_dir), safe_serialization=to_safetensors | |
| ) | |
| elif self.controlnet: | |
| self.controlnet.save_pretrained( | |
| str(save_dir / "controlnet"), safe_serialization=to_safetensors | |
| ) | |
| else: | |
| self.unet.save_pretrained( | |
| str(save_dir / "unet"), safe_serialization=to_safetensors | |
| ) | |
| def train(self, num_train_epochs=1000, gr_progress = None): | |
| args = self.args | |
| if args.num_train_epochs: | |
| num_train_epochs = args.num_train_epochs | |
| max_train_steps = ( | |
| num_train_epochs | |
| * len(self.train_dataloader) | |
| // args.gradient_accumulation_steps | |
| ) | |
| if self.is_main_process: | |
| self.init_trackers("tb_logs", config=vars(args)) | |
| self.global_step = 0 | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm( | |
| range(max_train_steps), | |
| disable=not self.is_local_main_process, | |
| ) | |
| progress_bar.set_description("Steps") | |
| try: | |
| for epoch in range(num_train_epochs): | |
| # run training loop | |
| if gr_progress is not None: | |
| gr_progress(0, desc=f"Starting Epoch {epoch}") | |
| if self.controlnet: | |
| self.controlnet.train() | |
| else: | |
| self.unet.train() | |
| for i, batch in enumerate(self.train_dataloader): | |
| loss, encoder_hidden_states = self.compute_loss(batch) | |
| loss /= args.gradient_accumulation_steps | |
| self.backward(loss) | |
| if self.global_step % args.gradient_accumulation_steps == 0: | |
| if self.sync_gradients: | |
| self.clip_grad_norm_( | |
| self.params_to_optimize, args.max_grad_norm | |
| ) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if self.sync_gradients: | |
| progress_bar.update(1) | |
| if gr_progress is not None: | |
| gr_progress(float(i/len(self.train_dataloader))) | |
| self.global_step += 1 | |
| if self.is_main_process: | |
| if self.global_step % args.checkpointing_steps == 0: | |
| self.save_weights() | |
| if args.image_logging_steps and ( | |
| self.global_step % args.image_logging_steps == 0 | |
| or self.global_step == 1 | |
| ): | |
| self.log_images(batch, encoder_hidden_states) | |
| logs = {"training_loss": loss.detach().item()} | |
| self.log(logs, step=self.global_step) | |
| progress_bar.set_postfix(**logs) | |
| if self.global_step >= max_train_steps: | |
| break | |
| self.wait_for_everyone() | |
| # run validation loop | |
| if args.valid_data_dir: | |
| total_valid_loss = 0 | |
| if self.controlnet: | |
| self.controlnet.eval() | |
| else: | |
| self.unet.eval() | |
| for batch in self.valid_dataloader: | |
| with torch.no_grad(): | |
| loss, encoder_hidden_states = self.compute_loss(batch) | |
| loss = loss.detach().item() | |
| total_valid_loss += loss | |
| logs = {"validation_loss": loss} | |
| progress_bar.set_postfix(**logs) | |
| self.log( | |
| { | |
| "validation_loss": total_valid_loss | |
| / len(self.valid_dataloader) | |
| }, | |
| step=self.global_step, | |
| ) | |
| self.wait_for_everyone() | |
| except KeyboardInterrupt: | |
| print("Keyboard interrupt detected, attempting to save trained weights") | |
| # except Exception as e: | |
| # print(f"Encountered error {e}, attempting to save trained weights") | |
| self.save_weights() | |
| self.end_training() | |