root commited on
Commit
55c8626
·
1 Parent(s): 1fd9820

update code

Browse files
configuration_yuanvl.py CHANGED
@@ -1,6 +1,6 @@
1
  # --------------------------------------------------------
2
  # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
 
1
  # --------------------------------------------------------
2
  # InternVL
3
+ # Copyright (c) 2024 YuanLabAI
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
conversation.py CHANGED
@@ -391,7 +391,7 @@ register_conv_template(
391
  Conversation(
392
  name='yuan-chat',
393
  system_template='<|im_start|>system\n{system_message}',
394
- system_message='你是IEI-源多模态模型,英文名是YuanVL,是由浪潮信息开发的多模态大语言模型。',
395
  roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
396
  sep_style=SeparatorStyle.MPT,
397
  sep='<|im_end|>\n',
 
391
  Conversation(
392
  name='yuan-chat',
393
  system_template='<|im_start|>system\n{system_message}',
394
+ system_message='你是Yuan3.0 Flash多模态大模型,由YuanLab.ai 团队开发的多模态大语言模型。',
395
  roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
396
  sep_style=SeparatorStyle.MPT,
397
  sep='<|im_end|>\n',
modeling_yuanlm2.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -26,17 +26,15 @@ import torch.utils.checkpoint
26
  from torch import einsum, nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
  from transformers.activations import ACT2FN
29
- from transformers.generation import GenerationMixin
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
- from .configuration_yuan import YuanConfig
34
  from einops import rearrange
35
  # from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
36
  #from apex.normalization import MixedFusedRMSNorm as RMSNorm
37
  #from flash_attn import flash_attn_func
38
- #from transformer_engine.pytorch import RMSNorm
39
- import pdb
40
  import copy
41
  try:
42
  import grouped_gemm as gg
@@ -53,39 +51,6 @@ logger = logging.get_logger(__name__)
53
 
54
  _CONFIG_FOR_DOC = "YuanConfig"
55
 
56
- class RMSNorm(torch.nn.Module):
57
- def __init__(self, hidden_size, eps=1e-6):
58
- super().__init__()
59
- self.weight = torch.nn.Parameter(torch.ones(hidden_size))
60
- self.variance_epsilon = eps
61
-
62
- def forward(self, hidden_states):
63
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
64
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
65
-
66
- # convert into half-precision if necessary
67
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
68
- hidden_states = hidden_states.to(self.weight.dtype)
69
-
70
- return self.weight * hidden_states
71
-
72
-
73
- """
74
- class YuanRotaryEmbedding(nn.Module):
75
- def __init__(self, dim, base=10000, dtype=torch.float32, device=None, scaling_factor=1.0, rope_type='default'):
76
- super().__init__()
77
- inv_freq = (1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))).to(dtype)#.to('cuda:1')
78
- self.register_buffer('inv_freq', inv_freq)
79
-
80
- def forward(self, max_seq_len, offset=0):
81
- self.inv_freq = self.inv_freq.to(torch.float32)
82
- seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
83
- freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
84
- # first part even vector components, second part odd vector components,
85
- # 2 * dim in dimension size
86
- emb = torch.cat((freqs, freqs), dim=-1)
87
- # emb [seq_length, .., dim]
88
- return emb[:, None, None, :]"""
89
 
90
  class YuanRotaryEmbedding(nn.Module):
91
  def __init__(self, dim, base=10000, dtype=torch.float32, rotary_interleaved=False, seq_len_interpolation_factor=None):
@@ -143,17 +108,10 @@ class YuanRotaryEmbedding(nn.Module):
143
  )
144
  # emb [seq_length, .., dim]
145
  emb = emb[:, None, None, :]
146
- #emb = emb[:, None, :]
147
  return emb
148
 
149
 
150
  def _rotate_half(x, rotary_interleaved):
151
- """huggingface version
152
- change sign so the last dimension becomes [-odd, +even]
153
-
154
- x1, x2 = torch.chunk(x, 2, dim=-1)
155
- return torch.cat((-x2, x1), dim=-1)
156
- """
157
  if not rotary_interleaved:
158
  x1, x2 = torch.chunk(x, 2, dim=-1)
159
  return torch.cat((-x2, x1), dim=-1)
@@ -180,24 +138,6 @@ def apply_rotary_pos_emb(t, freqs, position_ids, rotary_interleaved=False):
180
 
181
  t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
182
  return torch.cat((t, t_pass), dim=-1)
183
- """huggingface version
184
- input tensor t is of shape [seq_length, ..., dim]
185
- rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
186
- check https://kexue.fm/archives/8265 for detailed formulas
187
-
188
- dtype = t.dtype
189
- rot_dim = freqs.shape[-1]
190
- t_pass = t[..., rot_dim:]
191
- if position_ids.shape[1] > 1:
192
- freqs = freqs[position_ids]
193
- freqs = freqs.view(t.shape[1],freqs.shape[1],freqs.shape[2],freqs.shape[4]).transpose(0,1)
194
- # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
195
- t = t[..., :rot_dim]
196
- # first part is cosine component
197
- # second part is sine component, need to change signs with _rotate_half method
198
- t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
199
- t = t.to(dtype)
200
- """
201
 
202
  return torch.cat((t, t_pass), dim=-1)
203
 
@@ -287,53 +227,6 @@ class LocalizedFiltering(torch.nn.Module):
287
  lf_output = self.output_layernorm(output2 + residual)
288
 
289
  return lf_output
290
- '''#IEIyuan huggingface version
291
- if before_hidden_states == None:
292
- inputs = inputs.transpose(0,1)
293
- seq_len, bsz, embed_dim = inputs.size()
294
- if embed_dim != self.embed_dim:
295
- raise ValueError(
296
- f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}"
297
- )
298
- residual = inputs
299
- inputs = inputs.view(seq_len, 1, bsz, embed_dim).permute(2, 3, 0, 1)
300
- inputs = torch.cat((torch.zeros(bsz, embed_dim, 1, 1, dtype=inputs.dtype, device=inputs.device), inputs), dim=2).contiguous()
301
- output1 = self.conv1(inputs)
302
-
303
- output1 = torch.cat((torch.zeros(bsz, embed_dim // 2, 1, 1, dtype=inputs.dtype, device=inputs.device), output1), dim=2).contiguous()
304
- output2 = self.conv2(output1).permute(2, 3, 0, 1).contiguous()
305
- output2 = output2.view(seq_len, bsz, embed_dim)
306
- assert output2.shape == residual.shape
307
- norm_input = (output2 + residual)#.to('cuda:0')
308
- torch.cuda.set_device(norm_input.device)
309
- lf_output = self.output_layernorm(norm_input)
310
- lf_output = lf_output#.to('cuda:1')
311
- lf_output = lf_output.transpose(0,1)
312
- return lf_output
313
- else:
314
- inputs = inputs.transpose(0,1)
315
- before_hidden_states = before_hidden_states.transpose(0,1)
316
- seq_len, bsz, embed_dim = inputs.size()
317
- if embed_dim != self.embed_dim:
318
- raise ValueError(
319
- f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}"
320
- )
321
- residual = inputs
322
- inputs = inputs.view(seq_len, 1, bsz, embed_dim).permute(2, 3, 0, 1)
323
- before_hidden_states = before_hidden_states.view(2, 1, bsz, embed_dim).permute(2, 3, 0, 1)
324
- inputs = torch.cat((before_hidden_states, inputs), dim=2).contiguous()
325
- output1 = self.conv1(inputs)
326
- output2 = self.conv2(output1).permute(2, 3, 0, 1).contiguous()
327
- output2 = output2.view(seq_len, bsz, embed_dim)
328
- assert output2.shape == residual.shape
329
-
330
- norm_input = (output2 + residual)#.to('cuda:0')
331
- torch.cuda.set_device(norm_input.device)
332
- lf_output = self.output_layernorm(norm_input)
333
- lf_output = lf_output#.to('cuda:1')
334
- lf_output = lf_output.transpose(0,1)
335
- return lf_output
336
- '''
337
 
338
 
339
  def forward(
@@ -445,8 +338,6 @@ class FlashSelfAttention(torch.nn.Module):
445
  # only on first autoregressive step q,k,v have same seqlen
446
  is_causal = seqlen_q == seqlen_k
447
  cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device)
448
- #cu_seqlens_q = [cu_seqlens_q[0], cu_seqlens_q[-1]]
449
- #cu_seqlens_k = [cu_seqlens_k[0], cu_seqlens_k[-1]]
450
  dropout_p = 0
451
 
452
  output = flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, dropout_p, softmax_scale=self.softmax_scale, causal=is_causal)
@@ -579,8 +470,6 @@ class YuanAttention(nn.Module):
579
  self.lf_gate = LocalizedFiltering(self.hidden_size, self.lf_conv2d_group, self.lf_conv2d_num_pad)
580
  self.get_query_key = nn.Linear(self.hidden_size, 2 * self.attention_projection_size, bias=False)
581
  self.core_attention = FlashSelfAttention(causal=True, attention_dropout=config.attn_dropout, softmax_scale=self.softmax_scale)
582
- #self.core_attention_flash = DotProductAttention(num_attention_heads=self.num_heads,
583
- # kv_channels=self.head_dim)
584
 
585
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
586
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -596,6 +485,7 @@ class YuanAttention(nn.Module):
596
  output_attentions: bool = False,
597
  use_cache: bool = False,
598
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
599
  q_len, bsz, _ = hidden_states.size()
600
  hidden_states = hidden_states#.to('cuda:1')
601
  is_first_step = False
@@ -621,7 +511,6 @@ class YuanAttention(nn.Module):
621
  else:
622
  hidden_states = self.lf_gate(hidden_states, before_hidden_states)
623
  mixed_qk_layer = self.get_query_key(hidden_states)
624
- #mixed_qk_layer = torch.matmul(hidden_states, qk_tensor)
625
  new_tensor_shape = mixed_qk_layer.size()[:-1] + (self.num_heads, 2 * self.head_dim)
626
  mixed_qk_layer = mixed_qk_layer.view(*new_tensor_shape)
627
  (query_states, key_states) = torch.split(mixed_qk_layer, self.head_dim, dim=-1)
@@ -635,7 +524,6 @@ class YuanAttention(nn.Module):
635
  if rotary_pos_emb is not None:
636
  if position_ids.shape[1] == 1:
637
  q_seq_start = position_ids[0,-1]
638
- #seq_start = past_key_value[0].shape[0]
639
  q_seq_end = q_seq_start + 1
640
  k_seq_end = q_seq_end
641
  else:
@@ -654,17 +542,11 @@ class YuanAttention(nn.Module):
654
  key_states = torch.cat([past_key_value[0], key_states], dim=0)
655
  value_states = torch.cat([past_key_value[1], value_states], dim=0)
656
  past_key_value = (key_states, value_states, inference_hidden_states_memory) if use_cache else None
657
- #query_states = apply_rotary_pos_emb(query_states.permute(1, 0, 2, 3), q_pos_emb, position_ids)
658
- #key_states = apply_rotary_pos_emb(key_states.permute(1, 0, 2, 3), k_pos_emb, position_ids)
659
  query_states = apply_rotary_pos_emb(query_states, q_pos_emb, position_ids)
660
  key_states = apply_rotary_pos_emb(key_states, k_pos_emb, position_ids_k)
661
 
662
  attn_weights = None
663
- #query_states = query_states.transpose(0,1)
664
- #key_states = key_states.transpose(0,1)
665
- #value_states = value_states
666
  attn_output = self.core_attention(query_states, key_states, value_states)
667
- #attn_output = self.core_attention(query_states, key_states, value_states, attention_mask)
668
  q_len, bsz, _, _ = attn_output.shape
669
  attn_output = attn_output.reshape(q_len, bsz, -1)
670
 
@@ -755,7 +637,6 @@ class GroupedMLP(nn.Module):
755
  return torch.nn.functional.silu(x[0]) * x[1]
756
 
757
  self.activation_func = glu
758
- #self.ffn_hidden_size = config.moe_config['ffn_hidden_size']
759
  self.ffn_hidden_size = config.ffn_hidden_size
760
  fc1_output_size_per_partition = self.ffn_hidden_size * 2
761
  fc2_input_size = self.ffn_hidden_size
@@ -765,10 +646,6 @@ class GroupedMLP(nn.Module):
765
  def forward(self, permuted_hidden_states, tokens_per_expert):
766
  torch.cuda.set_device(permuted_hidden_states.device)
767
  permuted_hidden_states = permuted_hidden_states#.to('cuda:0')
768
- #fc1_output = gg.ops.gmm(permuted_hidden_states, self.weight1, tokens_per_expert.cpu(), trans_b=False)
769
-
770
- #intermediate_parallel = self.activation_func(fc1_output)
771
- #fc2_output = gg.ops.gmm(intermediate_parallel, self.weight2, tokens_per_expert.cpu(), trans_b=False)
772
 
773
  fc2_outputs = []
774
  start_idx = 0
@@ -776,13 +653,10 @@ class GroupedMLP(nn.Module):
776
  if tokens_per_expert[i] == 0:
777
  continue
778
  end_idx = start_idx + tokens_per_expert[i]
779
- #fc1_output = torch.matmul(permuted_hidden_states[start_idx:end_idx], self.w1[i])
780
  # Use custom attributes for each expert's Linear layers
781
 
782
  fc1_output = self.w1[i](permuted_hidden_states[start_idx:end_idx])
783
- #print("shape1:", self.w1[i].shape, "shape2:", permuted_hidden_states[start_idx:end_idx].shape)
784
  intermediate_parallel = self.activation_func(fc1_output)
785
- #fc2_output = torch.matmul(intermediate_parallel, self.w2[i])
786
  fc2_output = self.w2[i](intermediate_parallel)
787
  fc2_outputs.append(fc2_output)
788
  start_idx = end_idx
@@ -800,7 +674,6 @@ class YuanMoeLayer(nn.Module):
800
 
801
  expert_indices_offset = (0)
802
 
803
- #self.gate = ParallelAttention_router(config)
804
  self.router = ParallelAttention_router(config)
805
  self.token_dispatcher = MoEDroplessTokenDispatcher(self.num_experts, config=self.config)
806
  self.experts = GroupedMLP(self.num_experts, self.config)
@@ -812,7 +685,6 @@ class YuanMoeLayer(nn.Module):
812
 
813
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
814
  batch_size, sequence_length, hidden_dim = hidden_states.shape
815
- #logits = self.gate(hidden_states)
816
  logits = self.router(hidden_states)
817
  scores, indices = self.routing(logits)
818
  scores = scores.to(hidden_states.dtype)
@@ -868,6 +740,7 @@ class YuanDecoderLayer(nn.Module):
868
  residual = hidden_states#.to('cuda:1')
869
  torch.cuda.set_device(hidden_states.device)
870
  hidden_states = self.input_layernorm(hidden_states) #.to('cuda:0')).to('cuda:1')
 
871
  # Self Attention
872
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
873
  hidden_states=hidden_states,
@@ -879,7 +752,9 @@ class YuanDecoderLayer(nn.Module):
879
  output_attentions=output_attentions,
880
  use_cache=use_cache,
881
  )
 
