Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers.models.attention_processor import LoRAAttnProcessor | |
| def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None): | |
| """ | |
| Add tokens to the tokenizer and set the initial value of token embeddings | |
| """ | |
| tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token) | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| token_embeds = text_encoder.get_input_embeddings().weight.data | |
| placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) | |
| if initializer_token: | |
| token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) | |
| for i, placeholder_token_id in enumerate(placeholder_token_ids): | |
| token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]] | |
| else: | |
| for i, placeholder_token_id in enumerate(placeholder_token_ids): | |
| token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id]) | |
| return placeholder_token_ids | |
| def tokenize_prompt(tokenizer, prompt, replace_token=False): | |
| text_inputs = tokenizer( | |
| prompt, | |
| replace_token=replace_token, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| return text_input_ids | |
| def get_processor(self, return_deprecated_lora: bool = False): | |
| r""" | |
| Get the attention processor in use. | |
| Args: | |
| return_deprecated_lora (`bool`, *optional*, defaults to `False`): | |
| Set to `True` to return the deprecated LoRA attention processor. | |
| Returns: | |
| "AttentionProcessor": The attention processor in use. | |
| """ | |
| if not return_deprecated_lora: | |
| return self.processor | |
| # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible | |
| # serialization format for LoRA Attention Processors. It should be deleted once the integration | |
| # with PEFT is completed. | |
| is_lora_activated = { | |
| name: module.lora_layer is not None | |
| for name, module in self.named_modules() | |
| if hasattr(module, "lora_layer") | |
| } | |
| # 1. if no layer has a LoRA activated we can return the processor as usual | |
| if not any(is_lora_activated.values()): | |
| return self.processor | |
| # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` | |
| is_lora_activated.pop("add_k_proj", None) | |
| is_lora_activated.pop("add_v_proj", None) | |
| # 2. else it is not posssible that only some layers have LoRA activated | |
| if not all(is_lora_activated.values()): | |
| raise ValueError( | |
| f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" | |
| ) | |
| # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor | |
| # non_lora_processor_cls_name = self.processor.__class__.__name__ | |
| # lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) | |
| hidden_size = self.inner_dim | |
| # now create a LoRA attention processor from the LoRA layers | |
| kwargs = { | |
| "cross_attention_dim": self.cross_attention_dim, | |
| "rank": self.to_q.lora_layer.rank, | |
| "network_alpha": self.to_q.lora_layer.network_alpha, | |
| "q_rank": self.to_q.lora_layer.rank, | |
| "q_hidden_size": self.to_q.lora_layer.out_features, | |
| "k_rank": self.to_k.lora_layer.rank, | |
| "k_hidden_size": self.to_k.lora_layer.out_features, | |
| "v_rank": self.to_v.lora_layer.rank, | |
| "v_hidden_size": self.to_v.lora_layer.out_features, | |
| "out_rank": self.to_out[0].lora_layer.rank, | |
| "out_hidden_size": self.to_out[0].lora_layer.out_features, | |
| } | |
| if hasattr(self.processor, "attention_op"): | |
| kwargs["attention_op"] = self.processor.attention_op | |
| lora_processor = LoRAAttnProcessor(hidden_size, **kwargs) | |
| lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) | |
| lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) | |
| lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) | |
| lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) | |
| return lora_processor | |
| def get_attn_processors(self): | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| # set recursively | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = get_processor(module, return_deprecated_lora=True) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |