Commit
·
a2b7c86
1
Parent(s):
c6a5a4d
refactor: restructure the class
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- modeling_lora.py +6 -13
modeling_lora.py
CHANGED
|
@@ -12,7 +12,6 @@ from transformers import PretrainedConfig
|
|
| 12 |
from .modeling_xlm_roberta import (
|
| 13 |
XLMRobertaFlashConfig,
|
| 14 |
XLMRobertaModel,
|
| 15 |
-
XLMRobertaPreTrainedModel,
|
| 16 |
)
|
| 17 |
|
| 18 |
|
|
@@ -209,19 +208,13 @@ class LoRAParametrization(nn.Module):
|
|
| 209 |
layer.current_task = task_idx
|
| 210 |
|
| 211 |
|
| 212 |
-
class XLMRobertaLoRA(
|
| 213 |
def __init__(
|
| 214 |
self,
|
| 215 |
config: XLMRobertaFlashConfig,
|
| 216 |
-
roberta: Optional[XLMRobertaModel] = None,
|
| 217 |
):
|
| 218 |
super().__init__(config)
|
| 219 |
|
| 220 |
-
if roberta is None:
|
| 221 |
-
self.roberta = XLMRobertaModel(config)
|
| 222 |
-
else:
|
| 223 |
-
self.roberta = roberta
|
| 224 |
-
|
| 225 |
self._num_adaptations = len(config.lora_adaptations)
|
| 226 |
self._rank = config.lora_rank
|
| 227 |
self._dropout_p = config.lora_dropout_p
|
|
@@ -238,6 +231,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 238 |
# By default, we select the first LoRA
|
| 239 |
self.current_task = 0
|
| 240 |
|
|
|
|
| 241 |
@property
|
| 242 |
def main_params_trainable(self):
|
| 243 |
return self._main_params_trainable
|
|
@@ -273,15 +267,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 273 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 274 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 275 |
)
|
|
|
|
| 276 |
if config.load_trained_adapters:
|
| 277 |
return super().from_pretrained(
|
| 278 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 279 |
)
|
| 280 |
else:
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
)
|
| 284 |
-
return cls(config, roberta=roberta)
|
| 285 |
|
| 286 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
| 287 |
self.apply(
|
|
@@ -320,7 +313,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 320 |
def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
|
| 321 |
if lora_adaptation is None or lora_adaptation >= 0:
|
| 322 |
self.current_task = lora_adaptation
|
| 323 |
-
return
|
| 324 |
|
| 325 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
| 326 |
for _, param in self.named_parameters(recurse=recurse):
|
|
|
|
| 12 |
from .modeling_xlm_roberta import (
|
| 13 |
XLMRobertaFlashConfig,
|
| 14 |
XLMRobertaModel,
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
|
|
|
| 208 |
layer.current_task = task_idx
|
| 209 |
|
| 210 |
|
| 211 |
+
class XLMRobertaLoRA(XLMRobertaModel):
|
| 212 |
def __init__(
|
| 213 |
self,
|
| 214 |
config: XLMRobertaFlashConfig,
|
|
|
|
| 215 |
):
|
| 216 |
super().__init__(config)
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
self._num_adaptations = len(config.lora_adaptations)
|
| 219 |
self._rank = config.lora_rank
|
| 220 |
self._dropout_p = config.lora_dropout_p
|
|
|
|
| 231 |
# By default, we select the first LoRA
|
| 232 |
self.current_task = 0
|
| 233 |
|
| 234 |
+
|
| 235 |
@property
|
| 236 |
def main_params_trainable(self):
|
| 237 |
return self._main_params_trainable
|
|
|
|
| 267 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 268 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 269 |
)
|
| 270 |
+
|
| 271 |
if config.load_trained_adapters:
|
| 272 |
return super().from_pretrained(
|
| 273 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 274 |
)
|
| 275 |
else:
|
| 276 |
+
torch.set_default_dtype(torch.float16)
|
| 277 |
+
return cls(config)
|
|
|
|
|
|
|
| 278 |
|
| 279 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
| 280 |
self.apply(
|
|
|
|
| 313 |
def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
|
| 314 |
if lora_adaptation is None or lora_adaptation >= 0:
|
| 315 |
self.current_task = lora_adaptation
|
| 316 |
+
return super().forward(*args, **kwargs)
|
| 317 |
|
| 318 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
| 319 |
for _, param in self.named_parameters(recurse=recurse):
|