882
  hidden_states = residual + hidden_states.permute(1, 0, 2)
 
883
  # Fully Connected
884
  residual = hidden_states#.to('cuda:1')
885
  torch.cuda.set_device(hidden_states.device)
@@ -1159,14 +1034,10 @@ class YuanModel(YuanPreTrainedModel):
1159
 
1160
  seq_length_with_past = seq_length
1161
  past_key_values_length = 0
 
1162
  if past_key_values is not None:
1163
- #past_key_values_length = past_key_values[0][0].shape[2]
1164
- #modify
1165
- print('0000')
1166
  past_key_values_length = past_key_values[0][0].shape[0]
1167
  seq_length_with_past = seq_length_with_past + past_key_values_length
1168
- else:
1169
- print('1111')
1170
 
1171
  # modify to reset position ids
1172
  if past_key_values is not None:
@@ -1187,6 +1058,7 @@ class YuanModel(YuanPreTrainedModel):
1187
  )
1188
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1189
  else:
 
1190
  pass
1191
 
1192
  if inputs_embeds is None:
@@ -1206,6 +1078,11 @@ class YuanModel(YuanPreTrainedModel):
1206
  #rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
1207
  # Rotary positional embeddings (embedding is None for PP intermediate devices)
1208
  rotary_pos_emb = None
 
 
 
 
 
