Commit
·
6f5a561
1
Parent(s):
09ed4f9
Upload LUAR
Browse files
model.py
CHANGED
|
@@ -44,7 +44,6 @@ class LUAR(PreTrainedModel):
|
|
| 44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
| 45 |
"""Mean Pooling as described in the SBERT paper.
|
| 46 |
"""
|
| 47 |
-
# input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
|
| 48 |
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
|
| 49 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
| 50 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
|
|
|
| 44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
| 45 |
"""Mean Pooling as described in the SBERT paper.
|
| 46 |
"""
|
|
|
|
| 47 |
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
|
| 48 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
| 49 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|