Fix attention mask dtype issues.
Browse files- amplify.py +1 -1
amplify.py
CHANGED
|
@@ -248,7 +248,7 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 248 |
if attention_mask is not None and not torch.all(attention_mask == 0):
|
| 249 |
assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
|
| 250 |
"AMPLIFY expects an additive attention_mask.\n"
|
| 251 |
-
"Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float(
|
| 252 |
)
|
| 253 |
attention_mask = (
|
| 254 |
attention_mask.unsqueeze(1)
|
|
|
|
| 248 |
if attention_mask is not None and not torch.all(attention_mask == 0):
|
| 249 |
assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
|
| 250 |
"AMPLIFY expects an additive attention_mask.\n"
|
| 251 |
+
"Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float('-inf'))"
|
| 252 |
)
|
| 253 |
attention_mask = (
|
| 254 |
attention_mask.unsqueeze(1)
|