1209
  rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
1210
 
1211
  hidden_states = inputs_embeds
@@ -1220,8 +1097,8 @@ class YuanModel(YuanPreTrainedModel):
1220
  all_hidden_states = () if output_hidden_states else None
1221
  all_self_attns = () if output_attentions else None
1222
  next_decoder_cache = () if use_cache else None
1223
- #position_ids = position_ids.cpu()
1224
- #position_ids_k = position_ids_k.cpu()
1225
  for idx, decoder_layer in enumerate(self.layers):
1226
  if output_hidden_states:
1227
  all_hidden_states += (hidden_states,)
@@ -1262,9 +1139,8 @@ class YuanModel(YuanPreTrainedModel):
1262
  if output_attentions:
1263
  all_self_attns += (layer_outputs[1],)
1264
  hidden_states = hidden_states#.to('cuda:0')
1265
- #torch.cuda.set_device(hidden_states.device)
1266
  hidden_states = self.norm(hidden_states)
1267
- #print(hidden_states)
1268
  # add hidden states from the last decoder layer
1269
  if output_hidden_states:
1270
  all_hidden_states += (hidden_states,)
@@ -1279,12 +1155,11 @@ class YuanModel(YuanPreTrainedModel):
1279
  )
1280
 
1281
 
1282
- class YuanForCausalLM(YuanPreTrainedModel, GenerationMixin):
1283
  def __init__(self, config):
