| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| import os |
| import warnings |
| from contextlib import contextmanager |
|
|
| import torch |
| from accelerate import dispatch_model, infer_auto_device_map |
| from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules |
| from accelerate.utils import get_balanced_memory |
| from huggingface_hub import hf_hub_download |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput |
| from transformers.utils import PushToHubMixin |
|
|
| import packaging.version |
| import transformers |
| from typing import Any, Literal, Optional, Union |
|
|
| from .sama import SamaTuner |
| from .utils import ( |
| TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, |
| WEIGHTS_NAME, |
| PeftConfig, |
| PeftType, |
| PromptLearningConfig, |
| TaskType, |
| _set_trainable, |
| get_peft_model_state_dict, |
| set_peft_model_state_dict, |
| shift_tokens_right, |
| ) |
|
|
| class PeftModel(PushToHubMixin, torch.nn.Module): |
| """ |
| |
| """ |
|
|
| def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"): |
| super().__init__() |
| self.peft_config = peft_config |
| self.base_model = model |
| self.config = self.base_model.config |
| self.modules_to_save = None |
| self.active_adapter = adapter_name |
| |
| if isinstance(self.peft_config, PromptLearningConfig): |
| self._setup_prompt_encoder() |
| else: |
| if self.peft_config.peft_type == PeftType.SAMA: |
| self.base_model = SamaTuner(model, {adapter_name: peft_config}, adapter_name) |
| if getattr(self.peft_config, "modules_to_save", None) is not None: |
| self.modules_to_save = self.peft_config.modules_to_save |
| _set_trainable(self) |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.base_model_torch_dtype = getattr(model, "dtype", None) |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| r""" |
| Args: |
| This function saves the adapter model and the adapter configuration files to a directory, so that it can be |
| re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub` |
| method. |
| save_directory (`str`): |
| Directory where the adapter model and configuration files will be saved (will be created if it does not |
| exist). |
| **kwargs: |
| Additional keyword arguments passed along to the `push_to_hub` method. |
| """ |
| if os.path.isfile(save_directory): |
| raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None)) |
| torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) |
|
|
| |
| if self.peft_config.base_model_name_or_path is None: |
| self.peft_config.base_model_name_or_path = ( |
| self.base_model.__dict__.get("name_or_path", None) |
| if isinstance(self.peft_config, PromptLearningConfig) |
| else self.base_model.model.__dict__.get("name_or_path", None) |
| ) |
| inference_mode = self.peft_config.inference_mode |
| self.peft_config.inference_mode = True |
| self.peft_config.save_pretrained(save_directory) |
| self.peft_config.inference_mode = inference_mode |
|
|
| @classmethod |
| def from_pretrained(cls, model, model_id, is_trainable = False, **kwargs): |
| r""" |
| Args: |
| Instantiate a `LoraModel` from a pretrained Lora configuration and weights. |
| model (`transformers.PreTrainedModel`): |
| The model to be adapted. The model should be initialized with the `from_pretrained` method. from |
| `transformers` library. |
| model_id (`str`): |
| The name of the Lora configuration to use. Can be either: |
| - A string, the `model id` of a Lora configuration hosted inside a model repo on |
| huggingface Hub |
| - A path to a directory containing a Lora configuration file saved using the |
| `save_pretrained` method, e.g., ``./my_lora_config_directory/``. |
| """ |
| from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING |
| |
| config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id) |
| config.inference_mode = not is_trainable |
|
|
| if getattr(model, "hf_device_map", None) is not None: |
| remove_hook_from_submodules(model) |
|
|
| if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): |
| model = cls(model, config) |
| else: |
| model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config) |
|
|
| |
| if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)): |
| filename = os.path.join(model_id, WEIGHTS_NAME) |
| else: |
| try: |
| filename = hf_hub_download(model_id, WEIGHTS_NAME) |
| except: |
| raise ValueError( |
| f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " |
| f"Please check that the file {WEIGHTS_NAME} is present at {model_id}." |
| ) |
|
|
| adapters_weights = torch.load( |
| filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ) |
| |
| model = set_peft_model_state_dict(model, adapters_weights) |
| if getattr(model, "hf_device_map", None) is not None: |
| device_map = kwargs.get("device_map", "auto") |
| max_memory = kwargs.get("max_memory", None) |
| no_split_module_classes = model._no_split_modules |
| if device_map != "sequential": |
| max_memory = get_balanced_memory( |
| model, |
| max_memory=max_memory, |
| no_split_module_classes=no_split_module_classes, |
| low_zero=(device_map == "balanced_low_0"), |
| ) |
| if isinstance(device_map, str): |
| device_map = infer_auto_device_map( |
| model, max_memory=max_memory, no_split_module_classes=no_split_module_classes |
| ) |
| model = dispatch_model(model, device_map=device_map) |
| hook = AlignDevicesHook(io_same_device=True) |
| if model.peft_config.peft_type == PeftType.LORA or model.peft_config.peft_type == PeftType.BOTTLENECK \ |
| or model.peft_config.peft_type == PeftType.SAMA: |
| add_hook_to_module(model.base_model.model, hook) |
| else: |
| remove_hook_from_submodules(model.prompt_encoder) |
| add_hook_to_module(model.base_model, hook) |
| |
| |
| |
| return model |
|
|
| def _setup_prompt_encoder(self): |
| transformer_backbone = None |
| for name, module in self.base_model.named_children(): |
| for param in module.parameters(): |
| param.requires_grad = False |
| if isinstance(module, PreTrainedModel): |
| |
| if transformer_backbone is None: |
| transformer_backbone = module |
| self.transformer_backbone_name = name |
|
|
| if self.peft_config.num_transformer_submodules is None: |
| self.peft_config.num_transformer_submodules = ( |
| 2 if self.peft_config.task_type == TaskType.SEQ_2_SEQ_LM else 1 |
| ) |
|
|
| for named_param, value in list(transformer_backbone.named_parameters()): |
| if value.shape[0] == self.base_model.config.vocab_size: |
| self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", "")) |
| break |
|
|
| if self.peft_config.peft_type == PeftType.PROMPT_TUNING: |
| prompt_encoder = PromptEmbedding(self.peft_config, self.word_embeddings) |
| elif self.peft_config.peft_type == PeftType.P_TUNING: |
| prompt_encoder = PromptEncoder(self.peft_config) |
| elif self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| prompt_encoder = PrefixEncoder(self.peft_config) |
| else: |
| raise ValueError("Not supported") |
| self.prompt_encoder = prompt_encoder |
| self.prompt_tokens = torch.arange( |
| self.peft_config.num_virtual_tokens * self.peft_config.num_transformer_submodules |
| ).long() |
|
|
| def get_prompt_embedding_to_save(self): |
| """ |
| Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type != |
| PeftType.LORA`. |
| """ |
| prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(1, -1).to(self.device) |
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens] |
| prompt_embeddings = self.prompt_encoder(prompt_tokens) |
| return prompt_embeddings[0].detach().cpu() |
|
|
| def get_prompt(self, batch_size): |
| """ |
| Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`. |
| """ |
| prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device) |
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens] |
| if self.peft_config.inference_mode: |
| past_key_values = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) |
| else: |
| past_key_values = self.prompt_encoder(prompt_tokens) |
| past_key_values = past_key_values.view( |
| batch_size, |
| self.peft_config.num_virtual_tokens, |
| self.peft_config.num_layers * 2, |
| self.peft_config.num_attention_heads, |
| self.peft_config.token_dim // self.peft_config.num_attention_heads, |
| ) |
| if self.peft_config.num_transformer_submodules == 2: |
| past_key_values = torch.cat([past_key_values, past_key_values], dim=2) |
| past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split( |
| self.peft_config.num_transformer_submodules * 2 |
| ) |
| if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: |
| post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] |
| past_key_values = post_process_fn(past_key_values) |
| return past_key_values |
| else: |
| if self.peft_config.inference_mode: |
| prompts = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) |
| else: |
| prompts = self.prompt_encoder(prompt_tokens) |
| return prompts |
|
|
| def print_trainable_parameters(self): |
| """ |
| Prints the number of trainable parameters in the model. |
| """ |
| trainable_params = 0 |
| all_param = 0 |
| for _, param in self.named_parameters(): |
| num_params = param.numel() |
| |
| if num_params == 0 and hasattr(param, "ds_numel"): |
| num_params = param.ds_numel |
|
|
| all_param += num_params |
| if param.requires_grad: |
| trainable_params += num_params |
| print( |
| f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable: {100 * trainable_params / all_param}%" |
| ) |
|
|
| def __getattr__(self, name: str): |
| """Forward missing attributes to the wrapped module.""" |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.base_model, name) |
|
|
| def forward(self, *args, **kwargs): |
| """ |
| Forward pass of the model. |
| """ |
| return self.get_base_model()(*args, **kwargs) |
|
|
| @contextmanager |
| def disable_adapter(self): |
| """ |
| Disables the adapter module. |
| """ |
| if isinstance(self.peft_config, PromptLearningConfig): |
| old_forward = self.forward |
| self.forward = self.base_model.forward |
| else: |
| self.base_model.disable_adapter_layers() |
| yield |
| if isinstance(self.peft_config, PromptLearningConfig): |
| self.forward = old_forward |
| else: |
| self.base_model.enable_adapter_layers() |
|
|
| def get_base_model(self): |
| """ |
| Returns the base model. |
| """ |
| return self.base_model if isinstance(self.peft_config, PromptLearningConfig) else self.base_model.model |
|
|
|
|
|
|
| class PeftModelForSequenceClassification(PeftModel): |
| """ |
| """ |
|
|
| def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"): |
| super().__init__(model, peft_config, adapter_name) |
| self.modules_to_save = ["classifier", "score"] |
|
|
| |
| |
| |
| |
| user_modules = getattr(peft_config, "modules_to_save", None) or [] |
| default_modules = ["classifier", "score"] |
| self.modules_to_save = list(set(user_modules + default_modules)) |
|
|
| |
| if isinstance(self.base_model, SamaTuner): |
| real_model = self.base_model.model |
| else: |
| real_model = self.base_model |
|
|
| |
| for name, _ in real_model.named_children(): |
| if any(module_name in name for module_name in self.modules_to_save): |
| self.cls_layer_name = name |
|
|
| |
| _set_trainable(self) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| if "num_items_in_batch" in kwargs: |
| kwargs.pop("num_items_in_batch") |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if not isinstance(self.peft_config, PromptLearningConfig): |
| return self.base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| batch_size = input_ids.shape[0] |
| if attention_mask is not None: |
| |
| prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) |
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") |
| kwargs["position_ids"] = None |
| kwargs.update( |
| { |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| } |
| ) |
|
|
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) |
| else: |
| if kwargs.get("token_type_ids", None) is not None: |
| kwargs["token_type_ids"] = torch.cat( |
| ( |
| torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device), |
| kwargs["token_type_ids"], |
| ), |
| dim=1, |
| ).long() |
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| prompts = self.get_prompt(batch_size=batch_size) |
| prompts = prompts.to(inputs_embeds.dtype) |
| inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) |
| return self.base_model(inputs_embeds=inputs_embeds, **kwargs) |
|
|
| def _prefix_tuning_forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| batch_size = input_ids.shape[0] |
| past_key_values = self.get_prompt(batch_size) |
| fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) |
| kwargs.update( |
| { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "inputs_embeds": inputs_embeds, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| "past_key_values": past_key_values, |
| } |
| ) |
| if "past_key_values" in fwd_params: |
| return self.base_model(labels=labels, **kwargs) |
| else: |
| transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) |
| fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) |
| if "past_key_values" not in fwd_params: |
| raise ValueError("Model does not support past key values which are required for prefix tuning.") |
| outputs = transformer_backbone_name(**kwargs) |
| pooled_output = outputs[1] if len(outputs) > 1 else outputs[0] |
| if "dropout" in [name for name, _ in list(self.base_model.named_children())]: |
| pooled_output = self.base_model.dropout(pooled_output) |
| logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.base_model.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.base_model.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class PeftModelForCausalLM(PeftModel): |
| """ |
| Peft model for Causal LM |
| |
| Args: |
| model ([`PreTrainedModel`]): Base transformer model |
| peft_config ([`PeftConfig`]): Peft config. |
| |
| |
| Example:: |
| |
| >>> from transformers import AutoModelForCausalLM >>> from peft_local_tensor import PeftModelForCausalLM, get_peft_config |
| >>> config = { |
| 'peft_type': 'PREFIX_TUNING', 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'num_virtual_tokens': |
| 20, 'token_dim': 1280, 'num_transformer_submodules': 1, 'num_attention_heads': 20, 'num_layers': 36, |
| 'encoder_hidden_size': 1280, 'prefix_projection': False, 'postprocess_past_key_value_function': None |
| } |
| >>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>> |
| peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable |
| params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544 |
| """ |
|
|
| def __init__(self, model, peft_config: PeftConfig, adapter_name: str = "default"): |
| self.prompt_encoder = None |
| self.modules_to_save = None |
| super().__init__(model, peft_config, adapter_name) |
| self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| if not isinstance(self.peft_config, PromptLearningConfig): |
| return self.base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| batch_size = input_ids.shape[0] |
| if attention_mask is not None: |
| |
| prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) |
|
|
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") |
| kwargs["position_ids"] = None |
| if kwargs.get("token_type_ids", None) is not None: |
| warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") |
| kwargs["token_type_ids"] = None |
| kwargs.update( |
| { |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| } |
| ) |
|
|
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| past_key_values = self.get_prompt(batch_size) |
| return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs) |
| else: |
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| |
| if labels is not None: |
| prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device) |
| kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) |
| prompts = self.get_prompt(batch_size=batch_size) |
| prompts = prompts.to(inputs_embeds.dtype) |
| inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) |
| return self.base_model(inputs_embeds=inputs_embeds, **kwargs) |
|
|
| def generate(self, **kwargs): |
| self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation |
| try: |
| if not isinstance(self.peft_config, PromptLearningConfig): |
| outputs = self.base_model.generate(**kwargs) |
| else: |
| if "input_ids" not in kwargs: |
| raise ValueError("input_ids must be provided for Peft model generation") |
| if kwargs.get("attention_mask", None) is not None: |
| |
| prefix_attention_mask = torch.ones( |
| kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens |
| ).to(kwargs["input_ids"].device) |
| kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1) |
|
|
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn( |
| "Position ids are not supported for parameter efficient tuning. Ignoring position ids." |
| ) |
| kwargs["position_ids"] = None |
| if kwargs.get("token_type_ids", None) is not None: |
| warnings.warn( |
| "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" |
| ) |
| kwargs["token_type_ids"] = None |
|
|
| outputs = self.base_model.generate(**kwargs) |
| except: |
| self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation |
| raise |
| else: |
| self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation |
| return outputs |
|
|
| def prepare_inputs_for_generation(self, *args, **kwargs): |
| model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) |
| if isinstance(self.peft_config, PromptLearningConfig): |
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| prefix_attention_mask = torch.ones( |
| model_kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens |
| ).to(model_kwargs["input_ids"].device) |
| model_kwargs["attention_mask"] = torch.cat( |
| (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 |
| ) |
|
|
| if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) |
| if self.base_model_torch_dtype is not None: |
| |
| if isinstance(past_key_values[0], tuple): |
| past_key_values = tuple( |
| tuple( |
| past_key_value.to(self.base_model_torch_dtype) |
| for past_key_value in past_key_value_tuple |
| ) |
| for past_key_value_tuple in past_key_values |
| ) |
| else: |
| past_key_values = tuple( |
| past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_values |
| ) |
| |
| model_kwargs["past_key_values"] = past_key_values |
| else: |
| if model_kwargs["past_key_values"] is None: |
| inputs_embeds = self.word_embeddings(model_kwargs["input_ids"]) |
| prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0]) |
| prompts = prompts.to(inputs_embeds.dtype) |
| model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) |
| model_kwargs["input_ids"] = None |
|
|
| return model_kwargs |
|
|
|
|
| class PeftModelForSeq2SeqLM(PeftModel): |
| """ |
| |
| """ |
|
|
| def __init__(self, model, peft_config: PeftConfig): |
| super().__init__(model, peft_config) |
| self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation |
| self.base_model_prepare_encoder_decoder_kwargs_for_generation = ( |
| self.base_model._prepare_encoder_decoder_kwargs_for_generation |
| ) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| decoder_inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| if not isinstance(self.peft_config, PromptLearningConfig): |
| return self.base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| decoder_input_ids=decoder_input_ids, |
| decoder_attention_mask=decoder_attention_mask, |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| labels=labels, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| batch_size = input_ids.shape[0] |
| if decoder_attention_mask is not None: |
| |
| prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) |
| decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1) |
|
|
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") |
| kwargs["position_ids"] = None |
| if kwargs.get("token_type_ids", None) is not None: |
| warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") |
| kwargs["token_type_ids"] = None |
| kwargs.update( |
| { |
| "attention_mask": attention_mask, |
| "decoder_attention_mask": decoder_attention_mask, |
| "labels": labels, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| } |
| ) |
|
|
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| past_key_values = self.get_prompt(batch_size) |
| return self.base_model( |
| input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs |
| ) |
| else: |
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| if decoder_inputs_embeds is None and decoder_input_ids is None: |
| decoder_input_ids = shift_tokens_right( |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| ) |
| decoder_inputs_embeds = self.word_embeddings(decoder_input_ids) |
|
|
| if attention_mask is not None: |
| |
| prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) |
| kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) |
| |
| if labels is not None: |
| if self.peft_config.num_transformer_submodules == 1: |
| kwargs["labels"] = labels |
| elif self.peft_config.num_transformer_submodules == 2: |
| prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device) |
| kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) |
| prompts = self.get_prompt(batch_size=batch_size) |
| prompts = prompts.to(inputs_embeds.dtype) |
| inputs_embeds = torch.cat((prompts[:, : self.peft_config.num_virtual_tokens], inputs_embeds), dim=1) |
| if self.peft_config.num_transformer_submodules == 1: |
| return self.base_model(inputs_embeds=inputs_embeds, **kwargs) |
| elif self.peft_config.num_transformer_submodules == 2: |
| decoder_inputs_embeds = torch.cat( |
| (prompts[:, self.peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1 |
| ) |
| return self.base_model( |
| inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs |
| ) |
|
|
| def generate(self, **kwargs): |
| self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation |
| self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( |
| self._prepare_encoder_decoder_kwargs_for_generation |
| ) |
| try: |
| if not isinstance(self.peft_config, PromptLearningConfig): |
| outputs = self.base_model.generate(**kwargs) |
| else: |
| if "input_ids" not in kwargs: |
| raise ValueError("input_ids must be provided for Peft model generation") |
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn( |
| "Position ids are not supported for parameter efficient tuning. Ignoring position ids." |
| ) |
| kwargs["position_ids"] = None |
| if kwargs.get("token_type_ids", None) is not None: |
| warnings.warn( |
| "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" |
| ) |
| kwargs["token_type_ids"] = None |
|
|
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| outputs = self.base_model.generate(**kwargs) |
| else: |
| raise NotImplementedError |
| except: |
| self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation |
| self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( |
| self.base_model_prepare_encoder_decoder_kwargs_for_generation |
| ) |
| raise |
| else: |
| self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation |
| self.base_model._prepare_encoder_decoder_kwargs_for_generation = ( |
| self.base_model_prepare_encoder_decoder_kwargs_for_generation |
| ) |
| return outputs |
|
|
| def prepare_inputs_for_generation(self, *args, **kwargs): |
| model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) |
| if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| batch_size = model_kwargs["decoder_input_ids"].shape[0] |
| past_key_values = self.get_prompt(batch_size) |
| model_kwargs["past_key_values"] = past_key_values |
| return model_kwargs |
|
|
|
|
| class PeftModelForTokenClassification(PeftModel): |
| """ |
| |
| """ |
|
|
| def __init__(self, model, peft_config: PeftConfig): |
| super().__init__(model, peft_config) |
| self.modules_to_save = ["classifier", "score"] |
|
|
| for name, _ in self.base_model.named_children(): |
| if any(module_name in name for module_name in self.modules_to_save): |
| self.cls_layer_name = name |
| break |
|
|
| |
| _set_trainable(self) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if not isinstance(self.peft_config, PromptLearningConfig): |
| return self.base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| batch_size = input_ids.shape[0] |
| if attention_mask is not None: |
| |
| prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device) |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) |
| if kwargs.get("position_ids", None) is not None: |
| warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") |
| kwargs["position_ids"] = None |
| kwargs.update( |
| { |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| } |
| ) |
|
|
| if self.peft_config.peft_type == PeftType.PREFIX_TUNING: |
| return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) |
| else: |
| if kwargs.get("token_type_ids", None) is not None: |
| kwargs["token_type_ids"] = torch.cat( |
| ( |
| torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device), |
| kwargs["token_type_ids"], |
| ), |
| dim=1, |
| ).long() |
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| prompts = self.get_prompt(batch_size=batch_size) |
| prompts = prompts.to(inputs_embeds.dtype) |
| inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) |
| return self.base_model(inputs_embeds=inputs_embeds, **kwargs) |
|
|
| def _prefix_tuning_forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs, |
| ): |
| batch_size = input_ids.shape[0] |
| past_key_values = self.get_prompt(batch_size) |
| fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) |
| kwargs.update( |
| { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "inputs_embeds": inputs_embeds, |
| "output_attentions": output_attentions, |
| "output_hidden_states": output_hidden_states, |
| "return_dict": return_dict, |
| "past_key_values": past_key_values, |
| } |
| ) |
| if "past_key_values" in fwd_params: |
| return self.base_model(labels=labels, **kwargs) |
| else: |
| transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) |
| fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) |
| if "past_key_values" not in fwd_params: |
| raise ValueError("Model does not support past key values which are required for prefix tuning.") |
| outputs = transformer_backbone_name(**kwargs) |
| sequence_output = outputs[0] |
| if "dropout" in [name for name, _ in list(self.base_model.named_children())]: |
| sequence_output = self.base_model.dropout(sequence_output) |
| logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output) |
|
|
| loss = None |
| loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|