Update modeling_yuanlm2.py

#1
Files changed (1) hide show
  1. modeling_yuanlm2.py +16 -140
modeling_yuanlm2.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -32,11 +32,6 @@ from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_yuan import YuanConfig
34
  from einops import rearrange
35
- # from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
36
- #from apex.normalization import MixedFusedRMSNorm as RMSNorm
37
- #from flash_attn import flash_attn_func
38
- #from transformer_engine.pytorch import RMSNorm
39
- import pdb
40
  import copy
41
  try:
42
  import grouped_gemm as gg
@@ -70,23 +65,6 @@ class RMSNorm(torch.nn.Module):
70
  return self.weight * hidden_states
71
 
72
 
73
- """
74
- class YuanRotaryEmbedding(nn.Module):
75
- def __init__(self, dim, base=10000, dtype=torch.float32, device=None, scaling_factor=1.0, rope_type='default'):
76
- super().__init__()
77
- inv_freq = (1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))).to(dtype)#.to('cuda:1')
78
- self.register_buffer('inv_freq', inv_freq)
79
-
80
- def forward(self, max_seq_len, offset=0):
81
- self.inv_freq = self.inv_freq.to(torch.float32)
82
- seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
83
- freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
84
- # first part even vector components, second part odd vector components,
85
- # 2 * dim in dimension size
86
- emb = torch.cat((freqs, freqs), dim=-1)
87
- # emb [seq_length, .., dim]
88
- return emb[:, None, None, :]"""
89
-
90
  class YuanRotaryEmbedding(nn.Module):
91
  def __init__(self, dim, base=10000, dtype=torch.float32, rotary_interleaved=False, seq_len_interpolation_factor=None):
92
  super().__init__()
@@ -143,17 +121,11 @@ class YuanRotaryEmbedding(nn.Module):
143
  )
144
  # emb [seq_length, .., dim]
145
  emb = emb[:, None, None, :]
146
- #emb = emb[:, None, :]
147
  return emb
148
 
149
 
150
  def _rotate_half(x, rotary_interleaved):
151
- """huggingface version
152
- change sign so the last dimension becomes [-odd, +even]
153
-
154
- x1, x2 = torch.chunk(x, 2, dim=-1)
155
- return torch.cat((-x2, x1), dim=-1)
156
- """
157
  if not rotary_interleaved:
158
  x1, x2 = torch.chunk(x, 2, dim=-1)
159
  return torch.cat((-x2, x1), dim=-1)
@@ -166,7 +138,6 @@ def _rotate_half(x, rotary_interleaved):
166
  def apply_rotary_pos_emb(t, freqs, position_ids, rotary_interleaved=False):
167
 
168
  rot_dim = freqs.shape[-1]
169
- #if position_ids.shape[1] > 1:
170
  freqs = freqs[position_ids]
171
  freqs = freqs.view(t.shape[1],freqs.shape[1],freqs.shape[2],freqs.shape[4]).transpose(0,1)
172
  # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
@@ -180,26 +151,7 @@ def apply_rotary_pos_emb(t, freqs, position_ids, rotary_interleaved=False):
180
 
181
  t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
182
  return torch.cat((t, t_pass), dim=-1)
183
- """huggingface version
184
- input tensor t is of shape [seq_length, ..., dim]
185
- rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
186
- check https://kexue.fm/archives/8265 for detailed formulas
187
-
188
- dtype = t.dtype
189
- rot_dim = freqs.shape[-1]
190
- t_pass = t[..., rot_dim:]
191
- if position_ids.shape[1] > 1:
192
- freqs = freqs[position_ids]
193
- freqs = freqs.view(t.shape[1],freqs.shape[1],freqs.shape[2],freqs.shape[4]).transpose(0,1)
194
- # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
195
- t = t[..., :rot_dim]
196
- # first part is cosine component
197
- # second part is sine component, need to change signs with _rotate_half method
198
- t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
199
- t = t.to(dtype)
200
- """
201
 
202
- return torch.cat((t, t_pass), dim=-1)
203
 
204
  class LocalizedFiltering(torch.nn.Module):