1284
  super().__init__(config)
1285
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1286
  self.model = YuanModel(config)
1287
- #self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1288
  self.post_init()
1289
 
1290
  def get_input_embeddings(self):
@@ -1425,8 +1300,9 @@ class YuanForCausalLM(YuanPreTrainedModel, GenerationMixin):
1425
  output_hidden_states=output_hidden_states,
1426
  return_dict=return_dict,
1427
  )
 
1428
  hidden_states = outputs[0].transpose(0,1)
1429
- #print(hidden_states)
1430
  logits = self.lm_head(hidden_states)
1431
 
1432
  loss = None
 
1
  # coding=utf-8
2
+ # Copyright 2022 YuanLabAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
26
  from torch import einsum, nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
  from transformers.activations import ACT2FN
 
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
30
  from transformers.modeling_utils import PreTrainedModel
31
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
32
+ from configuration_yuan import YuanConfig
33
  from einops import rearrange
34
  # from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
35
  #from apex.normalization import MixedFusedRMSNorm as RMSNorm
36
  #from flash_attn import flash_attn_func
37
+ from transformer_engine.pytorch import RMSNorm
 
38
  import copy
39
  try:
40
  import grouped_gemm as gg
 
51
 
52
  _CONFIG_FOR_DOC = "YuanConfig"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  class YuanRotaryEmbedding(nn.Module):
