cp524 commited on
Commit
1117f02
·
1 Parent(s): 01f8311

Convert mask len to int

Browse files
Files changed (1) hide show
  1. src/smc/scheduler.py +1 -1
src/smc/scheduler.py CHANGED
@@ -142,7 +142,7 @@ class MeissonicScheduler(BaseScheduler):
142
  else:
143
  raise ValueError(f"unknown masking schedule {self.masking_schedule}")
144
 
145
- mask_len = (seq_len * mask_ratio).floor()
146
  # do not mask more than amount previously masked
147
  mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
148
  # mask at least one
 
142
  else:
143
  raise ValueError(f"unknown masking schedule {self.masking_schedule}")
144
 
145
+ mask_len = (seq_len * mask_ratio).floor().long()
146
  # do not mask more than amount previously masked
147
  mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
148
  # mask at least one