KitsuVp commited on
Commit
1f8397b
Β·
verified Β·
1 Parent(s): 7a75741

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +6 -6
modeling_neollm.py CHANGED
@@ -1556,8 +1556,8 @@ def eager_attention_forward(
1556
  attn_weights = nn.functional.softmax(
1557
  attn_weights, dim=-1, dtype=torch.float32
1558
  ).to(query.dtype)
1559
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
1560
  attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
 
1561
  return attn_output, attn_weights
1562
 
1563
 
@@ -1622,10 +1622,8 @@ def affine_scaled_eager_attention_forward(
1622
  if attn_analysis is not None:
1623
  attn_analysis.attn_weights_post_affine = attn_weights_affine.detach()
1624
 
1625
- attn_weights_affine = nn.functional.dropout(
1626
- attn_weights_affine, p=dropout, training=module.training
1627
- )
1628
  attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous()
 
1629
  return attn_output, attn_weights_affine
1630
 
1631
 
@@ -1729,6 +1727,7 @@ def affine_scaled_flash_attention_forward(
1729
 
1730
  # ── Combine and apply dropout to the full affine output ───────────────
1731
  output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
 
1732
 
1733
  # attn_weights is None β€” flash never exposes the softmax weight matrix.
1734
  return output, None
@@ -2471,8 +2470,9 @@ class NeoLLMAttention(nn.Module):
2471
  **kwargs,
2472
  )
2473
  else:
2474
- attn_fn = eager_attention_forward
2475
- if self.config._attn_implementation != "eager":
 
2476
  attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
2477
  attn_out, attn_weights = attn_fn(
2478
  self, q, k, v, attention_mask,
 
1556
  attn_weights = nn.functional.softmax(
1557
  attn_weights, dim=-1, dtype=torch.float32
1558
  ).to(query.dtype)
 
1559
  attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
1560
+ attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
1561
  return attn_output, attn_weights
1562
 
1563
 
 
1622
  if attn_analysis is not None:
1623
  attn_analysis.attn_weights_post_affine = attn_weights_affine.detach()
1624
 
 
 
 
1625
  attn_output = torch.matmul(attn_weights_affine, value_states).transpose(1, 2).contiguous()
1626
+ attn_output = nn.functional.dropout(attn_output, p=dropout, training=module.training)
1627
  return attn_output, attn_weights_affine
1628
 
1629
 
 
1727
 
1728
  # ── Combine and apply dropout to the full affine output ───────────────
1729
  output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
1730
+ output = nn.functional.dropout(output, p=dropout, training=module.training)
1731
 
1732
  # attn_weights is None β€” flash never exposes the softmax weight matrix.
1733
  return output, None
 
2470
  **kwargs,
2471
  )
2472
  else:
2473
+ if self.config._attn_implementation == "eager":
2474
+ attn_fn = eager_attention_forward
2475
+ else:
2476
  attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
2477
  attn_out, attn_weights = attn_fn(
2478
  self, q, k, v, attention_mask,