Update modeling_FalconTST.py
Browse files- modeling_FalconTST.py +2 -14
modeling_FalconTST.py
CHANGED
|
@@ -247,7 +247,7 @@ class RMSNorm(nn.Module):
|
|
| 247 |
return self.weight * hidden_states.to(input_dtype)
|
| 248 |
|
| 249 |
|
| 250 |
-
class
|
| 251 |
"""Implement the scaled dot product attention with softmax.
|
| 252 |
Arguments
|
| 253 |
---------
|
|
@@ -294,16 +294,6 @@ class FlashAttention(nn.Module):
|
|
| 294 |
|
| 295 |
|
| 296 |
|
| 297 |
-
class TEDotProductAttention(nn.Module):
|
| 298 |
-
def __init__(self, flash_attention,):
|
| 299 |
-
super().__init__()
|
| 300 |
-
self.flash_attention = flash_attention
|
| 301 |
-
|
| 302 |
-
def forward(self, q, k, v, mask=None):
|
| 303 |
-
# Prioritize using FlashAttention
|
| 304 |
-
return self.flash_attention(q, k, v, mask)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
class SelfAttention(nn.Module):
|
| 308 |
def __init__(self,config,):
|
| 309 |
super().__init__()
|
|
@@ -311,9 +301,7 @@ class SelfAttention(nn.Module):
|
|
| 311 |
q_layernorm=config.q_layernorm
|
| 312 |
k_layernorm=config.k_layernorm
|
| 313 |
self.hidden_size = config.hidden_size
|
| 314 |
-
self.core_attention = TEDotProductAttention(
|
| 315 |
-
flash_attention=FlashAttention(),
|
| 316 |
-
)
|
| 317 |
self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
|
| 318 |
self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
|
| 319 |
if q_layernorm:
|
|
|
|
| 247 |
return self.weight * hidden_states.to(input_dtype)
|
| 248 |
|
| 249 |
|
| 250 |
+
class TEDotProductAttention(nn.Module):
|
| 251 |
"""Implement the scaled dot product attention with softmax.
|
| 252 |
Arguments
|
| 253 |
---------
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
class SelfAttention(nn.Module):
|
| 298 |
def __init__(self,config,):
|
| 299 |
super().__init__()
|
|
|
|
| 301 |
q_layernorm=config.q_layernorm
|
| 302 |
k_layernorm=config.k_layernorm
|
| 303 |
self.hidden_size = config.hidden_size
|
| 304 |
+
self.core_attention = TEDotProductAttention()
|
|
|
|
|
|
|
| 305 |
self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
|
| 306 |
self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
|
| 307 |
if q_layernorm:
|