| | import math |
| | import os |
| | import warnings |
| | from functools import partial |
| | from typing import Iterator, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.utils.parametrize as parametrize |
| | from torch import nn |
| | from torch.nn import Parameter |
| | from transformers import PretrainedConfig |
| |
|
| | from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel |
| |
|
| |
|
| | LORA_NO_UPDATE = '__lora_no_update__' |
| |
|
| |
|
| | def initialized_weights( |
| | shape: Tuple[int], num_adaptations: int, init: str = "kaiming" |
| | ) -> torch.Tensor: |
| | weight_data = [] |
| | for _ in range(num_adaptations): |
| | new_adaption = torch.zeros(shape) |
| | if init == "kaiming": |
| | nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) |
| | elif init == "normal": |
| | nn.init.normal_(new_adaption) |
| | else: |
| | raise NotImplementedError |
| | weight_data.append(new_adaption) |
| | return torch.stack(weight_data, dim=0) |
| |
|
| |
|
| | class LoRAParametrization(nn.Module): |
| | """ |
| | This LoRA implementation was inspired by https://github.com/cccntu/minLoRA |
| | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy |
| | Permission is hereby granted, free of charge, to any person obtaining a copy of this software |
| | and associated documentation files (the "Software"), to deal in the Software without restriction, |
| | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, |
| | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, |
| | subject to the following conditions: |
| | The above copyright notice and this permission notice shall be included in all copies or substantial |
| | portions of the Software. |
| | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT |
| | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
| | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, |
| | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
| | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | fan_in: int, |
| | fan_out: int, |
| | layer_type: str = "linear", |
| | num_adaptations: int = 1, |
| | rank: int = 4, |
| | dropout_p: float = 0.0, |
| | alpha: float = 1, |
| | ): |
| | super().__init__() |
| | |
| | |
| | fan_in_fan_out = layer_type == "embedding" |
| | self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) |
| |
|
| | if layer_type == "linear": |
| | self.lora_A = nn.Parameter( |
| | initialized_weights((rank, fan_in), num_adaptations, init="kaiming") |
| | ) |
| | self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank))) |
| | elif layer_type == "embedding": |
| | self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank))) |
| | self.lora_B = nn.Parameter( |
| | initialized_weights( |
| | (rank, fan_out), num_adaptations=num_adaptations, init="normal" |
| | ) |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| | self.lora_alpha, self.rank = alpha, rank |
| | self.scaling = alpha / rank |
| | self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x |
| | self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x |
| | self.register_buffer( |
| | "lora_dropout_mask", |
| | torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), |
| | persistent=False, |
| | ) |
| | self.forward_fn = lambda x: x |
| | self.current_task = None |
| |
|
| | def _dropout(self, A): |
| | |
| | return A * self.lora_dropout(self.lora_dropout_mask) |
| |
|
| | def lora_forward(self, X): |
| | assert self.current_task is not None |
| | return ( |
| | X |
| | + torch.matmul( |
| | *self.swap( |
| | ( |
| | self.lora_B[self.current_task], |
| | self.dropout_fn(self.lora_A[self.current_task]), |
| | ) |
| | ) |
| | ).view(X.shape) |
| | * self.scaling |
| | ) |
| |
|
| | def forward(self, X): |
| | return self.forward_fn(X) |
| |
|
| | @property |
| | def current_task(self): |
| | return self._current_task |
| |
|
| | @current_task.setter |
| | def current_task(self, task: Union[None, int]): |
| | self._current_task = task |
| | if task is None: |
| | self.forward_fn = lambda x: x |
| | else: |
| | self.forward_fn = self.lora_forward |
| |
|
| | @classmethod |
| | def from_linear( |
| | cls, |
| | layer: nn.Module, |
| | num_adaptations: int, |
| | rank: int, |
| | dropout_p: float, |
| | alpha: float, |
| | ): |
| | assert isinstance(layer, nn.Linear) |
| | fan_out, fan_in = layer.weight.shape |
| | return cls( |
| | fan_in, |
| | fan_out, |
| | num_adaptations=num_adaptations, |
| | layer_type="linear", |
| | rank=rank, |
| | dropout_p=dropout_p, |
| | alpha=alpha, |
| | ) |
| |
|
| | @classmethod |
| | def from_embedding( |
| | cls, |
| | layer: nn.Module, |
| | num_adaptations: int, |
| | rank: int, |
| | dropout_p: float, |
| | alpha: float, |
| | ): |
| | assert isinstance(layer, nn.Embedding) |
| | fan_in, fan_out = layer.weight.shape |
| | return cls( |
| | fan_in, |
| | fan_out, |
| | num_adaptations=num_adaptations, |
| | layer_type="embedding", |
| | rank=rank, |
| | dropout_p=dropout_p, |
| | alpha=alpha, |
| | ) |
| |
|
| | @classmethod |
| | def add_to_layer( |
| | cls, |
| | layer: nn.Module, |
| | num_adaptations: int, |
| | rank: int, |
| | dropout_p: float, |
| | alpha: float, |
| | ): |
| | if isinstance(layer, nn.Linear): |
| | parametrize.register_parametrization( |
| | layer, |
| | "weight", |
| | cls.from_linear( |
| | layer, |
| | num_adaptations=num_adaptations, |
| | rank=rank, |
| | dropout_p=dropout_p, |
| | alpha=alpha, |
| | ), |
| | ) |
| | elif isinstance(layer, nn.Embedding): |
| | parametrize.register_parametrization( |
| | layer, |
| | "weight", |
| | cls.from_embedding( |
| | layer, |
| | num_adaptations=num_adaptations, |
| | rank=rank, |
| | dropout_p=dropout_p, |
| | alpha=alpha, |
| | ), |
| | ) |
| |
|
| | @staticmethod |
| | def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None): |
| | if isinstance(layer, LoRAParametrization): |
| | layer.current_task = task_idx |
| |
|
| |
|
| | class XLMRobertaLoRA(XLMRobertaPreTrainedModel): |
| | def __init__( |
| | self, |
| | config: XLMRobertaFlashConfig, |
| | roberta: Optional[XLMRobertaModel] = None |
| | ): |
| | super().__init__(config) |
| |
|
| | if roberta is None: |
| | self.roberta = XLMRobertaModel(config) |
| | else: |
| | self.roberta = roberta |
| |
|
| | self._lora_adaptations = config.lora_adaptations |
| | if ( |
| | not isinstance(self._lora_adaptations, list) |
| | or len(self._lora_adaptations) < 1 |
| | ): |
| | raise ValueError( |
| | f'`lora_adaptations` must be a list and contain at least one element' |
| | ) |
| | self._adaptation_map = { |
| | name: idx for idx, name in enumerate(self._lora_adaptations) |
| | } |
| | self._rank = config.lora_rank |
| | self._dropout_p = config.lora_dropout_p |
| | self._alpha = config.lora_alpha |
| | self._register_lora( |
| | num_adaptations=len(self._lora_adaptations), |
| | rank=self._rank, |
| | dropout_p=self._dropout_p, |
| | alpha=self._alpha, |
| | ) |
| | self.main_params_trainable = config.lora_main_params_trainable |
| | self._task_idx = None |
| | |
| | self.current_task = None |
| |
|
| | @property |
| | def main_params_trainable(self): |
| | return self._main_params_trainable |
| |
|
| | @main_params_trainable.setter |
| | def main_params_trainable(self, val: bool): |
| | """Whether the main parameters (i.e. those that are not LoRA) should be trainable. |
| | This method sets the `requires_grad_` attribute of the main weights |
| | and controls which parameters are returned in `self.parameters()`. |
| | :param val: Whether or not to make the parameters trainable. |
| | :return: None |
| | """ |
| | self._main_params_trainable = val |
| | for name, param in super().named_parameters(): |
| | if "lora" not in name: |
| | param.requires_grad_(val) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| | *model_args, |
| | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| | cache_dir: Optional[Union[str, os.PathLike]] = None, |
| | ignore_mismatched_sizes: bool = False, |
| | force_download: bool = False, |
| | local_files_only: bool = False, |
| | token: Optional[Union[str, bool]] = None, |
| | revision: str = "main", |
| | use_safetensors: bool = None, |
| | **kwargs, |
| | ): |
| | config = XLMRobertaFlashConfig.from_pretrained( |
| | pretrained_model_name_or_path, *model_args, **kwargs |
| | ) |
| |
|
| | if config.load_trained_adapters: |
| | return super().from_pretrained( |
| | pretrained_model_name_or_path, *model_args, **kwargs |
| | ) |
| | else: |
| | roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
| | return cls(config, roberta=roberta) |
| |
|
| | def _register_lora(self, num_adaptations, rank, dropout_p, alpha): |
| | self.apply( |
| | partial( |
| | LoRAParametrization.add_to_layer, |
| | num_adaptations=num_adaptations, |
| | rank=rank, |
| | dropout_p=dropout_p, |
| | alpha=alpha, |
| | ) |
| | ) |
| |
|
| | @property |
| | def current_task(self): |
| | """Which LoRA is currently selected |
| | :return: Integer or None (when LoRA is disabled) |
| | """ |
| | return self._task_idx |
| |
|
| | @current_task.setter |
| | def current_task(self, task_name: Union[None, str]): |
| | """Set the LoRA that is to be used. |
| | The LoRA is specified by `task_idx`, which may be an integer >= 0, |
| | indexing the available LoRAs. If it is None, no LoRA is used. |
| | :param task_name: Which LoRA to use |
| | :return: |
| | """ |
| | if task_name and task_name not in self._lora_adaptations: |
| | raise ValueError( |
| | f"Unsupported task '{task_name}'. " |
| | f"Supported tasks are: {', '.join(self.config.lora_adaptations)}." |
| | f"Alternatively, set `task` to `None` if you want to disable LoRA." |
| | ) |
| | task_idx = self._adaptation_map[task_name] if task_name else None |
| | if self._task_idx != task_idx: |
| | |
| | self._task_idx = task_idx |
| | self.apply( |
| | partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx) |
| | ) |
| |
|
| | def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs): |
| | if task != LORA_NO_UPDATE: |
| | self.current_task = task |
| |
|
| | return self.roberta(*args, **kwargs) |
| |
|
| | def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
| | for _, param in self.named_parameters(recurse=recurse): |
| | yield param |
| |
|
| | def named_parameters( |
| | self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| | ) -> Iterator[Tuple[str, Parameter]]: |
| | for name, param in super().named_parameters( |
| | prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate |
| | ): |
| | if "lora" in name or self.main_params_trainable: |
| | yield name, param |
| |
|
| | @torch.inference_mode() |
| | def encode( |
| | self, |
| | *args, |
| | task: Union[str, None] = LORA_NO_UPDATE, |
| | **kwargs, |
| | ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: |
| | """ |
| | Computes sentence embeddings |
| | |
| | task(`str`, *optional*, defaults to `LORA_NO_UPDATE`): |
| | Specifies the task for which the encoding is intended. This parameter controls the |
| | use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set |
| | to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the |
| | existing adapter configuration. If `task` is explicitly set to `None`, all LoRA |
| | adapters are disabled, and the model reverts to its original, general-purpose weights. |
| | If `task` is set to a specific LoRA adaptation, that adaptation is activated. |
| | """ |
| | if task != LORA_NO_UPDATE: |
| | if not task: |
| | warnings.warn( |
| | f"Task-specific embeddings are disabled. To enable, specify the `task` " |
| | f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}", |
| | category=UserWarning, |
| | ) |
| | self.current_task = task |
| |
|
| | return self.roberta.encode(*args, **kwargs) |
| |
|