dinalt commited on
Commit
0b1deee
·
verified ·
1 Parent(s): bb91586

Cleanup model implementation

Browse files

- Added support for inference cache.
- Refactor common code in attention
- Removed unused code (fragments from another project)

Files changed (1) hide show
  1. modelling_walsh.py +64 -294
modelling_walsh.py CHANGED
@@ -340,7 +340,6 @@ class HFCausalModel(PreTrainedModel):
340
  ):
341
  attention_mask = attention_mask[:, -max_cache_length:]
342
 
343
- # NOTE: "RSWalsh" models don't need to have their absolute positions adjusted to zero; they are trained for this.
344
  position_ids = kwargs.get("position_ids", None)
345
  if attention_mask is not None and position_ids is None:
346
  # create position_ids on the fly for batch generation
@@ -420,6 +419,7 @@ class HFCausalModel(PreTrainedModel):
420
  num_heads=config.num_attention_heads,
421
  attn_type=attn_type,
422
  layer_idx=layer_idx,
 
423
  **config.attention_args,
424
  )
425
 
@@ -516,25 +516,6 @@ class Transformer(nn.Module):
516
  init.constant_(self.output_projection.bias, 0.)
517
  init.normal_(self.embedding.weight, std=self.d_model**-0.5)
518
 
519
- # A vanilla positional encoder
520
- class PositionalEncoder(nn.Module):
521
- def __init__(self, d_embed, max_seq):
522
- super().__init__()
523
- self.d_embed = d_embed
524
- self.max_seq = max_seq
525
-
526
- weight = torch.zeros(max_seq, d_embed)
527
- position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1)
528
- div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
529
- weight[:, 0::2] = torch.sin(position * div_term)
530
- weight[:, 1::2] = torch.cos(position * div_term)
531
- weight = weight.unsqueeze(0)
532
- self.register_buffer('weight', weight)
533
-
534
- def forward(self, x):
535
- seq_len = x.size(-2)
536
- return x + self.weight[:, :seq_len]
537
-
538
  # Converts a torch array of integers into their equivalent binary codes.
539
  def binary_tensor(x, bits):
540
  mask = 2**torch.arange(bits).to(x.device, x.dtype)
@@ -791,42 +772,6 @@ class FeedforwardLayer(nn.Module):
791
  init.constant_(self.linear1.bias, 0.)
792
  init.constant_(self.linear2.bias, 0.)
793
 
794
- # GLU Variants Improve Transformer
795
- # https://arxiv.org/pdf/2002.05202v1.pdf
796
- class SwiGLUFeedforwardLayer(nn.Module):
797
- def __init__(
798
- self,
799
- d_model,
800
- d_feedforward,
801
- layer_idx,
802
- beta=1.0,
803
- dropout=0.1
804
- ):
805
- super().__init__()
806
- self.d_model = d_model
807
- self.d_feedforward = d_feedforward
808
- self.beta = 1.0
809
-
810
- self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False)
811
- self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False)
812
- self.dropout = nn.Dropout(dropout)
813
- self.reset_parameters()
814
-
815
- def forward(self, x):
816
- x, gate = self.linear1(x).chunk(2, dim=-1)
817
- x = x * F.silu(gate)
818
- x = self.dropout(x)
819
- x = self.linear2(x)
820
- return x
821
-
822
- def reset_parameters(self):
823
- # Deepnet initialization
824
- # https://arxiv.org/pdf/2203.00555.pdf
825
- w, g = self.linear1.weight.chunk(2, dim=0)
826
- init.xavier_uniform_(w, gain=self.beta)
827
- init.xavier_uniform_(g, gain=self.beta)
828
- init.xavier_uniform_(self.linear2.weight, gain=self.beta)
829
-
830
  class CausalSelfAttention(nn.Module):
831
  def __init__(
832
  self,
@@ -838,6 +783,7 @@ class CausalSelfAttention(nn.Module):
838
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
839
  attn_type,
840
  layer_idx,
 
841
  beta=1.0,
842
  dropout=0.1,
843
  ):
@@ -847,6 +793,7 @@ class CausalSelfAttention(nn.Module):
847
  self.beta = beta
848
  self.attn_type = attn_type
849
  self.layer_idx = layer_idx
 
850
 
851
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
852
 
@@ -877,9 +824,21 @@ class CausalSelfAttention(nn.Module):
877
  init.constant_(self.in_proj.bias, 0.)
