Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +1 -1
modeling_esm_plusplus.py
CHANGED
|
@@ -321,7 +321,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 321 |
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 322 |
if attention_mask is not None:
|
| 323 |
if attention_mask.dtype == torch.bool:
|
| 324 |
-
|
| 325 |
else:
|
| 326 |
attn_bias += attention_mask
|
| 327 |
|
|
|
|
| 321 |
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 322 |
if attention_mask is not None:
|
| 323 |
if attention_mask.dtype == torch.bool:
|
| 324 |
+
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
| 325 |
else:
|
| 326 |
attn_bias += attention_mask
|
| 327 |
|