fix: normlaise after truncate
Browse files- modeling_xlm_roberta.py +19 -14
modeling_xlm_roberta.py
CHANGED
|
@@ -600,7 +600,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 600 |
|
| 601 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 602 |
if truncate_dim:
|
| 603 |
-
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 604 |
|
| 605 |
if convert_to_tensor:
|
| 606 |
all_embeddings = torch.stack(all_embeddings)
|
|
@@ -613,19 +613,24 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 613 |
self.train(is_training)
|
| 614 |
return all_embeddings
|
| 615 |
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
def mean_pooling(
|
| 631 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
|
|
| 600 |
|
| 601 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 602 |
if truncate_dim:
|
| 603 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim, normalize_embeddings)
|
| 604 |
|
| 605 |
if convert_to_tensor:
|
| 606 |
all_embeddings = torch.stack(all_embeddings)
|
|
|
|
| 613 |
self.train(is_training)
|
| 614 |
return all_embeddings
|
| 615 |
|
| 616 |
+
def truncate_embeddings(self, embeddings, truncate_dim, normalize_embeddings):
|
| 617 |
+
if not self.config.matryoshka_dimensions:
|
| 618 |
+
logger.warning(
|
| 619 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
| 620 |
+
)
|
| 621 |
+
return embeddings
|
| 622 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 623 |
+
truncated_embeddings = [tensor[:truncate_dim] for tensor in embeddings]
|
| 624 |
+
if normalize_embeddings:
|
| 625 |
+
truncated_embeddings = [
|
| 626 |
+
torch.nn.functional.normalize(tensor, p=2, dim=0) for tensor in truncated_embeddings
|
| 627 |
+
]
|
| 628 |
+
return truncated_embeddings
|
| 629 |
+
else:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
| 632 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
| 633 |
+
)
|
| 634 |
|
| 635 |
def mean_pooling(
|
| 636 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|