878
  init.constant_(self.output_linear.bias, 0.)
879
 
880
- def project_input(self, qkv):
 
 
881
  proj = self.in_proj(qkv)
882
- return proj.chunk(chunks=3, dim=-1)
 
 
 
 
 
 
 
 
 
 
883
 
884
  def forward(
885
  self,
@@ -888,7 +847,15 @@ class CausalSelfAttention(nn.Module):
888
  past_key_values,
889
  use_cache,
890
  ):
891
- if self.attn_type == "flash2":
 
 
 
 
 
 
 
 
892
  if use_cache is None or use_cache == False:
893
  return self.flash2_forward(qkv)
894
  else:
@@ -898,21 +865,15 @@ class CausalSelfAttention(nn.Module):
898
  batch_size, seq_len, d_embed = qkv.shape
899
 
900
  # Feed the inputs through the K, Q, V matrices.
901
- query, key, value = self.project_input(qkv)
902
-
903
- # Split projections into multiple heads and swap position of sequence / heads dimension
904
- query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
905
- key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
906
- value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
907
-
908
- # Update the cache values.
909
- if past_key_values is not None:
910
- key, value = past_key_values.update(key, value, self.layer_idx)
911
-
912
  # Default to returning empty attention weights.
913
  attentions = None
 
 
914
 
915
- if self.attn_type == "torch":
916
  # This context manager can be used to force which implementation to use.
917
  #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
918
  attended_values = F.scaled_dot_product_attention(
@@ -921,7 +882,7 @@ class CausalSelfAttention(nn.Module):
921
  value,
922
  attn_mask=None,
923
  dropout_p=self.dropout.p if self.training else 0.0,
924
- is_causal=True,
925
  scale=self.dot_product_scale
926
  )
927
  # "native" scaled-dot-product attention implementation.
@@ -930,13 +891,14 @@ class CausalSelfAttention(nn.Module):
930
  scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
931
 
932
  # Mask future positions from the past
933
- scores.masked_fill_(
934
- torch.tril(
935
- torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device),
936
- diagonal=0,
937
- ).logical_not(),
938
- float('-inf'),
939
- )
 
940
 
941
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
942
  attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
@@ -956,10 +918,10 @@ class CausalSelfAttention(nn.Module):
956
  return dict(
957
  hidden_states=attended_values,
958
  attentions=attentions,
959
- # Unimplemented...
960
- past_key_values=None
961
  )
962
-
 
963
  def flash2_forward(
964
  self,
965
  qkv,
@@ -977,9 +939,9 @@ class CausalSelfAttention(nn.Module):
977
  -1,
978
  (3, self.num_heads, self.d_head)
979
  )
980
-
981
  attended_values = flash_attn_qkvpacked_func(
982
- qkv.bfloat16(),
983
  dropout_p=self.dropout.p if self.training else 0.0,
984
  softmax_scale=self.dot_product_scale,
985
  causal=True,
@@ -1007,18 +969,8 @@ class CausalSelfAttention(nn.Module):
1007
  batch_size, seq_len, d_embed = qkv.shape
1008
 
1009
  # Feed the inputs through the K, Q, V matrices.
1010
- query, key, value = self.project_input(qkv)
1011
-
1012
- # TODO: Refactor -- this code is repeated in the baseline implementation.
1013
- # Split projections into multiple heads and swap position of sequence / heads dimension
1014
- query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1015
- key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1016
- value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1017
-
1018
- if past_key_values is not None:
1019
- key, value = past_key_values.update(key, value, self.layer_idx)
1020
-
1021
- #query, key, value = self._downcast_to_float16(query, key, value)
1022
 
1023
  # Expected inputs to flash2:
1024
  # q: (batch_size, seqlen, nheads, headdim)
@@ -1049,204 +1001,22 @@ class CausalSelfAttention(nn.Module):
1049
  past_key_values=past_key_values
1050
  )
1051
 
