fredzzp commited on
Commit
ed3972d
·
verified ·
1 Parent(s): e267107

Initial model upload with custom code

Browse files
Files changed (1) hide show
  1. modeling_qwen2.py +76 -291
modeling_qwen2.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
2
  # Copyright 2025 Bytedance Ltd. and/or its affiliates
3
  #
@@ -38,20 +43,14 @@ from transformers.utils import (
38
  replace_return_docstrings,
39
  )
40
 
41
- from veomni.models.transformers.qwen2.generation_utils import MDMGenerationMixin
42
 
43
- from ....data.constants import IGNORE_INDEX
44
- from ....distributed.parallel_state import get_parallel_state
45
- from ....distributed.sequence_parallel import (
46
  gather_heads_scatter_seq,
47
  gather_seq_scatter_heads,
48
  reduce_sequence_parallel_loss,
49
  )
50
- from ....utils import logging
51
- from ....utils.import_utils import is_liger_kernel_available
52
 
53
 
54
- if is_liger_kernel_available():
55
  from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # type: ignore
56
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
57
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
@@ -183,7 +182,7 @@ class Qwen2Attention(nn.Module):
183
  query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
184
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
185
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
186
- if get_parallel_state().ulysses_enabled:
187
  query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
188
  key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
189
  value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
@@ -229,7 +228,7 @@ class Qwen2Attention(nn.Module):
229
  )
230
 
231
  attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
232
- if get_parallel_state().ulysses_enabled:
233
  attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
234
 
235
  attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size).contiguous()
@@ -533,115 +532,79 @@ class Qwen2Model(Qwen2PreTrainedModel):
533
  def set_input_embeddings(self, value):
534
  self.embed_tokens = value
535
 
 
536
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
 
537
  def forward(
538
  self,
539
  input_ids: torch.LongTensor = None,
540
  attention_mask: Optional[torch.Tensor] = None,
541
  position_ids: Optional[torch.LongTensor] = None,
542
- past_key_values: Optional[Cache] = None,
543
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
544
  use_cache: Optional[bool] = None,
545
  output_attentions: Optional[bool] = None,
546
  output_hidden_states: Optional[bool] = None,
547
  return_dict: Optional[bool] = None,
548
  cache_position: Optional[torch.LongTensor] = None,
549
  is_causal: bool = True,
550
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
551
- ) -> Union[Tuple, BaseModelOutputWithPast]:
 
 
 
 
 
 
 
552
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
553
  output_hidden_states = (
554
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
555
  )
556
- use_cache = use_cache if use_cache is not None else self.config.use_cache
557
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
558
 
559
- if (input_ids is None) ^ (inputs_embeds is not None):
560
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
561
-
562
- if self.gradient_checkpointing and self.training and use_cache:
563
- logger.warning_once(
564
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
565
- )
566
- use_cache = False
567
-
568
- if inputs_embeds is None:
569
- inputs_embeds = self.embed_tokens(input_ids)
570
-
571
- if use_cache and past_key_values is None:
572
- past_key_values = DynamicCache()
573
-
574
- if cache_position is None:
575
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
576
- cache_position = torch.arange(
577
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
578
- )
579
-
580
- if position_ids is None:
581
- position_ids = cache_position.unsqueeze(0)
582
-
583
- causal_mask = self._update_causal_mask(
584
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
585
  )
586
 
587
- hidden_states = inputs_embeds
588
-
589
- # create position embeddings to be shared across the decoder layers
590
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
591
-
592
- # decoder layers
593
- all_hidden_states = () if output_hidden_states else None
594
- all_self_attns = () if output_attentions else None
595
-
596
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
597
- if output_hidden_states:
598
- all_hidden_states += (hidden_states,)
599
-
600
- if self.gradient_checkpointing and self.training:
601
- layer_outputs = self._gradient_checkpointing_func(
602
- decoder_layer.__call__,
603
- hidden_states,
604
- causal_mask,
605
- position_ids,
606
- past_key_values,
607
- output_attentions,
608
- use_cache,
609
- cache_position,
610
- position_embeddings,
611
- is_causal,
612
- )
613
- else:
614
- layer_outputs = decoder_layer(
615
- hidden_states,
616
- attention_mask=causal_mask,
617
- position_ids=position_ids,
618
- past_key_value=past_key_values,
619
- output_attentions=output_attentions,
620
- use_cache=use_cache,
621
- cache_position=cache_position,
622
- position_embeddings=position_embeddings,
623
- is_causal=is_causal,
624
- **flash_attn_kwargs,
625
- )
626
-
627
- hidden_states = layer_outputs[0]
628
-
629
- if output_attentions:
630
- all_self_attns += (layer_outputs[1],)
631
-
632
- hidden_states = self.norm(hidden_states)
633
 