205
  """
@@ -287,54 +239,6 @@ class LocalizedFiltering(torch.nn.Module):
287
  lf_output = self.output_layernorm(output2 + residual)
288
 
289
  return lf_output
290
- '''#IEIyuan huggingface version
291
- if before_hidden_states == None:
292
- inputs = inputs.transpose(0,1)
293
- seq_len, bsz, embed_dim = inputs.size()
294
- if embed_dim != self.embed_dim:
295
- raise ValueError(
296
- f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}"
297
- )
298
- residual = inputs
299
- inputs = inputs.view(seq_len, 1, bsz, embed_dim).permute(2, 3, 0, 1)
300
- inputs = torch.cat((torch.zeros(bsz, embed_dim, 1, 1, dtype=inputs.dtype, device=inputs.device), inputs), dim=2).contiguous()
301
- output1 = self.conv1(inputs)
302
-
303
- output1 = torch.cat((torch.zeros(bsz, embed_dim // 2, 1, 1, dtype=inputs.dtype, device=inputs.device), output1), dim=2).contiguous()
304
- output2 = self.conv2(output1).permute(2, 3, 0, 1).contiguous()
305
- output2 = output2.view(seq_len, bsz, embed_dim)
306
- assert output2.shape == residual.shape
307
- norm_input = (output2 + residual)#.to('cuda:0')
308
- torch.cuda.set_device(norm_input.device)
309
- lf_output = self.output_layernorm(norm_input)
310
- lf_output = lf_output#.to('cuda:1')
311
- lf_output = lf_output.transpose(0,1)
312
- return lf_output
313
- else:
314
- inputs = inputs.transpose(0,1)
315
- before_hidden_states = before_hidden_states.transpose(0,1)
316
- seq_len, bsz, embed_dim = inputs.size()
317
- if embed_dim != self.embed_dim:
318
- raise ValueError(
319
- f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}"
320
- )
321
- residual = inputs
322
- inputs = inputs.view(seq_len, 1, bsz, embed_dim).permute(2, 3, 0, 1)
323
- before_hidden_states = before_hidden_states.view(2, 1, bsz, embed_dim).permute(2, 3, 0, 1)
324
- inputs = torch.cat((before_hidden_states, inputs), dim=2).contiguous()
325
- output1 = self.conv1(inputs)
326
- output2 = self.conv2(output1).permute(2, 3, 0, 1).contiguous()
327
- output2 = output2.view(seq_len, bsz, embed_dim)
328
- assert output2.shape == residual.shape
329
-
330
- norm_input = (output2 + residual)#.to('cuda:0')
331
- torch.cuda.set_device(norm_input.device)
332
- lf_output = self.output_layernorm(norm_input)
333
- lf_output = lf_output#.to('cuda:1')
334
- lf_output = lf_output.transpose(0,1)
335
- return lf_output
336
- '''
337
-
338
 
339
  def forward(
340
  self,
@@ -444,9 +348,7 @@ class FlashSelfAttention(torch.nn.Module):
444
  # turn off FA causal mask after first inference autoregressive iteration
445
  # only on first autoregressive step q,k,v have same seqlen
446
  is_causal = seqlen_q == seqlen_k
447
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device)
448
- #cu_seqlens_q = [cu_seqlens_q[0], cu_seqlens_q[-1]]
449
- #cu_seqlens_k = [cu_seqlens_k[0], cu_seqlens_k[-1]]
450
  dropout_p = 0
451
 
452
  output = flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, dropout_p, softmax_scale=self.softmax_scale, causal=is_causal)
@@ -473,7 +375,6 @@ class ParallelAttention_router(nn.Module):
473
  is_first_step = False
474
  before_hidden_states = None
475
 
476
- #mixed_x_layer = torch.matmul(hidden_states, self.query_key_value)
477
  mixed_x_layer = self.query_key_value(hidden_states)
478
  (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, self.projection_size, -1)
479
  b, s, z = query_layer.shape
@@ -493,7 +394,6 @@ class YuanExpertMLP(nn.Module):
493
  def __init__(self, config):
494
  super(YuanExpertMLP, self).__init__()
495
  self.gated_linear_unit = config.moe_config['gated_linear_unit']
496
- #self.ffn_hidden_size = config.moe_config['ffn_hidden_size']
497
  self.ffn_hidden_size = config.ffn_hidden_size
498
 
499
 
@@ -579,8 +479,7 @@ class YuanAttention(nn.Module):
579
  self.lf_gate = LocalizedFiltering(self.hidden_size, self.lf_conv2d_group, self.lf_conv2d_num_pad)
580
  self.get_query_key = nn.Linear(self.hidden_size, 2 * self.attention_projection_size, bias=False)
581
  self.core_attention = FlashSelfAttention(causal=True, attention_dropout=config.attn_dropout, softmax_scale=self.softmax_scale)
582
- #self.core_attention_flash = DotProductAttention(num_attention_heads=self.num_heads,
583
- # kv_channels=self.head_dim)
584
 
585
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
586
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -597,7 +496,7 @@ class YuanAttention(nn.Module):
597
  use_cache: bool = False,
598
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
599
  q_len, bsz, _ = hidden_states.size()
600
- hidden_states = hidden_states#.to('cuda:1')
601
  is_first_step = False
602
  if use_cache:
603
  if past_key_value is None:
@@ -621,7 +520,6 @@ class YuanAttention(nn.Module):
621
  else:
622
  hidden_states = self.lf_gate(hidden_states, before_hidden_states)
623
  mixed_qk_layer = self.get_query_key(hidden_states)
624
- #mixed_qk_layer = torch.matmul(hidden_states, qk_tensor)
625
  new_tensor_shape = mixed_qk_layer.size()[:-1] + (self.num_heads, 2 * self.head_dim)
626
  mixed_qk_layer = mixed_qk_layer.view(*new_tensor_shape)
627
  (query_states, key_states) = torch.split(mixed_qk_layer, self.head_dim, dim=-1)
@@ -635,7 +533,6 @@ class YuanAttention(nn.Module):
635
  if rotary_pos_emb is not None:
636
  if position_ids.shape[1] == 1:
637
  q_seq_start = position_ids[0,-1]
638
- #seq_start = past_key_value[0].shape[0]
639
  q_seq_end = q_seq_start + 1
640
  k_seq_end = q_seq_end
641
  else:
@@ -654,17 +551,11 @@ class YuanAttention(nn.Module):
654
  key_states = torch.cat([past_key_value[0], key_states], dim=0)
655
  value_states = torch.cat([past_key_value[1], value_states], dim=0)
656
  past_key_value = (key_states, value_states, inference_hidden_states_memory) if use_cache else None
657
- #query_states = apply_rotary_pos_emb(query_states.permute(1, 0, 2, 3), q_pos_emb, position_ids)
658
- #key_states = apply_rotary_pos_emb(key_states.permute(1, 0, 2, 3), k_pos_emb, position_ids)
659
  query_states = apply_rotary_pos_emb(query_states, q_pos_emb, position_ids)
660
  key_states = apply_rotary_pos_emb(key_states, k_pos_emb, position_ids_k)
661
 
662
  attn_weights = None
663
- #query_states = query_states.transpose(0,1)
664
- #key_states = key_states.transpose(0,1)
665
- #value_states = value_states
666
  attn_output = self.core_attention(query_states, key_states, value_states)
667
- #attn_output = self.core_attention(query_states, key_states, value_states, attention_mask)
668
  q_len, bsz, _, _ = attn_output.shape
669
  attn_output = attn_output.reshape(q_len, bsz, -1)
670
 
@@ -764,11 +655,8 @@ class GroupedMLP(nn.Module):
764
  self.w2 = nn.ModuleList([nn.Linear(self.ffn_hidden_size, self.config.hidden_size, bias=False) for _ in range(num_experts)])
765
  def forward(self, permuted_hidden_states, tokens_per_expert):
766
  torch.cuda.set_device(permuted_hidden_states.device)
767
- permuted_hidden_states = permuted_hidden_states#.to('cuda:0')
768
- #fc1_output = gg.ops.gmm(permuted_hidden_states, self.weight1, tokens_per_expert.cpu(), trans_b=False)
769
 
770
- #intermediate_parallel = self.activation_func(fc1_output)
771
- #fc2_output = gg.ops.gmm(intermediate_parallel, self.weight2, tokens_per_expert.cpu(), trans_b=False)
772
 
773
  fc2_outputs = []
774
  start_idx = 0
@@ -776,18 +664,15 @@ class GroupedMLP(nn.Module):
776
  if tokens_per_expert[i] == 0:
777
  continue
778
  end_idx = start_idx + tokens_per_expert[i]
779
- #fc1_output = torch.matmul(permuted_hidden_states[start_idx:end_idx], self.w1[i])
780
  # Use custom attributes for each expert's Linear layers
781
 
782
  fc1_output = self.w1[i](permuted_hidden_states[start_idx:end_idx])
783
- #print("shape1:", self.w1[i].shape, "shape2:", permuted_hidden_states[start_idx:end_idx].shape)
784
  intermediate_parallel = self.activation_func(fc1_output)
785
- #fc2_output = torch.matmul(intermediate_parallel, self.w2[i])
786
  fc2_output = self.w2[i](intermediate_parallel)
787
  fc2_outputs.append(fc2_output)
788
  start_idx = end_idx
789
  fc2_output = torch.cat(fc2_outputs, dim=0)
790
- return fc2_output#.to('cuda:1')
791
 
792
  class YuanMoeLayer(nn.Module):
793
  def __init__(self, config:YuanConfig):
@@ -800,7 +685,6 @@ class YuanMoeLayer(nn.Module):
800
 
801
  expert_indices_offset = (0)
802
 
803
- #self.gate = ParallelAttention_router(config)
804
  self.router = ParallelAttention_router(config)
805
  self.token_dispatcher = MoEDroplessTokenDispatcher(self.num_experts, config=self.config)
806
  self.experts = GroupedMLP(self.num_experts, self.config)
@@ -812,7 +696,6 @@ class YuanMoeLayer(nn.Module):
812
 
813
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
814
  batch_size, sequence_length, hidden_dim = hidden_states.shape
815
- #logits = self.gate(hidden_states)
816
  logits = self.router(hidden_states)
817
  scores, indices = self.routing(logits)
818
  scores = scores.to(hidden_states.dtype)
@@ -865,9 +748,9 @@ class YuanDecoderLayer(nn.Module):
865
  (see `past_key_values`).
866
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
867
  """
