Update modeling_neollm.py
Browse files- modeling_neollm.py +6 -6
modeling_neollm.py
CHANGED
|
@@ -1556,8 +1556,8 @@ def eager_attention_forward(
|
|
| 1556 |
attn_weights = nn.functional.softmax(
|
| 1557 |
attn_weights, dim=-1, dtype=torch.float32
|
| 1558 |
).to(query.dtype)
|
| 1559 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 1560 |
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
|
|
|
|
| 1561 |
return attn_output, attn_weights
|
| 1562 |
|
| 1563 |
|
|
@@ -1622,10 +1622,8 @@ def affine_scaled_eager_attention_forward(
|
|
| 1622 |
if attn_analysis is not None:
|
| 1623 |
attn_analysis.attn_weights_post_affine = attn_weights_affine.detach()
|
| 1624 |
|
| 1625 |
-
attn_weights_affine = nn.functional.dropout(
|
| 1626 |
-
attn_weights_affine, p=dropout, training=module.training
|
| 1627 |
-
)
|
| 1628 |
attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous()
|
|
|
|
| 1629 |
return attn_output, attn_weights_affine
|
| 1630 |
|
| 1631 |
|
|
@@ -1729,6 +1727,7 @@ def affine_scaled_flash_attention_forward(
|
|
| 1729 |
|
| 1730 |
# ββ Combine and apply dropout to the full affine output βββββββββββββββ
|
| 1731 |
output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
|
|
|
|
| 1732 |
|
| 1733 |
# attn_weights is None β flash never exposes the softmax weight matrix.
|
| 1734 |
return output, None
|
|
@@ -2471,8 +2470,9 @@ class NeoLLMAttention(nn.Module):
|
|
| 2471 |
**kwargs,
|
| 2472 |
)
|
| 2473 |
else:
|
| 2474 |
-
|
| 2475 |
-
|
|
|
|
| 2476 |
attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 2477 |
attn_out, attn_weights = attn_fn(
|
| 2478 |
self, q, k, v, attention_mask,
|
|
|
|
| 1556 |
attn_weights = nn.functional.softmax(
|
| 1557 |
attn_weights, dim=-1, dtype=torch.float32
|
| 1558 |
).to(query.dtype)
|
|
|
|
| 1559 |
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
|
| 1560 |
+
attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
|
| 1561 |
return attn_output, attn_weights
|
| 1562 |
|
| 1563 |
|
|
|
|
| 1622 |
if attn_analysis is not None:
|
| 1623 |
attn_analysis.attn_weights_post_affine = attn_weights_affine.detach()
|
| 1624 |
|
|
|
|
|
|
|
|
|
|
| 1625 |
attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous()
|
| 1626 |
+
attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
|
| 1627 |
return attn_output, attn_weights_affine
|
| 1628 |
|
| 1629 |
|
|
|
|
| 1727 |
|
| 1728 |
# ββ Combine and apply dropout to the full affine output βββββββββββββββ
|
| 1729 |
output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
|
| 1730 |
+
output = nn.functional.dropout(output, p=dropout, training=module.training)
|
| 1731 |
|
| 1732 |
# attn_weights is None β flash never exposes the softmax weight matrix.
|
| 1733 |
return output, None
|
|
|
|
| 2470 |
**kwargs,
|
| 2471 |
)
|
| 2472 |
else:
|
| 2473 |
+
if self.config._attn_implementation == "eager":
|
| 2474 |
+
attn_fn = eager_attention_forward
|
| 2475 |
+
else:
|
| 2476 |
attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 2477 |
attn_out, attn_weights = attn_fn(
|
| 2478 |
self, q, k, v, attention_mask,
|