634
- # add hidden states from the last decoder layer
635
- if output_hidden_states:
636
- all_hidden_states += (hidden_states,)
637
 
638
- output = BaseModelOutputWithPast(
639
- last_hidden_state=hidden_states,
640
- past_key_values=past_key_values if use_cache else None,
641
- hidden_states=all_hidden_states,
642
- attentions=all_self_attns,
 
643
  )
644
- return output if return_dict else output.to_tuple()
645
 
646
  def _update_causal_mask(
647
  self,
@@ -799,7 +762,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
799
  class KwargsForCausalLM(FlashAttentionKwargs, ): ...
800
 
801
 
802
- class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
803
  _tied_weights_keys = ["lm_head.weight"]
804
  _tp_plan = {"lm_head": "colwise_rep"}
805
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@@ -831,6 +794,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
831
  def get_decoder(self):
832
  return self.model
833
 
 
834
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
835
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
836
  def forward(
@@ -841,76 +805,27 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
841
  past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
842
  inputs_embeds: Optional[torch.FloatTensor] = None,
843
  labels: Optional[torch.LongTensor] = None,
844
- mask_ratio: Optional[torch.FloatTensor]=None,
845
  use_cache: Optional[bool] = None,
846
  output_attentions: Optional[bool] = None,
847
  output_hidden_states: Optional[bool] = None,
848
  return_dict: Optional[bool] = None,
849
  cache_position: Optional[torch.LongTensor] = None,
850
- logits_to_keep: Union[int, torch.Tensor] = 0,
851
  is_causal: bool = True,
852
- **kwargs: Unpack[KwargsForCausalLM],
853
  ) -> Union[Tuple, CausalLMOutputWithPast]:
854
- r"""
855
  Args:
856
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
857
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
858
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
859
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
860
-
861
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
862
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
863
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
864
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
865
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
866
- This is useful when using packed tensor format (single dimension for batch and sequence length).
867
-
868
- Returns:
869
-
870
- Example:
871
-
872
- ```python
873
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
874
-
875
- >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
876
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
877
-
878
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
879
- >>> inputs = tokenizer(prompt, return_tensors="pt")
880
-
881
- >>> # Generate
882
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
883
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
884
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
885
- ```"""
886
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
887
  output_hidden_states = (
888
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
889
  )
890
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
891
 
892
- if not get_parallel_state().sp_enabled and labels is not None:
893
- # Shift so that tokens < n predict n
894
- labels = labels[..., 1:].contiguous()
895
- labels = labels.view(-1)
896
- if (
897
- position_ids is not None
898
- and position_ids.size(0) == 1
899
- and not (torch.diff(position_ids, dim=-1) >= 0).all()
900
- ):
901
- position_ids_ = position_ids.flatten()
902
- indices_q = torch.arange(position_ids_.size(0), device=position_ids_.device, dtype=torch.int32)
903
- cu_seq_lens = torch.cat(
904
- (
905
- indices_q[position_ids_ == 0],
906
- torch.tensor(position_ids_.size(), device=position_ids_.device, dtype=torch.int32),
907
- )
908
- )
909
- labels[cu_seq_lens[1:-1] - 1] = IGNORE_INDEX
910
- if mask_ratio is not None:
911
- is_causal = False
912
- mask_ratio = mask_ratio[..., 1:].contiguous()
913
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
914
  outputs = self.model(
915
  input_ids=input_ids,
916
  attention_mask=attention_mask,
@@ -927,64 +842,19 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
927
  )
928
 
929
  hidden_states = outputs[0]
930
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
931
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
932
- hidden_states = hidden_states[:, slice_indices, :]
933
-
934
  loss = None
935
- logits = None
936
  if labels is not None:
937
- labels = labels.view(-1) # flatten label
938
- if is_liger_kernel_available():
939
- if mask_ratio is not None:
940
- loss_fct = LigerFusedLinearCrossEntropyLoss(reduction="none",ignore_index=IGNORE_INDEX)
941
- if not get_parallel_state().sp_enabled:
942
- # Shift so that tokens < n predict n
943
- hidden_states = hidden_states[..., :-1, :].contiguous()
944
- loss = loss_fct(
945
- self.lm_head.weight,
946
- hidden_states.view(-1, self.config.hidden_size),
947
- labels
948
- )
949
- path_loss = (-loss).exp().detach() * loss
950
- loss = loss + path_loss
951
- loss_mask = labels != IGNORE_INDEX
952
- loss = (loss * loss_mask * (1/mask_ratio)).sum() / (loss_mask.sum() + 1e-8)
953
- else:
954
- loss_fct = LigerFusedLinearCrossEntropyLoss(reduction="mean")
955
- if not get_parallel_state().sp_enabled:
956
- # Shift so that tokens < n predict n
957
- hidden_states = hidden_states[..., :-1, :].contiguous()
958
- hidden_states = hidden_states.view(-1, self.config.hidden_size)
959
- loss = loss_fct(self.lm_head.weight, hidden_states, labels)
960
- else:
961
- if mask_ratio is not None:
962
- loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
963
- logits = self.lm_head(hidden_states)
964
- logits = logits.view(-1, self.vocab_size)
965
- loss = loss_fct(logits, labels.view(-1))
966
- path_loss = (-loss).exp().detach() * loss
967
- loss = loss + path_loss
968
- loss_mask = labels != IGNORE_INDEX
969
- loss = (loss * loss_mask * (1/mask_ratio)).sum() / (loss_mask.sum() + 1e-8)
970
- else:
971
- loss_fct = torch.nn.CrossEntropyLoss(reduction="mean")
972
- logits = self.lm_head(hidden_states)
973
- # Upcast to float if we need to compute the loss to avoid potential precision issues
974
- logits = logits.float()
975
- if not get_parallel_state().sp_enabled:
976
- # Shift so that tokens < n predict n
977
- logits = logits[..., :-1, :].contiguous()
978
-
979
- # Flatten the tokens
980
- logits = logits.view(-1, self.vocab_size)
981
- loss = loss_fct(logits, labels)
982
-
983
- if get_parallel_state().sp_enabled:
984
- num_valid_tokens = (labels != IGNORE_INDEX).sum()
985
- loss = reduce_sequence_parallel_loss(loss, num_valid_tokens)
986
- else:
987
- logits = self.lm_head(hidden_states)
988
 
989
  if not return_dict:
990
  output = (logits,) + outputs[1:]
@@ -997,88 +867,3 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
997
  hidden_states=outputs.hidden_states,
998
  attentions=outputs.attentions,
999
  )
