feat: return from_bert for from_pretrained
Browse files- modeling_lora.py +21 -0
modeling_lora.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import math
|
|
|
|
| 2 |
from functools import partial
|
| 3 |
from typing import Iterator, Optional, Tuple, Union
|
| 4 |
|
|
@@ -6,6 +7,7 @@ import torch
|
|
| 6 |
import torch.nn.utils.parametrize as parametrize
|
| 7 |
from torch import nn
|
| 8 |
from torch.nn import Parameter
|
|
|
|
| 9 |
|
| 10 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 11 |
|
|
@@ -193,6 +195,25 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 193 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
| 194 |
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 197 |
self.apply(
|
| 198 |
partial(
|
|
|
|
| 1 |
import math
|
| 2 |
+
import os
|
| 3 |
from functools import partial
|
| 4 |
from typing import Iterator, Optional, Tuple, Union
|
| 5 |
|
|
|
|
| 7 |
import torch.nn.utils.parametrize as parametrize
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn import Parameter
|
| 10 |
+
from transformers import PretrainedConfig
|
| 11 |
|
| 12 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 13 |
|
|
|
|
| 195 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
| 196 |
|
| 197 |
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_pretrained(
|
| 200 |
+
cls,
|
| 201 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 202 |
+
*model_args,
|
| 203 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 204 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 205 |
+
ignore_mismatched_sizes: bool = False,
|
| 206 |
+
force_download: bool = False,
|
| 207 |
+
local_files_only: bool = False,
|
| 208 |
+
token: Optional[Union[str, bool]] = None,
|
| 209 |
+
revision: str = "main",
|
| 210 |
+
use_safetensors: bool = None,
|
| 211 |
+
**kwargs,
|
| 212 |
+
):
|
| 213 |
+
# TODO: choose between from_bert and super().from_pretrained
|
| 214 |
+
return cls.from_bert(pretrained_model_name_or_path)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 218 |
self.apply(
|
| 219 |
partial(
|