Spaces:
Runtime error
Runtime error
| from typing import Callable, Optional | |
| import torch | |
| from accelerate.logging import get_logger | |
| from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.models.cross_attention import CrossAttention | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline | |
| from diffusers.pipelines.stable_diffusion.safety_checker import ( | |
| StableDiffusionSafetyChecker, | |
| ) | |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| logger = get_logger(__name__) | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None | |
| ): | |
| if use_memory_efficient_attention_xformers: | |
| if self.added_kv_proj_dim is not None: | |
| # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP | |
| # which uses this type of cross attention ONLY because the attention mask of format | |
| # [0, ..., -10.000, ..., 0, ...,] is not supported | |
| raise NotImplementedError( | |
| "Memory efficient attention with `xformers` is currently not supported when" | |
| " `self.added_kv_proj_dim` is defined." | |
| ) | |
| elif not is_xformers_available(): | |
| raise ModuleNotFoundError( | |
| ( | |
| "Refer to https://github.com/facebookresearch/xformers for more information on how to install" | |
| " xformers" | |
| ), | |
| name="xformers", | |
| ) | |
| elif not torch.cuda.is_available(): | |
| raise ValueError( | |
| "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" | |
| " only available for GPU " | |
| ) | |
| else: | |
| try: | |
| # Make sure we can run the memory efficient attention | |
| _ = xformers.ops.memory_efficient_attention( | |
| torch.randn((1, 2, 40), device="cuda"), | |
| torch.randn((1, 2, 40), device="cuda"), | |
| torch.randn((1, 2, 40), device="cuda"), | |
| ) | |
| except Exception as e: | |
| raise e | |
| processor = CustomDiffusionXFormersAttnProcessor( | |
| attention_op=attention_op) | |
| else: | |
| processor = CustomDiffusionAttnProcessor() | |
| self.set_processor(processor) | |
| class CustomDiffusionAttnProcessor: | |
| def __call__( | |
| self, | |
| attn: CrossAttention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| ): | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size) | |
| query = attn.to_q(hidden_states) | |
| crossattn = False | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| crossattn = True | |
| if attn.cross_attention_norm: | |
| encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| if crossattn: | |
| detach = torch.ones_like(key) | |
| detach[:, :1, :] = detach[:, :1, :] * 0. | |
| key = detach * key + (1 - detach) * key.detach() | |
| value = detach * value + (1 - detach) * value.detach() | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| class CustomDiffusionXFormersAttnProcessor: | |
| def __init__(self, attention_op: Optional[Callable] = None): | |
| self.attention_op = attention_op | |
| def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size) | |
| query = attn.to_q(hidden_states) | |
| crossattn = False | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| else: | |
| crossattn = True | |
| if attn.cross_attention_norm: | |
| encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| if crossattn: | |
| detach = torch.ones_like(key) | |
| detach[:, :1, :] = detach[:, :1, :] * 0. | |
| key = detach * key + (1 - detach) * key.detach() | |
| value = detach * value + (1 - detach) * value.detach() | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias=attention_mask, op=self.attention_op | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| class CustomDiffusionPipeline(StableDiffusionPipeline): | |
| r""" | |
| Pipeline for custom diffusion model. | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
| library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). | |
| Args: | |
| vae ([`AutoencoderKL`]): | |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
| text_encoder ([`CLIPTextModel`]): | |
| Frozen text-encoder. Stable Diffusion uses the text portion of | |
| [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | |
| the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | |
| tokenizer (`CLIPTokenizer`): | |
| Tokenizer of class | |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | |
| unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
| scheduler ([`SchedulerMixin`]): | |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
| safety_checker ([`StableDiffusionSafetyChecker`]): | |
| Classification module that estimates whether generated images could be considered offensive or harmful. | |
| Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. | |
| feature_extractor ([`CLIPFeatureExtractor`]): | |
| Model that extracts features from generated images to be used as inputs for the `safety_checker`. | |
| modifier_token_id: list of id of tokens related to the target concept that are modified when ablated. | |
| """ | |
| _optional_components = ["safety_checker", | |
| "feature_extractor", "modifier_token_id"] | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler: SchedulerMixin, | |
| safety_checker: StableDiffusionSafetyChecker, | |
| feature_extractor: CLIPFeatureExtractor, | |
| requires_safety_checker: bool = True, | |
| modifier_token_id: list = [], | |
| ): | |
| super().__init__(vae, | |
| text_encoder, | |
| tokenizer, | |
| unet, | |
| scheduler, | |
| safety_checker, | |
| feature_extractor, | |
| requires_safety_checker) | |
| self.modifier_token_id = modifier_token_id | |
| def save_pretrained(self, save_path, parameter_group="cross-attn", all=False): | |
| if all: | |
| super().save_pretrained(save_path) | |
| else: | |
| delta_dict = {'unet': {}} | |
| if parameter_group == 'embedding': | |
| delta_dict['text_encoder'] = self.text_encoder.state_dict() | |
| for name, params in self.unet.named_parameters(): | |
| if parameter_group == "cross-attn": | |
| if 'attn2.to_k' in name or 'attn2.to_v' in name: | |
| delta_dict['unet'][name] = params.cpu().clone() | |
| elif parameter_group == "full-weight": | |
| delta_dict['unet'][name] = params.cpu().clone() | |
| else: | |
| raise ValueError( | |
| "parameter_group argument only supports one of [cross-attn, full-weight, embedding]" | |
| ) | |
| torch.save(delta_dict, save_path) | |
| def load_model(self, save_path): | |
| st = torch.load(save_path) | |
| print(st.keys()) | |
| if 'text_encoder' in st: | |
| self.text_encoder.load_state_dict(st['text_encoder']) | |
| for name, params in self.unet.named_parameters(): | |
| if name in st['unet']: | |
| params.data.copy_(st['unet'][f'{name}']) | |