1052
- @staticmethod
1053
- def _downcast_to_float16(query, key, value):
1054
- # Copied section for Transformers to handle this
1055
- # TODO: Revist other Flash2 impelementation, above
1056
- input_dtype = query.dtype
1057
- if input_dtype == torch.float32:
1058
- if torch.is_autocast_enabled():
1059
- target_dtype = torch.get_autocast_gpu_dtype()
1060
- # Handle the case where the model is quantized
1061
- elif hasattr(self.config, "_pre_quantization_dtype"):
1062
- target_dtype = self.config._pre_quantization_dtype
1063
- else:
1064
- target_dtype = self.q_proj.weight.dtype
1065
- logger.warning_once(
1066
- f"The input hidden states seems to be silently casted in float32, this might be related to"
1067
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1068
- f" {target_dtype}."
1069
- )
1070
- query = query.to(target_dtype)
1071
- key = key.to(target_dtype)
1072
- value = value.to(target_dtype)
1073
- return query, key, value
1074
-
1075
-
1076
- ########### TODO: Update to newer API, with inference cache
1077
-
1078
- # Attention layer with ALiBi relative positional encoding
1079
- # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
1080
- # https://arxiv.org/pdf/2108.12409.pdf
1081
- def alibi_biases(query_len, key_len, device='cpu'):
1082
- x = torch.arange(key_len, device=device)[None, :]
1083
- y = torch.arange(query_len, device=device)[:, None]
1084
- return x - y
1085
-
1086
- class CausalAlibiAttention(nn.Module):
1087
- def __init__(
1088
- self,
1089
- d_model,
1090
- num_heads,
1091
- beta=1.0,
1092
- dropout=0.1,
1093
- # values:
1094
- # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
1095
- # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
1096
- # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; can't train Alibi weights; least memory usage.
1097
- # Note: You can perform initial training with "torch," then switch to "flash2," after the Alibi weights have settled.
1098
- window_size=None,
1099
- attn_type="native",
1100
- freeze_alibi=True,
1101
- ):
1102
- super().__init__()
1103
- self.d_model = d_model
1104
- self.num_heads = num_heads
1105
- self.beta = beta
1106
- self.attn_type = attn_type
1107
-
1108
- assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
1109
-
1110
- # The dimension of each head.
1111
- self.d_head = d_model // num_heads
1112
-
1113
- # We scale the attention scores by the inverse-square-root of the head dimension
1114
- # this shifts the temerature of softmax.
1115
- self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
1116
-
1117
- self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model))
1118
- self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False)
1119
-
1120
- if window_size is not None:
1121
- self.window_size=(window_size, -1)
1122
- else:
1123
- self.window_size = (-1, -1)
1124
-
1125
- self.dropout = nn.Dropout(dropout)
1126
-
1127
- # This generates the original slope distribution from the paper.
1128
- # Observations with trainable slopes suggest that the high half of the slopes shift
1129
- # towards / past 1.0 and the low half approach zero or even go slightly negative.
1130
- # alibi_slopes = 1.0 / torch.logspace(1, 8, self.num_heads, base=2, dtype=torch.float)
1131
-
1132
- # These appear to work better, as initial values, in practice.
1133
- alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float)
1134
-
1135
- # If not trainable, it can improve performance somewhat if the low half are set to zero. Apparently
1136
- # making roughly half of the slopes position-agnostic is somehow closer to optimal?
1137
- # alibi_slopes.masked_fill_(torch.where(torch.arange(0, self.num_heads) >= (self.num_heads / 2), True, False), 0)
1138
-
1139
- self.alibi_slopes = nn.Parameter(alibi_slopes)
1140
-
1141
- # Optionally, allow/disallow training of ALiBi slopes.
1142
- self.alibi_slopes.requires_grad = (not freeze_alibi)
1143
- self.reset_parameters()
1144
-
1145
- def extra_repr(self) -> str:
1146
- return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}'
1147
-
1148
- def reset_parameters(self):
1149
- # Deepnet initialization
1150
- # https://arxiv.org/pdf/2203.00555.pdf
1151
-
1152
- q, k, v = self.in_proj.chunk(3)
1153
- init.xavier_uniform_(q, gain=1.0)
1154
- init.xavier_uniform_(k, gain=1.0)
1155
- init.xavier_uniform_(v, gain=self.beta)
1156
- init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
1157
-
1158
- def project_input(self, qkv):
1159
- proj = F.linear(qkv, self.in_proj)
1160
- return proj.chunk(chunks=3, dim=-1)
1161
-
1162
- def forward(self, qkv, need_weights):
1163
- if self.attn_type == "flash2":
1164
- return self.flash2_forward(qkv)
1165
-
1166
- # qkv: (batch_size, seq_len, d_embed)
1167
- batch_size, seq_len, d_embed = qkv.shape
1168
-
1169
- # Feed the inputs through the K, Q, V matrices.
1170
- query, key, value = self.project_input(qkv)
1171
-
1172
- # Split projections into multiple heads and swap position of sequence / heads dimension
1173
- query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1174
- key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1175
- value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1176
-
1177
- # Apply Alibi relative positional biases.
1178
- attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1)
1179
-
1180
- # Mask future positions from the past
1181
- causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0)
1182
- attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf'))
1183
- del causal_mask
1184
-
1185
- # Default to returning empty attention weights.
1186
- attention_weights = None
1187
-
1188
- if self.attn_type == "torch":
1189
- # This context manager can be used to force which implementation to use.
1190
- #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
1191
- attended_values = F.scaled_dot_product_attention(
1192
- query,
1193
- key,
1194
- value,
1195
- attn_mask=attn_bias.to(dtype=query.dtype),
1196
- dropout_p=self.dropout.p if self.training else 0.0,
1197
- is_causal=False,
1198
- scale=self.dot_product_scale
1199
- )
1200
- # "native" scaled-dot-product attention implementation.
1201
  else:
1202
- # Compute attention scores
1203
- scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
1204
-
1205
- # Adjust scores with attn_mask
1206
- scores += attn_bias
1207
-
1208
- # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
1209
- attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
1210
-
1211
- # Use the attention weights to get a weighted combination of value vectors
1212
- attended_values = torch.matmul(attention_weights, value)
1213
- if not output_attentions:
1214
- attention_weights = None
1215
-
1216
- # Concatenate attention heads and project to original embedding size using the output linear layer
1217
- attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
1218
-
1219
- # Project the concatenated output through the output matrix.
1220
- attended_values = self.output_linear(attended_values)
1221
- return attended_values, attention_weights
1222
-
1223
- def flash2_forward(self, qkv):
1224
- batch_size, seq_len, d_embed = qkv.shape
1225
-
1226
- # Feed the inputs through the K, Q, V matrices.
1227
- # query : (batch_size, seq_len, d_model)
1228
- # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
1229
- qkv = F.linear(
1230
- qkv,
1231
- self.in_proj,
1232
- ).unflatten(
1233
- -1,
1234
- (3, self.num_heads, self.d_head)
1235
  )
1236
 
1237
- attended_values = flash_attn_qkvpacked_func(
1238
- qkv.bfloat16(),
1239
- dropout_p=self.dropout.p if self.training else 0.0,
1240
- softmax_scale=self.dot_product_scale,
1241
- causal=True,
1242
- window_size=self.window_size,
1243
- alibi_slopes=self.alibi_slopes.float(),
1244
- ).to(dtype=qkv.dtype)
1245
- # attended_values: (batch_size, seqlen, nheads, headdim)
1246
-
1247
- # Concatentate heads back into d_embed
1248
- attended_values = attended_values.view(batch_size, seq_len, d_embed)
1249
-
1250
- # Project the concatenated output through the output matrix.
1251
- attended_values = self.output_linear(attended_values)
1252
- return attended_values, None
 
340
  ):
341
  attention_mask = attention_mask[:, -max_cache_length:]
342
 
 
343
  position_ids = kwargs.get("position_ids", None)
344
  if attention_mask is not None and position_ids is None:
345
  # create position_ids on the fly for batch generation
 
419
  num_heads=config.num_attention_heads,
420
  attn_type=attn_type,
421
  layer_idx=layer_idx,
422
+ config=config,
423
  **config.attention_args,
424
  )
425
 
 
516
  init.constant_(self.output_projection.bias, 0.)
517
  init.normal_(self.embedding.weight, std=self.d_model**-0.5)
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  # Converts a torch array of integers into their equivalent binary codes.
520
  def binary_tensor(x, bits):
521
  mask = 2**torch.arange(bits).to(x.device, x.dtype)
 
772
  init.constant_(self.linear1.bias, 0.)
773
  init.constant_(self.linear2.bias, 0.)
774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  class CausalSelfAttention(nn.Module):
776
  def __init__(
777
  self,
 
783
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
784
  attn_type,
785
  layer_idx,
786
+ config,
787
  beta=1.0,
788
  dropout=0.1,
789
  ):
 
793
  self.beta = beta
794
  self.attn_type = attn_type
795
  self.layer_idx = layer_idx
796
+ self.config = config
797
 