868
- residual = hidden_states#.to('cuda:1')
869
  torch.cuda.set_device(hidden_states.device)
870
- hidden_states = self.input_layernorm(hidden_states) #.to('cuda:0')).to('cuda:1')
871
  # Self Attention
872
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
873
  hidden_states=hidden_states,
@@ -881,10 +764,10 @@ class YuanDecoderLayer(nn.Module):
881
  )
882
  hidden_states = residual + hidden_states.permute(1, 0, 2)
883
  # Fully Connected
884
- residual = hidden_states#.to('cuda:1')
885
  torch.cuda.set_device(hidden_states.device)
886
- hidden_states = self.post_attention_layernorm(hidden_states) #.to('cuda:0')).to('cuda:1')
887
- hidden_states = self.mlp(hidden_states)# .to('cuda:1')
888
  hidden_states = residual + hidden_states
889
  outputs = (hidden_states,)
890
 
@@ -1160,13 +1043,10 @@ class YuanModel(YuanPreTrainedModel):
1160
  seq_length_with_past = seq_length
1161
  past_key_values_length = 0
1162
  if past_key_values is not None:
1163
- #past_key_values_length = past_key_values[0][0].shape[2]
1164
- #modify
1165
- print('0000')
1166
  past_key_values_length = past_key_values[0][0].shape[0]
