Commit
·
3eb20d0
1
Parent(s):
509511d
refactor: modify encode
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- modeling_lora.py +7 -9
- modeling_xlm_roberta.py +5 -2
modeling_lora.py
CHANGED
|
@@ -337,7 +337,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 337 |
def encode(
|
| 338 |
self,
|
| 339 |
*args,
|
| 340 |
-
task:
|
| 341 |
**kwargs,
|
| 342 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 343 |
"""
|
|
@@ -351,13 +351,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 351 |
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
| 352 |
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
| 353 |
"""
|
| 354 |
-
if task
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
)
|
| 361 |
-
self.current_task = task
|
| 362 |
|
| 363 |
return self.roberta.encode(*args, **kwargs)
|
|
|
|
| 337 |
def encode(
|
| 338 |
self,
|
| 339 |
*args,
|
| 340 |
+
task: Optional[str] = None,
|
| 341 |
**kwargs,
|
| 342 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 343 |
"""
|
|
|
|
| 351 |
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
| 352 |
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
| 353 |
"""
|
| 354 |
+
if task and task not in self._lora_adaptations:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
f"Unsupported task '{task}'. "
|
| 357 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 358 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
| 359 |
+
)
|
|
|
|
|
|
|
| 360 |
|
| 361 |
return self.roberta.encode(*args, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
|
@@ -459,6 +459,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 459 |
device: Optional[torch.device] = None,
|
| 460 |
normalize_embeddings: bool = False,
|
| 461 |
truncate_dim: Optional[int] = None,
|
|
|
|
| 462 |
**tokenizer_kwargs,
|
| 463 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 464 |
"""
|
|
@@ -549,14 +550,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 549 |
)
|
| 550 |
else:
|
| 551 |
range_iter = range(0, len(sentences), batch_size)
|
| 552 |
-
|
|
|
|
|
|
|
| 553 |
for i in range_iter:
|
| 554 |
encoded_input = self.tokenizer(
|
| 555 |
sentences[i : i + batch_size],
|
| 556 |
return_tensors='pt',
|
| 557 |
**tokenizer_kwargs,
|
| 558 |
).to(self.device)
|
| 559 |
-
token_embs = self.forward(**encoded_input)[0]
|
| 560 |
|
| 561 |
# Accumulate in fp32 to avoid overflow
|
| 562 |
token_embs = token_embs.float()
|
|
|
|
| 459 |
device: Optional[torch.device] = None,
|
| 460 |
normalize_embeddings: bool = False,
|
| 461 |
truncate_dim: Optional[int] = None,
|
| 462 |
+
task: Optional[str] = None,
|
| 463 |
**tokenizer_kwargs,
|
| 464 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 465 |
"""
|
|
|
|
| 550 |
)
|
| 551 |
else:
|
| 552 |
range_iter = range(0, len(sentences), batch_size)
|
| 553 |
+
lora_kwargs = {}
|
| 554 |
+
if task:
|
| 555 |
+
lora_kwargs['task'] = task
|
| 556 |
for i in range_iter:
|
| 557 |
encoded_input = self.tokenizer(
|
| 558 |
sentences[i : i + batch_size],
|
| 559 |
return_tensors='pt',
|
| 560 |
**tokenizer_kwargs,
|
| 561 |
).to(self.device)
|
| 562 |
+
token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
|
| 563 |
|
| 564 |
# Accumulate in fp32 to avoid overflow
|
| 565 |
token_embs = token_embs.float()
|