798
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
799
 
 
824
  init.constant_(self.in_proj.bias, 0.)
825
  init.constant_(self.output_linear.bias, 0.)
826
 
827
+ # Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
828
+ def project_input(self, qkv, past_key_values):
829
+ batch_size, seq_len, d_embed = qkv.shape
830
  proj = self.in_proj(qkv)
831
+ query, key, value = proj.chunk(chunks=3, dim=-1)
832
+
833
+ # Split projections into multiple heads and swap position of sequence / heads dimension
834
+ query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
835
+ key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
836
+ value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
837
+
838
+ # Update the cache values.
839
+ if past_key_values is not None:
840
+ key, value = past_key_values.update(key, value, self.layer_idx)
841
+ return query, key, value
842
 
843
  def forward(
844
  self,
 
847
  past_key_values,
848
  use_cache,
849
  ):
850
+ attn_type = self.attn_type
851
+ if output_attentions and attn_type != "native":
852
+ logger.warning_once(
853
+ "CausalSelfAttention(output_attentions=True) and attn_type is not 'native': "
854
+ "Forcing native attention."
855
+ )
856
+ attn_type = "native"
857
+
858
+ if attn_type == "flash2":
859
  if use_cache is None or use_cache == False:
860
  return self.flash2_forward(qkv)
861
  else:
 
865
  batch_size, seq_len, d_embed = qkv.shape
866
 
867
  # Feed the inputs through the K, Q, V matrices.
868
+ query, key, value = self.project_input(qkv, past_key_values)
869
+ kv_seq_len = key.shape[-2]
870
+
 
 
 
 
 
 
 
 
871
  # Default to returning empty attention weights.
872
  attentions = None
873
+
874
+ # https://github.com/pytorch/pytorch/issues/112577
875
 
876
+ if attn_type == "torch":
877
  # This context manager can be used to force which implementation to use.
878
  #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
879
  attended_values = F.scaled_dot_product_attention(
 
882
  value,
883
  attn_mask=None,
884
  dropout_p=self.dropout.p if self.training else 0.0,
885
+ is_causal=(seq_len > 1),
886
  scale=self.dot_product_scale
887
  )
888
  # "native" scaled-dot-product attention implementation.
 
891
  scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
892
 
893
  # Mask future positions from the past
894
+ if seq_len > 1:
895
+ scores.masked_fill_(
896
+ torch.tril(
897
+ torch.ones(seq_len, kv_seq_len, dtype=torch.bool, device=qkv.device),
898
+ diagonal=0,
899
+ ).logical_not(),
900
+ float('-inf'),
901
+ )
902
 
903
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
904
  attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
 
918
  return dict(
919
  hidden_states=attended_values,
920
  attentions=attentions,
921
+ past_key_values=past_key_values
 
922
  )
923
+
924
+ # No cache support, but faster
925
  def flash2_forward(
926
  self,
927
  qkv,
 
939
  -1,
940
  (3, self.num_heads, self.d_head)
941
  )
942
+
943
  attended_values = flash_attn_qkvpacked_func(
944
+ self._downcast_to_float16(qkv)[0],
945
  dropout_p=self.dropout.p if self.training else 0.0,
946
  softmax_scale=self.dot_product_scale,
947
  causal=True,
 
969
  batch_size, seq_len, d_embed = qkv.shape
970
 
971
  # Feed the inputs through the K, Q, V matrices.
972
+ query, key, value = self.project_input(qkv, past_key_values)
973
+ query, key, value = self._downcast_to_float16(query, key, value)
 
 
 
 
 
 
 
 
 
 
974
 
975
  # Expected inputs to flash2:
976
  # q: (batch_size, seqlen, nheads, headdim)
 
1001
  past_key_values=past_key_values
1002
  )
1003
 
1004
+ def _downcast_to_float16(self, *args):
1005
+ if args[0].dtype != torch.float32:
1006
+ return args
1007
+
1008
+ if torch.is_autocast_enabled():
1009
+ target_dtype = torch.get_autocast_gpu_dtype()
1010
+ # Handle the case where the model is quantized
1011
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1012
+ target_dtype = self.config._pre_quantization_dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  else:
1014
+ target_dtype = self.output_linear.weight.dtype
1015
+
1016
+ logger.warning_once(
1017
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1018
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1019
+ f" {target_dtype}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1020
  )
1021
 
1022
+ return (arg.to(target_dtype) for arg in args)