1167
  seq_length_with_past = seq_length_with_past + past_key_values_length
1168
  else:
1169
- print('1111')
1170
 
1171
  # modify to reset position ids
1172
  if past_key_values is not None:
@@ -1203,7 +1083,6 @@ class YuanModel(YuanPreTrainedModel):
1203
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1204
  )
1205
 
1206
- #rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
1207
  # Rotary positional embeddings (embedding is None for PP intermediate devices)
1208
  rotary_pos_emb = None
1209
  rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
@@ -1220,8 +1099,7 @@ class YuanModel(YuanPreTrainedModel):
1220
  all_hidden_states = () if output_hidden_states else None
1221
  all_self_attns = () if output_attentions else None
1222
  next_decoder_cache = () if use_cache else None
1223
- #position_ids = position_ids.cpu()
1224
- #position_ids_k = position_ids_k.cpu()
1225
  for idx, decoder_layer in enumerate(self.layers):
1226
  if output_hidden_states:
1227
  all_hidden_states += (hidden_states,)
@@ -1261,10 +1139,9 @@ class YuanModel(YuanPreTrainedModel):
1261
 
1262
  if output_attentions:
1263
  all_self_attns += (layer_outputs[1],)
1264
- hidden_states = hidden_states#.to('cuda:0')
1265
- #torch.cuda.set_device(hidden_states.device)
1266
  hidden_states = self.norm(hidden_states)
1267
- #print(hidden_states)
1268
  # add hidden states from the last decoder layer
1269
  if output_hidden_states:
1270
  all_hidden_states += (hidden_states,)