1000
-
1001
-
1002
-
1003
-
1004
- import torch
1005
- from tqdm import tqdm
1006
- from typing import Callable, Tuple, Any
1007
-
1008
-
1009
- def topk_masking(scores: torch.Tensor, cutoff_len: torch.Tensor, mode: str = "lowest") -> torch.Tensor:
1010
- """Generate a mask selecting the top-k lowest or highest elements per row."""
1011
- sorted_scores = scores.sort(dim=-1, descending=(mode == "highest")).values
1012
- cutoff = sorted_scores.gather(dim=-1, index=cutoff_len)
1013
- return (scores >= cutoff) if mode == "highest" else (scores < cutoff)
1014
-
1015
-
1016
- def sample_categorical(
1017
- logits: torch.Tensor, temperature: float = 1.0, noise_scale: float = 1.0
1018
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1019
- """
1020
- Sample from a categorical distribution with optional Gumbel noise.
1021
- Returns sampled tokens, their scores, and the noised logits.
1022
- """
1023
- logits = logits.to(torch.float64)
1024
- if temperature > 0:
1025
- gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
1026
- logits = logits / temperature + noise_scale * gumbel_noise
1027
- log_probs = logits.log_softmax(dim=-1)
1028
- scores, tokens = log_probs.max(dim=-1)
1029
- return tokens, scores.to(logits.dtype), logits.to(logits.dtype)
1030
-
1031
-
1032
- @torch.inference_mode()
1033
- @torch.amp.autocast(device_type="cuda", dtype=torch.float16)
1034
- def p2_sampling(
1035
- xt: torch.Tensor,
1036
- model: Any,
1037
- mask_id: int,
1038
- num_steps: int,
1039
- tau: float = 1.0,
1040
- kappa_fn: Callable[[float], float] = lambda t: t,
1041
- eta: float = 1.0,
1042
- **kwargs
1043
- ) -> torch.Tensor:
1044
- """
1045
- P2 Sampling implementation for discrete diffusion models.
1046
- Reference: https://arxiv.org/pdf/2502.03540
1047
- """
1048
- dt = 1 / num_steps
1049
- fix_mask = (xt != mask_id)
1050
-
1051
- for i in tqdm(range(1, num_steps + 1)):
1052
- t = i * dt
1053
- kappa_t = kappa_fn(t)
1054
-
1055
- logits = model(xt).double()
1056
- last_mask = (xt == mask_id)
1057
- unmask_t = ~last_mask & ~fix_mask
1058
-
1059
- x0, score, _ = sample_categorical(logits, temperature=tau)
1060
- score = score.masked_fill(fix_mask, float("inf"))
1061
- score[unmask_t] *= eta
1062
-
1063
- num_to_mask = ((~fix_mask).sum(dim=1, keepdim=True).float() * (1 - kappa_t)).long()
1064
- to_mask = topk_masking(score, num_to_mask, mode="lowest")
1065
-
1066
- xt[to_mask] = mask_id
1067
- mask_2_x0 = last_mask & ~to_mask
1068
- xt[mask_2_x0] = x0[mask_2_x0]
1069
-
1070
- xt[xt == mask_id] = x0[xt == mask_id]
1071
- return xt
1072
-
1073
-
1074
-
1075
- if is_liger_kernel_available():
1076
- apply_rotary_pos_emb = liger_rotary_pos_emb
1077
- Qwen2RMSNorm = LigerRMSNorm
1078
- Qwen2MLP = LigerSwiGLUMLP
1079
- logger.info_rank0("Apply liger kernel to Qwen2.")
1080
-
1081
-
1082
- ModelClass = Qwen2ForCausalLM
1083
-
1084
- __all__ = ["Qwen2ForCausalLM", "Qwen2Model", "Qwen2PreTrainedModel"]
 
1
+ import logging
2
+ from transformers import GenerationMixin
3
+ import torch
4
+ from typing import Optional, Union, List
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
  # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
7
  # Copyright 2025 Bytedance Ltd. and/or its affiliates
8
  #
 
43
  replace_return_docstrings,
44
  )
