| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | import os |
| | import warnings |
| | from dataclasses import dataclass |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg |
| | from diffusers.utils import convert_state_dict_to_diffusers |
| |
|
| | from ..core import randn_tensor |
| | from ..import_utils import is_peft_available |
| |
|
| |
|
| | if is_peft_available(): |
| | from peft import LoraConfig |
| | from peft.utils import get_peft_model_state_dict |
| |
|
| |
|
| | @dataclass |
| | class DDPOPipelineOutput(object): |
| | """ |
| | Output class for the diffusers pipeline to be finetuned with the DDPO trainer |
| | |
| | Args: |
| | images (`torch.Tensor`): |
| | The generated images. |
| | latents (`List[torch.Tensor]`): |
| | The latents used to generate the images. |
| | log_probs (`List[torch.Tensor]`): |
| | The log probabilities of the latents. |
| | |
| | """ |
| |
|
| | images: torch.Tensor |
| | latents: torch.Tensor |
| | log_probs: torch.Tensor |
| |
|
| |
|
| | @dataclass |
| | class DDPOSchedulerOutput(object): |
| | """ |
| | Output class for the diffusers scheduler to be finetuned with the DDPO trainer |
| | |
| | Args: |
| | latents (`torch.Tensor`): |
| | Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)` |
| | log_probs (`torch.Tensor`): |
| | Log probability of the above mentioned sample. Shape: `(batch_size)` |
| | """ |
| |
|
| | latents: torch.Tensor |
| | log_probs: torch.Tensor |
| |
|
| |
|
| | class DDPOStableDiffusionPipeline(object): |
| | """ |
| | Main class for the diffusers pipeline to be finetuned with the DDPO trainer |
| | """ |
| |
|
| | def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: |
| | raise NotImplementedError |
| |
|
| | def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: |
| | raise NotImplementedError |
| |
|
| | @property |
| | def unet(self): |
| | """ |
| | Returns the 2d U-Net model used for diffusion. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def vae(self): |
| | """ |
| | Returns the Variational Autoencoder model used from mapping images to and from the latent space |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def tokenizer(self): |
| | """ |
| | Returns the tokenizer used for tokenizing text inputs |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def scheduler(self): |
| | """ |
| | Returns the scheduler associated with the pipeline used for the diffusion process |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def text_encoder(self): |
| | """ |
| | Returns the text encoder used for encoding text inputs |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def autocast(self): |
| | """ |
| | Returns the autocast context manager |
| | """ |
| | raise NotImplementedError |
| |
|
| | def set_progress_bar_config(self, *args, **kwargs): |
| | """ |
| | Sets the progress bar config for the pipeline |
| | """ |
| | raise NotImplementedError |
| |
|
| | def save_pretrained(self, *args, **kwargs): |
| | """ |
| | Saves all of the model weights |
| | """ |
| | raise NotImplementedError |
| |
|
| | def get_trainable_layers(self, *args, **kwargs): |
| | """ |
| | Returns the trainable parameters of the pipeline |
| | """ |
| | raise NotImplementedError |
| |
|
| | def save_checkpoint(self, *args, **kwargs): |
| | """ |
| | Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state |
| | """ |
| | raise NotImplementedError |
| |
|
| | def load_checkpoint(self, *args, **kwargs): |
| | """ |
| | Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state |
| | """ |
| | raise NotImplementedError |
| |
|
| |
|
| | def _left_broadcast(input_tensor, shape): |
| | """ |
| | As opposed to the default direction of broadcasting (right to left), this function broadcasts |
| | from left to right |
| | Args: |
| | input_tensor (`torch.FloatTensor`): is the tensor to broadcast |
| | shape (`Tuple[int]`): is the shape to broadcast to |
| | """ |
| | input_ndim = input_tensor.ndim |
| | if input_ndim > len(shape): |
| | raise ValueError("The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to") |
| | return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape) |
| |
|
| |
|
| | def _get_variance(self, timestep, prev_timestep): |
| | alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) |
| | alpha_prod_t_prev = torch.where( |
| | prev_timestep.cpu() >= 0, |
| | self.alphas_cumprod.gather(0, prev_timestep.cpu()), |
| | self.final_alpha_cumprod, |
| | ).to(timestep.device) |
| | beta_prod_t = 1 - alpha_prod_t |
| | beta_prod_t_prev = 1 - alpha_prod_t_prev |
| |
|
| | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) |
| |
|
| | return variance |
| |
|
| |
|
| | def scheduler_step( |
| | self, |
| | model_output: torch.FloatTensor, |
| | timestep: int, |
| | sample: torch.FloatTensor, |
| | eta: float = 0.0, |
| | use_clipped_model_output: bool = False, |
| | generator=None, |
| | prev_sample: Optional[torch.FloatTensor] = None, |
| | ) -> DDPOSchedulerOutput: |
| | """ |
| | |
| | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
| | process from the learned model outputs (most often the predicted noise). |
| | |
| | Args: |
| | model_output (`torch.FloatTensor`): direct output from learned diffusion model. |
| | timestep (`int`): current discrete timestep in the diffusion chain. |
| | sample (`torch.FloatTensor`): |
| | current instance of sample being created by diffusion process. |
| | eta (`float`): weight of noise for added noise in diffusion step. |
| | use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped |
| | predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when |
| | `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would |
| | coincide with the one provided as input and `use_clipped_model_output` will have not effect. |
| | generator: random number generator. |
| | variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we |
| | can directly provide the noise for the variance itself. This is useful for methods such as |
| | CycleDiffusion. (https://arxiv.org/abs/2210.05559) |
| | |
| | Returns: |
| | `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample |
| | """ |
| |
|
| | if self.num_inference_steps is None: |
| | raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps |
| | |
| | prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) |
| |
|
| | |
| | alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) |
| | alpha_prod_t_prev = torch.where( |
| | prev_timestep.cpu() >= 0, |
| | self.alphas_cumprod.gather(0, prev_timestep.cpu()), |
| | self.final_alpha_cumprod, |
| | ) |
| | alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) |
| | alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) |
| |
|
| | beta_prod_t = 1 - alpha_prod_t |
| |
|
| | |
| | |
| | if self.config.prediction_type == "epsilon": |
| | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| | pred_epsilon = model_output |
| | elif self.config.prediction_type == "sample": |
| | pred_original_sample = model_output |
| | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
| | elif self.config.prediction_type == "v_prediction": |
| | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
| | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
| | else: |
| | raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`") |
| |
|
| | |
| | if self.config.thresholding: |
| | pred_original_sample = self._threshold_sample(pred_original_sample) |
| | elif self.config.clip_sample: |
| | pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) |
| |
|
| | |
| | |
| | variance = _get_variance(self, timestep, prev_timestep) |
| | std_dev_t = eta * variance ** (0.5) |
| | std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) |
| |
|
| | if use_clipped_model_output: |
| | |
| | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
| |
|
| | |
| | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon |
| |
|
| | |
| | prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
| |
|
| | if prev_sample is not None and generator is not None: |
| | raise ValueError("Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`.") |
| |
|
| | if prev_sample is None: |
| | variance_noise = randn_tensor( |
| | model_output.shape, |
| | generator=generator, |
| | device=model_output.device, |
| | dtype=model_output.dtype, |
| | ) |
| | prev_sample = prev_sample_mean + std_dev_t * variance_noise |
| |
|
| | |
| | log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi))) |
| | |
| | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) |
| |
|
| | return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob) |
| |
|
| |
|
| | |
| | |
| | @torch.no_grad() |
| | def pipeline_step( |
| | self, |
| | prompt: Optional[Union[str, List[str]]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| | callback_steps: int = 1, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | guidance_rescale: float = 0.0, |
| | ): |
| | r""" |
| | Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| | `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| | usually at the expense of lower image quality. |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| | less than `1`). |
| | num_images_per_prompt (`int`, *optional*, defaults to 1): |
| | The number of images to generate per prompt. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
| | [`schedulers.DDIMScheduler`], will be ignored for others. |
| | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
| | to make generation deterministic. |
| | latents (`torch.FloatTensor`, *optional*): |
| | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
| | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| | tensor will ge generated by sampling using the supplied random `generator`. |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| | provided, text embeddings will be generated from `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| | argument. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generate image. Choose between |
| | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | callback (`Callable`, *optional*): |
| | A function that will be called every `callback_steps` steps during inference. The function will be |
| | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| | callback_steps (`int`, *optional*, defaults to 1): |
| | The frequency at which the `callback` function will be called. If not specified, the callback will be |
| | called at every step. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| | `self.processor` in |
| | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
| | guidance_rescale (`float`, *optional*, defaults to 0.7): |
| | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are |
| | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of |
| | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). |
| | Guidance rescale factor should fix overexposure when using zero terminal SNR. |
| | |
| | Examples: |
| | |
| | Returns: |
| | `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities |
| | """ |
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | height, |
| | width, |
| | callback_steps, |
| | negative_prompt, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | ) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | |
| | text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
| | prompt_embeds = self._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=text_encoder_lora_scale, |
| | ) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | all_latents = [latents] |
| | all_log_probs = [] |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if do_classifier_free_guidance and guidance_rescale > 0.0: |
| | |
| | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
| |
|
| | |
| | scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) |
| | latents = scheduler_output.latents |
| | log_prob = scheduler_output.log_probs |
| |
|
| | all_latents.append(latents) |
| | all_log_probs.append(log_prob) |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | callback(i, t, latents) |
| |
|
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| | self.final_offload_hook.offload() |
| |
|
| | return DDPOPipelineOutput(image, all_latents, all_log_probs) |
| |
|
| |
|
| | class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): |
| | def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): |
| | self.sd_pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name, revision=pretrained_model_revision) |
| |
|
| | self.use_lora = use_lora |
| | self.pretrained_model = pretrained_model_name |
| | self.pretrained_revision = pretrained_model_revision |
| |
|
| | try: |
| | self.sd_pipeline.load_lora_weights( |
| | pretrained_model_name, |
| | weight_name="pytorch_lora_weights.safetensors", |
| | revision=pretrained_model_revision, |
| | ) |
| | self.use_lora = True |
| | except OSError: |
| | if use_lora: |
| | warnings.warn("If you are aware that the pretrained model has no lora weights to it, ignore this message. " "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder.") |
| |
|
| | self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config) |
| | self.sd_pipeline.safety_checker = None |
| |
|
| | |
| | self.sd_pipeline.vae.requires_grad_(False) |
| | self.sd_pipeline.text_encoder.requires_grad_(False) |
| | self.sd_pipeline.unet.requires_grad_(not self.use_lora) |
| |
|
| | def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: |
| | return pipeline_step(self.sd_pipeline, *args, **kwargs) |
| |
|
| | def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: |
| | return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) |
| |
|
| | @property |
| | def unet(self): |
| | return self.sd_pipeline.unet |
| |
|
| | @property |
| | def vae(self): |
| | return self.sd_pipeline.vae |
| |
|
| | @property |
| | def tokenizer(self): |
| | return self.sd_pipeline.tokenizer |
| |
|
| | @property |
| | def scheduler(self): |
| | return self.sd_pipeline.scheduler |
| |
|
| | @property |
| | def text_encoder(self): |
| | return self.sd_pipeline.text_encoder |
| |
|
| | @property |
| | def autocast(self): |
| | return contextlib.nullcontext if self.use_lora else None |
| |
|
| | def save_pretrained(self, output_dir): |
| | if self.use_lora: |
| | state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet)) |
| | self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) |
| | self.sd_pipeline.save_pretrained(output_dir) |
| |
|
| | def set_progress_bar_config(self, *args, **kwargs): |
| | self.sd_pipeline.set_progress_bar_config(*args, **kwargs) |
| |
|
| | def get_trainable_layers(self): |
| | if self.use_lora: |
| | lora_config = LoraConfig( |
| | r=4, |
| | lora_alpha=4, |
| | init_lora_weights="gaussian", |
| | target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
| | ) |
| | self.sd_pipeline.unet.add_adapter(lora_config) |
| |
|
| | |
| | for param in self.sd_pipeline.unet.parameters(): |
| | |
| | if param.requires_grad: |
| | param.data = param.to(torch.float32) |
| | return self.sd_pipeline.unet |
| | else: |
| | return self.sd_pipeline.unet |
| |
|
| | def save_checkpoint(self, models, weights, output_dir): |
| | if len(models) != 1: |
| | raise ValueError("Given how the trainable params were set, this should be of length 1") |
| | if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None: |
| | state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0])) |
| | self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) |
| | elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): |
| | models[0].save_pretrained(os.path.join(output_dir, "unet")) |
| | else: |
| | raise ValueError(f"Unknown model type {type(models[0])}") |
| |
|
| | def load_checkpoint(self, models, input_dir): |
| | if len(models) != 1: |
| | raise ValueError("Given how the trainable params were set, this should be of length 1") |
| | if self.use_lora: |
| | lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(input_dir, weight_name="pytorch_lora_weights.safetensors") |
| | self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0]) |
| |
|
| | elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): |
| | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") |
| | models[0].register_to_config(**load_model.config) |
| | models[0].load_state_dict(load_model.state_dict()) |
| | del load_model |
| | else: |
| | raise ValueError(f"Unknown model type {type(models[0])}") |
| |
|