@@ -1284,7 +1161,6 @@ class YuanForCausalLM(YuanPreTrainedModel, GenerationMixin):
1284
  super().__init__(config)
1285
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1286
  self.model = YuanModel(config)
1287
- #self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1288
  self.post_init()
1289
 
1290
  def get_input_embeddings(self):
 
1
  # coding=utf-8
2
+ # Copyright 2022 YuanLab and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_yuan import YuanConfig
34
  from einops import rearrange
 
 
 
 
 
35
  import copy
36
  try:
37
  import grouped_gemm as gg
 
65
  return self.weight * hidden_states
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  class YuanRotaryEmbedding(nn.Module):
69
  def __init__(self, dim, base=10000, dtype=torch.float32, rotary_interleaved=False, seq_len_interpolation_factor=None):
70
  super().__init__()
 
121
  )
122
  # emb [seq_length, .., dim]
123
  emb = emb[:, None, None, :]
 
124
  return emb
125
 
126
 
127
  def _rotate_half(x, rotary_interleaved):
128
+
 
 
 
 
 
129
  if not rotary_interleaved:
130
  x1, x2 = torch.chunk(x, 2, dim=-1)
131
  return torch.cat((-x2, x1), dim=-1)
 
138
  def apply_rotary_pos_emb(t, freqs, position_ids, rotary_interleaved=False):
139
 
140
  rot_dim = freqs.shape[-1]
 
141
  freqs = freqs[position_ids]
142
  freqs = freqs.view(t.shape[1],freqs.shape[1],freqs.shape[2],freqs.shape[4]).transpose(0,1)
143
  # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
 
151
 
152
  t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
153
  return torch.cat((t, t_pass), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
155
 
156
  class LocalizedFiltering(torch.nn.Module):
157
  """
 
239
  lf_output = self.output_layernorm(output2 + residual)
240
 
241
  return lf_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def forward(
244
  self,
 
348
  # turn off FA causal mask after first inference autoregressive iteration
349
  # only on first autoregressive step q,k,v have same seqlen
350
  is_causal = seqlen_q == seqlen_k
351
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device)
 
 
352
  dropout_p = 0
353
 
354
  output = flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, dropout_p, softmax_scale=self.softmax_scale, causal=is_causal)
 
375
  is_first_step = False
376
  before_hidden_states = None
377
 
 
378
  mixed_x_layer = self.query_key_value(hidden_states)
379
  (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, self.projection_size, -1)
380
  b, s, z = query_layer.shape
 
394
  def __init__(self, config):
395
  super(YuanExpertMLP, self).__init__()
396
  self.gated_linear_unit = config.moe_config['gated_linear_unit']
 
397
  self.ffn_hidden_size = config.ffn_hidden_size
398
 
399
 
 
479
  self.lf_gate = LocalizedFiltering(self.hidden_size, self.lf_conv2d_group, self.lf_conv2d_num_pad)
480
  self.get_query_key = nn.Linear(self.hidden_size, 2 * self.attention_projection_size, bias=False)
481
  self.core_attention = FlashSelfAttention(causal=True, attention_dropout=config.attn_dropout, softmax_scale=self.softmax_scale)
482
+
 
483
 
484
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
485
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
496
  use_cache: bool = False,
497
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
498
  q_len, bsz, _ = hidden_states.size()
499
+ hidden_states = hidden_states
500
  is_first_step = False
501
  if use_cache:
502
  if past_key_value is None:
 
520
  else:
521
  hidden_states = self.lf_gate(hidden_states, before_hidden_states)
522
  mixed_qk_layer = self.get_query_key(hidden_states)
 
523
  new_tensor_shape = mixed_qk_layer.size()[:-1] + (self.num_heads, 2 * self.head_dim)
524
  mixed_qk_layer = mixed_qk_layer.view(*new_tensor_shape)
525
  (query_states, key_states) = torch.split(mixed_qk_layer, self.head_dim, dim=-1)
 
533
  if rotary_pos_emb is not None:
534
  if position_ids.shape[1] == 1:
535
  q_seq_start = position_ids[0,-1]
 
536
  q_seq_end = q_seq_start + 1
537
  k_seq_end = q_seq_end
538
  else:
 
551
  key_states = torch.cat([past_key_value[0], key_states], dim=0)
552
  value_states = torch.cat([past_key_value[1], value_states], dim=0)
553
  past_key_value = (key_states, value_states, inference_hidden_states_memory) if use_cache else None
 
 
554
  query_states = apply_rotary_pos_emb(query_states, q_pos_emb, position_ids)
555
  key_states = apply_rotary_pos_emb(key_states, k_pos_emb, position_ids_k)
556
 
557
  attn_weights = None
 
 
 
558
  attn_output = self.core_attention(query_states, key_states, value_states)
 
559
  q_len, bsz, _, _ = attn_output.shape
560
  attn_output = attn_output.reshape(q_len, bsz, -1)
561
 
 
655
  self.w2 = nn.ModuleList([nn.Linear(self.ffn_hidden_size, self.config.hidden_size, bias=False) for _ in range(num_experts)])
656
  def forward(self, permuted_hidden_states, tokens_per_expert):
657
  torch.cuda.set_device(permuted_hidden_states.device)
658
+ permuted_hidden_states = permuted_hidden_states
 
659
 
 
 
660
 
661
  fc2_outputs = []
662
  start_idx = 0
 
664
  if tokens_per_expert[i] == 0:
665
  continue
666
  end_idx = start_idx + tokens_per_expert[i]
 
667
  # Use custom attributes for each expert's Linear layers
668
 
669
  fc1_output = self.w1[i](permuted_hidden_states[start_idx:end_idx])
 
670
  intermediate_parallel = self.activation_func(fc1_output)
 
671
  fc2_output = self.w2[i](intermediate_parallel)
672
  fc2_outputs.append(fc2_output)
673
  start_idx = end_idx
674
  fc2_output = torch.cat(fc2_outputs, dim=0)
675
+ return fc2_output
676
 
677
  class YuanMoeLayer(nn.Module):
678
  def __init__(self, config:YuanConfig):
 
685
 
686
  expert_indices_offset = (0)
687
 
 
688
  self.router = ParallelAttention_router(config)
689
  self.token_dispatcher = MoEDroplessTokenDispatcher(self.num_experts, config=self.config)
690
  self.experts = GroupedMLP(self.num_experts, self.config)
 
696
 
697
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
698
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
699
  logits = self.router(hidden_states)
700
  scores, indices = self.routing(logits)
701
  scores = scores.to(hidden_states.dtype)
 
748
  (see `past_key_values`).
749
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
750
  """
751
+ residual = hidden_states
752
  torch.cuda.set_device(hidden_states.device)
753
+ hidden_states = self.input_layernorm(hidden_states)
754
  # Self Attention
755
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
756
  hidden_states=hidden_states,
 
764
  )
