Harryx2025 commited on
Commit
ab80fdb
·
verified ·
1 Parent(s): f39a2d2

Update modeling_FalconTST.py

Browse files
Files changed (1) hide show
  1. modeling_FalconTST.py +2 -14
modeling_FalconTST.py CHANGED
@@ -247,7 +247,7 @@ class RMSNorm(nn.Module):
247
  return self.weight * hidden_states.to(input_dtype)
248
 
249
 
250
- class FlashAttention(nn.Module):
251
  """Implement the scaled dot product attention with softmax.
252
  Arguments
253
  ---------
@@ -294,16 +294,6 @@ class FlashAttention(nn.Module):
294
 
295
 
296
 
297
- class TEDotProductAttention(nn.Module):
298
- def __init__(self, flash_attention,):
299
- super().__init__()
300
- self.flash_attention = flash_attention
301
-
302
- def forward(self, q, k, v, mask=None):
303
- # Prioritize using FlashAttention
304
- return self.flash_attention(q, k, v, mask)
305
-
306
-
307
  class SelfAttention(nn.Module):
308
  def __init__(self,config,):
309
  super().__init__()
@@ -311,9 +301,7 @@ class SelfAttention(nn.Module):
311
  q_layernorm=config.q_layernorm
312
  k_layernorm=config.k_layernorm
313
  self.hidden_size = config.hidden_size
314
- self.core_attention = TEDotProductAttention(
315
- flash_attention=FlashAttention(),
316
- )
317
  self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
318
  self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
319
  if q_layernorm:
 
247
  return self.weight * hidden_states.to(input_dtype)
248
 
249
 
250
+ class TEDotProductAttention(nn.Module):
251
  """Implement the scaled dot product attention with softmax.
252
  Arguments
253
  ---------
 
294
 
295
 
296
 
 
 
 
 
 
 
 
 
 
 
297
  class SelfAttention(nn.Module):
298
  def __init__(self,config,):
299
  super().__init__()
 
301
  q_layernorm=config.q_layernorm
302
  k_layernorm=config.k_layernorm
303
  self.hidden_size = config.hidden_size
304
+ self.core_attention = TEDotProductAttention()
 
 
305
  self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
306
  self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
307
  if q_layernorm: