kadarakos commited on
Commit
f776570
·
1 Parent(s): f3f859d

fix precision oversight

Browse files
Files changed (1) hide show
  1. src/mentioned/model.py +1 -1
src/mentioned/model.py CHANGED
@@ -51,7 +51,7 @@ class SentenceEncoder(torch.nn.Module):
51
  word_mask = word_ids.unsqueeze(-1) == torch.arange(
52
  num_words, device=word_ids.device
53
  )
54
- word_mask = word_mask.float() # (B, S, W)
55
  # Sum embeddings for each word: (B, W, S) @ (B, S, D) -> (B, W, D)
56
  word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings)
57
  # Count subwords per word to get the denominator
 
51
  word_mask = word_ids.unsqueeze(-1) == torch.arange(
52
  num_words, device=word_ids.device
53
  )
54
+ word_mask = word_mask.to(subword_embeddings.dtype)
55
  # Sum embeddings for each word: (B, W, S) @ (B, S, D) -> (B, W, D)
56
  word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings)
57
  # Count subwords per word to get the denominator