Commit ·
c55e591
1
Parent(s): b27fa55
refactor: truncation fn
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- modeling_xlm_roberta.py +14 -9
modeling_xlm_roberta.py
CHANGED
|
@@ -579,15 +579,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 580 |
|
| 581 |
if truncate_dim:
|
| 582 |
-
|
| 583 |
-
logger.warning(
|
| 584 |
-
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
| 585 |
-
)
|
| 586 |
-
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 587 |
-
all_embeddings = [tensor[:truncate_dim] for tensor in all_embeddings]
|
| 588 |
-
else:
|
| 589 |
-
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
| 590 |
-
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
| 591 |
|
| 592 |
if convert_to_tensor:
|
| 593 |
all_embeddings = torch.stack(all_embeddings)
|
|
@@ -600,6 +592,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 600 |
self.train(is_training)
|
| 601 |
return all_embeddings
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
def mean_pooling(
|
| 604 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 605 |
):
|
|
|
|
| 579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 580 |
|
| 581 |
if truncate_dim:
|
| 582 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
if convert_to_tensor:
|
| 585 |
all_embeddings = torch.stack(all_embeddings)
|
|
|
|
| 592 |
self.train(is_training)
|
| 593 |
return all_embeddings
|
| 594 |
|
| 595 |
+
|
| 596 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 597 |
+
if not self.config.matryoshka_dimensions:
|
| 598 |
+
logger.warning(
|
| 599 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
| 600 |
+
)
|
| 601 |
+
return embeddings
|
| 602 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 603 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
| 606 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
| 607 |
+
|
| 608 |
def mean_pooling(
|
| 609 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 610 |
):
|