765
  hidden_states = residual + hidden_states.permute(1, 0, 2)
766
  # Fully Connected
767
+ residual = hidden_states
768
  torch.cuda.set_device(hidden_states.device)
769
+ hidden_states = self.post_attention_layernorm(hidden_states)
770
+ hidden_states = self.mlp(hidden_states)
771
  hidden_states = residual + hidden_states
772
  outputs = (hidden_states,)
773
 
 
1043
  seq_length_with_past = seq_length
1044
  past_key_values_length = 0
1045
  if past_key_values is not None:
 
 
 
1046
  past_key_values_length = past_key_values[0][0].shape[0]
1047
  seq_length_with_past = seq_length_with_past + past_key_values_length
1048
  else:
1049
+ pass
1050
 
1051
  # modify to reset position ids
1052
  if past_key_values is not None:
 
1083
  attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1084
  )
1085
 
 
1086
  # Rotary positional embeddings (embedding is None for PP intermediate devices)
1087
  rotary_pos_emb = None
1088
  rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
 
1099
  all_hidden_states = () if output_hidden_states else None
1100
  all_self_attns = () if output_attentions else None
1101
  next_decoder_cache = () if use_cache else None
1102
+
 
1103
  for idx, decoder_layer in enumerate(self.layers):
1104
  if output_hidden_states:
1105
  all_hidden_states += (hidden_states,)
 
1139
 
1140
  if output_attentions:
1141
  all_self_attns += (layer_outputs[1],)
1142
+ hidden_states = hidden_states
1143
+
1144
  hidden_states = self.norm(hidden_states)
 
1145
  # add hidden states from the last decoder layer
1146
  if output_hidden_states:
1147
  all_hidden_states += (hidden_states,)
 
1161
  super().__init__(config)
1162
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1163
  self.model = YuanModel(config)
 
1164
  self.post_init()
1165
 
1166
  def get_input_embeddings(self):