fix precision oversight
Browse files- src/mentioned/model.py +1 -1
src/mentioned/model.py
CHANGED
|
@@ -51,7 +51,7 @@ class SentenceEncoder(torch.nn.Module):
|
|
| 51 |
word_mask = word_ids.unsqueeze(-1) == torch.arange(
|
| 52 |
num_words, device=word_ids.device
|
| 53 |
)
|
| 54 |
-
word_mask = word_mask.
|
| 55 |
# Sum embeddings for each word: (B, W, S) @ (B, S, D) -> (B, W, D)
|
| 56 |
word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings)
|
| 57 |
# Count subwords per word to get the denominator
|
|
|
|
| 51 |
word_mask = word_ids.unsqueeze(-1) == torch.arange(
|
| 52 |
num_words, device=word_ids.device
|
| 53 |
)
|
| 54 |
+
word_mask = word_mask.to(subword_embeddings.dtype)
|
| 55 |
# Sum embeddings for each word: (B, W, S) @ (B, S, D) -> (B, W, D)
|
| 56 |
word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings)
|
| 57 |
# Count subwords per word to get the denominator
|