fix: fix LoRA implementation
Browse files- modeling_lora.py +2 -1
modeling_lora.py
CHANGED
|
@@ -210,6 +210,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 210 |
self._num_adaptions = config.num_loras
|
| 211 |
self._register_lora(self._num_adaptions)
|
| 212 |
self.main_params_trainable = False
|
|
|
|
| 213 |
self.current_task = 0
|
| 214 |
|
| 215 |
@property
|
|
@@ -265,7 +266,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 265 |
@current_task.setter
|
| 266 |
def current_task(self, task_idx: Union[None, int]):
|
| 267 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 268 |
-
if self._task_idx != task_idx
|
| 269 |
self._task_idx = task_idx
|
| 270 |
self.apply(
|
| 271 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
|
|
|
| 210 |
self._num_adaptions = config.num_loras
|
| 211 |
self._register_lora(self._num_adaptions)
|
| 212 |
self.main_params_trainable = False
|
| 213 |
+
self._task_idx = None
|
| 214 |
self.current_task = 0
|
| 215 |
|
| 216 |
@property
|
|
|
|
| 266 |
@current_task.setter
|
| 267 |
def current_task(self, task_idx: Union[None, int]):
|
| 268 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 269 |
+
if self._task_idx != task_idx:
|
| 270 |
self._task_idx = task_idx
|
| 271 |
self.apply(
|
| 272 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|