Spaces:
Sleeping
Sleeping
Convert mask len to int
Browse files- src/smc/scheduler.py +1 -1
src/smc/scheduler.py
CHANGED
|
@@ -142,7 +142,7 @@ class MeissonicScheduler(BaseScheduler):
|
|
| 142 |
else:
|
| 143 |
raise ValueError(f"unknown masking schedule {self.masking_schedule}")
|
| 144 |
|
| 145 |
-
mask_len = (seq_len * mask_ratio).floor()
|
| 146 |
# do not mask more than amount previously masked
|
| 147 |
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
| 148 |
# mask at least one
|
|
|
|
| 142 |
else:
|
| 143 |
raise ValueError(f"unknown masking schedule {self.masking_schedule}")
|
| 144 |
|
| 145 |
+
mask_len = (seq_len * mask_ratio).floor().long()
|
| 146 |
# do not mask more than amount previously masked
|
| 147 |
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
| 148 |
# mask at least one
|