Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +5 -8
modeling_esm_plusplus.py
CHANGED
|
@@ -316,15 +316,12 @@ class MultiHeadAttention(nn.Module):
|
|
| 316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 317 |
|
| 318 |
if output_attentions: # Manual attention computation
|
| 319 |
-
L,
|
| 320 |
-
scale = 1 / math.sqrt(
|
| 321 |
-
attn_bias = torch.zeros(L,
|
| 322 |
if attention_mask is not None:
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
else:
|
| 326 |
-
attn_bias += attention_mask
|
| 327 |
-
|
| 328 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 329 |
attn_weights += attn_bias
|
| 330 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
|
| 316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 317 |
|
| 318 |
if output_attentions: # Manual attention computation
|
| 319 |
+
b, L, d = x.shape
|
| 320 |
+
scale = 1 / math.sqrt(d)
|
| 321 |
+
attn_bias = torch.zeros(b, 1, L, L, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 322 |
if attention_mask is not None:
|
| 323 |
+
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
| 324 |
+
|
|
|
|
|
|
|
|
|
|
| 325 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 326 |
attn_weights += attn_bias
|
| 327 |
attn_weights = F.softmax(attn_weights, dim=-1)
|