56
  def __init__(self, dim, base=10000, dtype=torch.float32, rotary_interleaved=False, seq_len_interpolation_factor=None):
 
108
  )
109
  # emb [seq_length, .., dim]
110
  emb = emb[:, None, None, :]
 
111
  return emb
112
 
113
 
114
  def _rotate_half(x, rotary_interleaved):
 
 
 
 
 
 
115
  if not rotary_interleaved:
116
  x1, x2 = torch.chunk(x, 2, dim=-1)
117
  return torch.cat((-x2, x1), dim=-1)
 
138
 
139
  t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
140
  return torch.cat((t, t_pass), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  return torch.cat((t, t_pass), dim=-1)
143
 
 
227
  lf_output = self.output_layernorm(output2 + residual)
228
 
229
  return lf_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
  def forward(
 
338
  # only on first autoregressive step q,k,v have same seqlen
339
  is_causal = seqlen_q == seqlen_k
340
  cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device)
 
 
341
  dropout_p = 0
342
 
343
  output = flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, dropout_p, softmax_scale=self.softmax_scale, causal=is_causal)
 
470
  self.lf_gate = LocalizedFiltering(self.hidden_size, self.lf_conv2d_group, self.lf_conv2d_num_pad)
471
  self.get_query_key = nn.Linear(self.hidden_size, 2 * self.attention_projection_size, bias=False)
472
  self.core_attention = FlashSelfAttention(causal=True, attention_dropout=config.attn_dropout, softmax_scale=self.softmax_scale)
 
 
473
 
474
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
475
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
485
  output_attentions: bool = False,
486
  use_cache: bool = False,
487
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
488
+
489
  q_len, bsz, _ = hidden_states.size()
490
  hidden_states = hidden_states#.to('cuda:1')
491
  is_first_step = False
 
511
  else:
512
  hidden_states = self.lf_gate(hidden_states, before_hidden_states)
513
  mixed_qk_layer = self.get_query_key(hidden_states)
 
514
  new_tensor_shape = mixed_qk_layer.size()[:-1] + (self.num_heads, 2 * self.head_dim)
515
  mixed_qk_layer = mixed_qk_layer.view(*new_tensor_shape)
516
  (query_states, key_states) = torch.split(mixed_qk_layer, self.head_dim, dim=-1)
 
524
  if rotary_pos_emb is not None:
525
  if position_ids.shape[1] == 1:
526
  q_seq_start = position_ids[0,-1]
 
527
  q_seq_end = q_seq_start + 1
528
  k_seq_end = q_seq_end
529
  else:
 
542
  key_states = torch.cat([past_key_value[0], key_states], dim=0)
543
  value_states = torch.cat([past_key_value[1], value_states], dim=0)
544
  past_key_value = (key_states, value_states, inference_hidden_states_memory) if use_cache else None
 
 
545
  query_states = apply_rotary_pos_emb(query_states, q_pos_emb, position_ids)
546
  key_states = apply_rotary_pos_emb(key_states, k_pos_emb, position_ids_k)
547
 
548
  attn_weights = None
 
 
 
549
  attn_output = self.core_attention(query_states, key_states, value_states)
 
550
  q_len, bsz, _, _ = attn_output.shape
551
  attn_output = attn_output.reshape(q_len, bsz, -1)
552
 
 
637
  return torch.nn.functional.silu(x[0]) * x[1]
638
 
639
  self.activation_func = glu
 
640
  self.ffn_hidden_size = config.ffn_hidden_size
641
  fc1_output_size_per_partition = self.ffn_hidden_size * 2
642
  fc2_input_size = self.ffn_hidden_size
 
646
  def forward(self, permuted_hidden_states, tokens_per_expert):
647
  torch.cuda.set_device(permuted_hidden_states.device)
648
  permuted_hidden_states = permuted_hidden_states#.to('cuda:0')
 
 
 
 
649
 
650
  fc2_outputs = []
651
  start_idx = 0
 
653
  if tokens_per_expert[i] == 0:
654
  continue
655
  end_idx = start_idx + tokens_per_expert[i]
 
656
  # Use custom attributes for each expert's Linear layers
657
 
658
  fc1_output = self.w1[i](permuted_hidden_states[start_idx:end_idx])
 