45
 
 
46
 
 
 
 
47
  gather_heads_scatter_seq,
48
  gather_seq_scatter_heads,
49
  reduce_sequence_parallel_loss,
50
  )
 
 
51
 
52
 
53
+ if False:
54
  from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # type: ignore
55
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
56
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
 
182
  query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
183
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
184
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
185
+ if False:
186
  query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
187
  key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
188
  value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
 
228
  )
229
 
230
  attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
231
+ if False:
232
  attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
233
 
234
  attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size).contiguous()
 
532
  def set_input_embeddings(self, value):
533
  self.embed_tokens = value
534
 
535
+
536
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
537
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
538
  def forward(
539
  self,
540
  input_ids: torch.LongTensor = None,
541
  attention_mask: Optional[torch.Tensor] = None,
542
  position_ids: Optional[torch.LongTensor] = None,
543
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
544
  inputs_embeds: Optional[torch.FloatTensor] = None,
545
+ labels: Optional[torch.LongTensor] = None,
546
  use_cache: Optional[bool] = None,
547
  output_attentions: Optional[bool] = None,
548
  output_hidden_states: Optional[bool] = None,
549
  return_dict: Optional[bool] = None,
550
  cache_position: Optional[torch.LongTensor] = None,
551
  is_causal: bool = True,
552
+ **kwargs,
553
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
554
+ r\"\"\"
555
+ Args:
556
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
557
+ Labels for computing the masked language modeling loss. Indices should be in `[0, ...,
558
+ config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
559
+ computed for the tokens with labels in `[0, ..., config.vocab_size - 1]`.
560
+ \"\"\"
561
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
562
  output_hidden_states = (
563
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
564
  )
 
565
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
566
 
567
+ outputs = self.model(
568
+ input_ids=input_ids,
569
+ attention_mask=attention_mask,
570
+ position_ids=position_ids,
571
+ past_key_values=past_key_values,
572
+ inputs_embeds=inputs_embeds,
573
+ use_cache=use_cache,
574
+ output_attentions=output_attentions,
575
+ output_hidden_states=output_hidden_states,
576
+ return_dict=return_dict,
577
+ cache_position=cache_position,
578
+ is_causal=is_causal,
579
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  )
581
 
582
+ hidden_states = outputs[0]
583
+ logits = self.lm_head(hidden_states)
584
+ logits = logits.float()
585
+ loss = None
586
+
587
+ if labels is not None:
588
+ # Maintained for compatibility with Trainer API, but not essential for pure inference
589
+ shift_logits = logits[..., :-1, :].contiguous()
590
+ shift_labels = labels[..., 1:].contiguous()
591
+ loss_fct = torch.nn.CrossEntropyLoss()
592
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
593
+ shift_labels = shift_labels.view(-1)
594
+ shift_labels = shift_labels.to(shift_logits.device)
595
+ loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
+ if not return_dict:
598
+ output = (logits,) + outputs[1:]
599
+ return (loss,) + output if loss is not None else output
600
 
601
+ return CausalLMOutputWithPast(
602
+ loss=loss,
603
+ logits=logits,
604
+ past_key_values=outputs.past_key_values,
605
+ hidden_states=outputs.hidden_states,
606
+ attentions=outputs.attentions,
607
  )
 
608
 
609
  def _update_causal_mask(
610
  self,
 
762
  class KwargsForCausalLM(FlashAttentionKwargs, ): ...
763
 
764
 
765
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
766
  _tied_weights_keys = ["lm_head.weight"]
767
  _tp_plan = {"lm_head": "colwise_rep"}
768
  _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
 
794
  def get_decoder(self):
795
  return self.model
796
 
797
+
798
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
799
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
800
  def forward(
 
805
  past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
806
  inputs_embeds: Optional[torch.FloatTensor] = None,
807
  labels: Optional[torch.LongTensor] = None,
 
808
  use_cache: Optional[bool] = None,
809
  output_attentions: Optional[bool] = None,
810
  output_hidden_states: Optional[bool] = None,
811
  return_dict: Optional[bool] = None,
812
  cache_position: Optional[torch.LongTensor] = None,
 
813
  is_causal: bool = True,
814
+ **kwargs,
815
  ) -> Union[Tuple, CausalLMOutputWithPast]:
816
+ r\"\"\"
817
  Args:
818
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
819
+ Labels for computing the masked language modeling loss. Indices should be in `[0, ...,
820
+ config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
821
+ computed for the tokens with labels in `[0, ..., config.vocab_size - 1]`.
822
+ \"\"\"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
824
  output_hidden_states = (
825
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
826
  )
827
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  outputs = self.model(
830
  input_ids=input_ids,
831
  attention_mask=attention_mask,
 
842
  )
843
 
844
  hidden_states = outputs[0]
845
+ logits = self.lm_head(hidden_states)
846
+ logits = logits.float()
 
 
847
  loss = None
848
+
849
  if labels is not None:
850
+ # Maintained for compatibility with Trainer API, but not essential for pure inference
851
+ shift_logits = logits[..., :-1, :].contiguous()
852
+ shift_labels = labels[..., 1:].contiguous()
853
+ loss_fct = torch.nn.CrossEntropyLoss()
854
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
855
+ shift_labels = shift_labels.view(-1)
856
+ shift_labels = shift_labels.to(shift_logits.device)
857
+ loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
 
859
  if not return_dict:
860
  output = (logits,) + outputs[1:]
 
867
  hidden_states=outputs.hidden_states,
868
  attentions=outputs.attentions,
869
  )