659
  intermediate_parallel = self.activation_func(fc1_output)
 
660
  fc2_output = self.w2[i](intermediate_parallel)
661
  fc2_outputs.append(fc2_output)
662
  start_idx = end_idx
 
674
 
675
  expert_indices_offset = (0)
676
 
 
677
  self.router = ParallelAttention_router(config)
678
  self.token_dispatcher = MoEDroplessTokenDispatcher(self.num_experts, config=self.config)
679
  self.experts = GroupedMLP(self.num_experts, self.config)
 
685
 
686
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
687
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
688
  logits = self.router(hidden_states)
689
  scores, indices = self.routing(logits)
690
  scores = scores.to(hidden_states.dtype)
 
740
  residual = hidden_states#.to('cuda:1')
741
  torch.cuda.set_device(hidden_states.device)
742
  hidden_states = self.input_layernorm(hidden_states) #.to('cuda:0')).to('cuda:1')
743
+
744
  # Self Attention
745
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
746
  hidden_states=hidden_states,
 
752
  output_attentions=output_attentions,
753
  use_cache=use_cache,
754
  )
755
+
756
  hidden_states = residual + hidden_states.permute(1, 0, 2)
757
+
758
  # Fully Connected
759
  residual = hidden_states#.to('cuda:1')
760
  torch.cuda.set_device(hidden_states.device)
 
1034
 
1035
  seq_length_with_past = seq_length
1036
  past_key_values_length = 0
1037
+
1038
  if past_key_values is not None:
 
 
 
1039
  past_key_values_length = past_key_values[0][0].shape[0]
1040
  seq_length_with_past = seq_length_with_past + past_key_values_length
 
 
1041
 
1042
  # modify to reset position ids
1043
  if past_key_values is not None:
 
1058
  )
1059
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1060
  else:
1061
+ #position_ids = position_ids.view(-1, seq_length).long()
1062
  pass
1063
 
1064
  if inputs_embeds is None:
 
1078
  #rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
1079
  # Rotary positional embeddings (embedding is None for PP intermediate devices)
1080
  rotary_pos_emb = None
1081
+ '''
1082
+ rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
1083
+ transformer_input=inputs_embeds
1084
+ )
1085
+ '''
1086
  rotary_pos_emb = self.rotary_pos_emb(self.max_position_embeddings)
1087
 
1088
  hidden_states = inputs_embeds
 
1097
  all_hidden_states = () if output_hidden_states else None
1098
  all_self_attns = () if output_attentions else None
1099
  next_decoder_cache = () if use_cache else None
1100
+ position_ids = position_ids.cpu()
1101
+ position_ids_k = position_ids_k.cpu()
1102
  for idx, decoder_layer in enumerate(self.layers):
1103
  if output_hidden_states:
1104
  all_hidden_states += (hidden_states,)
 
1139
  if output_attentions:
1140
  all_self_attns += (layer_outputs[1],)
1141
  hidden_states = hidden_states#.to('cuda:0')
1142
+ torch.cuda.set_device(hidden_states.device)
1143
  hidden_states = self.norm(hidden_states)
 
1144
  # add hidden states from the last decoder layer
1145
  if output_hidden_states:
1146
  all_hidden_states += (hidden_states,)
 
1155
  )
1156
 
1157
 
1158
+ class YuanForCausalLM(YuanPreTrainedModel):
1159
  def __init__(self, config):
1160
  super().__init__(config)
1161
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1162
  self.model = YuanModel(config)
 
1163
  self.post_init()
1164
 
1165
  def get_input_embeddings(self):
 
1300
  output_hidden_states=output_hidden_states,
1301
  return_dict=return_dict,
1302
  )
1303
+
1304
  hidden_states = outputs[0].transpose(0,1)
1305
+
1306
  logits = self.lm_head(hidden_states)
1307
 
1308
  loss = None
modeling_yuanvl_chat.py CHANGED
@@ -1,6 +1,6 @@
1
  # --------------------------------------------------------
2
  # YuanVL
3
- # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
@@ -17,10 +17,9 @@ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
17
  LlamaTokenizer)
18
  from transformers.modeling_outputs import CausalLMOutputWithPast
19
  from transformers.modeling_utils import PreTrainedModel
20
- from transformers.generation import GenerationMixin
21
  from transformers.utils import ModelOutput, logging
22
 
23
- #from transformer_engine.pytorch import RMSNorm
24
  from transformers.activations import ACT2FN
25
 
26
  from .configuration_yuanvl import YuanVLChatConfig
@@ -31,22 +30,6 @@ from .utils import flatten_bn, merge_multimodal_embeddings
31
 
32
  logger = logging.get_logger(__name__)
33
 
34
- class RMSNorm(torch.nn.Module):
35
- def __init__(self, hidden_size, eps=1e-6):
36
- super().__init__()
37
- self.weight = torch.nn.Parameter(torch.ones(hidden_size))
38
- self.variance_epsilon = eps
39
-
40
- def forward(self, hidden_states):
41
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
42
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
43
-
44
- # convert into half-precision if necessary
45
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
46
- hidden_states = hidden_states.to(self.weight.dtype)
47
-
48
- return self.weight * hidden_states
49
-
50
  class InternVLImagePixelInputs(TypedDict):
51
  type: Literal["pixel_values"]
52
  data: Union[torch.Tensor, List[torch.Tensor]]
@@ -94,6 +77,9 @@ class YuanImageMLP(nn.Module):
94
  hidden_act: str,
95
  ) -> None:
96
  super().__init__()
 
 
 
97
  self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
98
  self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
99
  self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
@@ -108,13 +94,16 @@ class YuanImageMLP(nn.Module):
108
  return self.act_fn(y_1) * y_2
109
 
110
  def forward(self, x):
 
111
  x1 = self.up_proj(x)
112
  x2 = self.gate_proj(x)
113
  x3 = self.swiglu(x1, x2)
 
 
114
  x = self.down_proj(x3)
115
  return x
116
 
117
- class YuanVLChatModel(PreTrainedModel, GenerationMixin):
118
  config_class = YuanVLChatConfig
119
  main_input_name = 'pixel_values'
120
  base_model_prefix = 'language_model'
@@ -152,6 +141,10 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
152
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
153
 
154
  self.pixel_unshuffle = torch.nn.PixelUnshuffle(downscale_factor=2)
 
 
 
 
155
  layernorm_epsilon = config.llm_config.rms_norm_eps
156
 
157
  self.imagemlp_input_hiddensize = int(config.vision_config.hidden_size / self.downsample_ratio ** 2)
@@ -161,6 +154,18 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
161
  output_size=config.llm_config.hidden_size, hidden_act="silu")
162
  self.imagemlp_layernorm = RMSNorm(config.llm_config.hidden_size, eps=layernorm_epsilon)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  self.img_context_token_id = config.img_context_token_id
165
  self.conv_template = get_conv_template(self.template)
166
  self.system_message = self.conv_template.system_message
@@ -235,6 +240,7 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
235
  assert self.vision_model is not None
236
  # (total_patches, tokens_per_image, llm_config.hidden_size)
237
  image_embeds = self.extract_feature(image_input["data"])
 
238
  patches_per_image = image_input["patches_per_image"]
239
 
240
  # Only one image in the current batch
@@ -289,7 +295,7 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
289
  input_ids, inputs_embeds, multimodal_embeddings,
290
  self.img_context_token_id)
291
  return inputs_embeds
292
-
293
  def forward(
294
  self,
295
  input_ids: torch.LongTensor = None,
@@ -306,21 +312,22 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
306
  image_token_id: Optional[List[torch.Tensor]] = None,
307
  image_embeds: Optional[List[torch.Tensor]] = None,
308
  ) -> Union[Tuple, CausalLMOutputWithPast]:
309
-
 
 
310
  if inputs_embeds is None:
311
  # (images, patches * token_per_image)
312
  vision_embeddings = self.get_multimodal_embeddings(pixel_values, image_token_id, image_embeds)
313
  # (tokens, hidden_size)
314
- if input_ids is not None:
315
- vision_embeddings = vision_embeddings.to(input_ids.device)
316
- inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) #.permute(1, 0, 2)
317
  input_ids = None
318
 
319
  hidden_states = self.language_model.model(input_ids, attention_mask, position_ids, past_key_values,
320
  inputs_embeds, labels, use_cache, output_attentions,
321
  output_hidden_states, return_dict)
 
322
  return hidden_states
323
-
324
  def pixel_shuffle(self, x, scale_factor=0.5):
325
  n, w, h, c = x.size()
326
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -380,8 +387,6 @@ class YuanVLChatModel(PreTrainedModel, GenerationMixin):
380
  vit_embeds = visual_features
381
  else:
382
  vit_embeds = self.get_multimodal_embeddings(pixel_values)
383
- if input_ids is not None:
384
- vit_embeds = vit_embeds.to(input_ids.device)
385
  inputs_embeds = self.get_input_embeddings(input_ids, vit_embeds)
386
  input_ids = None
387
 
 
1
  # --------------------------------------------------------
2
  # YuanVL
3
+ # Copyright (c) 2024 YuanLabAI
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
 
17
  LlamaTokenizer)
18
  from transformers.modeling_outputs import CausalLMOutputWithPast
19
  from transformers.modeling_utils import PreTrainedModel
 
20
  from transformers.utils import ModelOutput, logging
21
 
22
+ from transformer_engine.pytorch import RMSNorm
23
  from transformers.activations import ACT2FN
24
 
25
  from .configuration_yuanvl import YuanVLChatConfig
 
30
 
31
  logger = logging.get_logger(__name__)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class InternVLImagePixelInputs(TypedDict):
34
  type: Literal["pixel_values"]
35
  data: Union[torch.Tensor, List[torch.Tensor]]
 
77
  hidden_act: str,
78
  ) -> None:
79
  super().__init__()
80
+ #self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, bias=False,)
81
+ #self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, bias=False,)
82
+ #self.down_proj = RowParallelLinear(intermediate_size, output_size, bias=False,)
83
  self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
84
  self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
85
  self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
 
94
  return self.act_fn(y_1) * y_2
95
 
96
  def forward(self, x):
97
+ #import pdb
98
  x1 = self.up_proj(x)
99
  x2 = self.gate_proj(x)
100
  x3 = self.swiglu(x1, x2)
101
+ #x3 = self.act_fn(x1)
102
+ #x2 = self.gate_proj(x)
103
  x = self.down_proj(x3)
104
  return x
105
 
106
+ class YuanVLChatModel(PreTrainedModel):
107
  config_class = YuanVLChatConfig
108
  main_input_name = 'pixel_values'
109
  base_model_prefix = 'language_model'
 
141
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
142
 
143
  self.pixel_unshuffle = torch.nn.PixelUnshuffle(downscale_factor=2)
144
+ #vit_hidden_size = config.vision_config.hidden_size
145
+ #llm_hidden_size = config.llm_config.hidden_size
146
+ #vit_mlp_ffn_hidden_size = config.vit_mlp_ffn_hidden_size
147
+ #layernorm_epsilon = config.llm_config.layernorm_epsilon
148
  layernorm_epsilon = config.llm_config.rms_norm_eps
149
 
150
  self.imagemlp_input_hiddensize = int(config.vision_config.hidden_size / self.downsample_ratio ** 2)
 
154
  output_size=config.llm_config.hidden_size, hidden_act="silu")
155
  self.imagemlp_layernorm = RMSNorm(config.llm_config.hidden_size, eps=layernorm_epsilon)
156
 
157
+ '''
158
+ # modify internvl vision
159
+ vit_hidden_size = config.vision_config.hidden_size
160
+ llm_hidden_size = config.llm_config.hidden_size
161
+ self.mlp1 = nn.Sequential(
162
+ nn.LayerNorm(vit_hidden_size * int(1/self.downsample_ratio) ** 2),
163
+ nn.Linear(vit_hidden_size * int(1/self.downsample_ratio) ** 2, llm_hidden_size),
164
+ nn.GELU(),
165
+ nn.Linear(llm_hidden_size, llm_hidden_size)
166
+ )
167
+ '''
168
+
169
  self.img_context_token_id = config.img_context_token_id
170
  self.conv_template = get_conv_template(self.template)
171
  self.system_message = self.conv_template.system_message
 
240
  assert self.vision_model is not None
241
  # (total_patches, tokens_per_image, llm_config.hidden_size)
242
  image_embeds = self.extract_feature(image_input["data"])
243
+
244
  patches_per_image = image_input["patches_per_image"]
245
 
246
  # Only one image in the current batch
 
295
  input_ids, inputs_embeds, multimodal_embeddings,
296
  self.img_context_token_id)
297
  return inputs_embeds
298
+
299
  def forward(
300
  self,
301
  input_ids: torch.LongTensor = None,
 
312
  image_token_id: Optional[List[torch.Tensor]] = None,
313
  image_embeds: Optional[List[torch.Tensor]] = None,
314
  ) -> Union[Tuple, CausalLMOutputWithPast]:
315
+
316
+ import pdb
317
+ pdb.set_trace()
318
  if inputs_embeds is None:
319
  # (images, patches * token_per_image)
320
  vision_embeddings = self.get_multimodal_embeddings(pixel_values, image_token_id, image_embeds)
321
  # (tokens, hidden_size)
322
+ inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings).permute(1, 0, 2)
 
 
323
  input_ids = None
324
 
325
  hidden_states = self.language_model.model(input_ids, attention_mask, position_ids, past_key_values,
326
  inputs_embeds, labels, use_cache, output_attentions,
327
  output_hidden_states, return_dict)
328
+
329
  return hidden_states
330
+
331
  def pixel_shuffle(self, x, scale_factor=0.5):
332
  n, w, h, c = x.size()
333
  # N, W, H, C --> N, W, H * scale, C // scale
 
387
  vit_embeds = visual_features
388
  else:
389
  vit_embeds = self.get_multimodal_embeddings(pixel_values)
 
 
390
  inputs_embeds = self.get_input_embeddings(input_ids, vit_embeds)
391
  input_ids = None
392