voidrank commited on
Commit
0661abf
·
1 Parent(s): 200e3ef

revise generate

Browse files
Files changed (1) hide show
  1. modeling.py +91 -10
modeling.py CHANGED
@@ -6,6 +6,7 @@ from torch import nn
6
  import torch.nn.functional as F
7
  from functools import partial
8
 
 
9
  from transformers.activations import ACT2FN
10
  from transformers.cache_utils import Cache, DynamicCache
11
  from transformers.generation import GenerationMixin
@@ -643,7 +644,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
643
  loss=loss,
644
  logits=logits,
645
  past_key_values=outputs.past_key_values,
646
- hidden_states=outputs.hidden_states,
647
  attentions=outputs.attentions,
648
  block_past_key_values=outputs.block_past_key_values,
649
  )
@@ -652,7 +653,9 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
652
  def generate(
653
  self,
654
  input_ids,
655
- max_new_tokens,
 
 
656
  mask_id=151665,
657
  threshold=1,
658
  small_block_size=8,
@@ -662,14 +665,36 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
662
  top_p=0.95,
663
  temperature=0,
664
  use_block_cache=False,
 
 
 
665
  **kwargs
666
  ):
 
 
 
 
 
 
 
 
667
  num_blocks = max_new_tokens // block_size
668
  original_input_length = input_ids.shape[1]
669
 
670
  if input_ids.shape[1] > block_size:
671
- output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
 
 
 
 
 
672
  logits, past_key_values = output.logits, output.past_key_values
 
 
 
 
 
 
673
  if input_ids.shape[1] % block_size == 0:
674
  next_token = logits[:, -1:, :].argmax(dim=-1)
675
  input_ids = torch.cat([input_ids, next_token], dim=1)
@@ -683,30 +708,51 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
683
  break
684
  prompt_length = input_ids.shape[1]
685
  # Initialize x_init with mask_id
686
- x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
 
 
 
 
687
  x_init = torch.cat([input_ids, x_init], dim=1)
688
 
689
  x_t = x_init.clone()
690
  block_past_key_values = None
 
691
  while True:
692
  if stop_token in x_t[:, prompt_length:]:
693
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
694
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
695
  break
696
  mask_idx = (x_t[:, -block_size:] == mask_id)
 
697
  # Decode a complete block, update cache, and generate the next token
698
  if mask_idx.sum() == 0:
699
- output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
 
 
 
 
 
 
700
  logits, past_key_values = output.logits, output.past_key_values
 
 
 
 
 
 
 
701
  next_token = logits[:, -1:, :].argmax(dim=-1)
702
  x_t = torch.cat([x_t, next_token], dim=1)
703
  break
 
704
  for small_block_idx in range(num_small_blocks):
705
  small_block_start_idx = small_block_idx * small_block_size
706
  small_block_end_idx = small_block_start_idx + small_block_size
707
 
708
  start = -block_size + small_block_start_idx
709
  end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
 
710
  while True:
711
  mask_idx = (x_t[:, -block_size:] == mask_id)
712
  if mask_idx[:, start:end].sum() == 0:
@@ -718,18 +764,43 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
718
 
719
  if use_block_cache:
720
  if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
721
- output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
 
 
 
 
 
 
722
  logits, block_past_key_values = output.logits, output.block_past_key_values
723
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
724
  logits = logits[:, start:end]
725
  else:
726
- logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
 
 
 
 
 
 
 
 
 
727
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
728
  else:
729
- logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
 
 
 
 
 
 
730
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
731
  logits = logits[:, start:end]
732
 
 
 
 
 
733
 
734
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
735
  # Select tokens with probability greater than threshold from p_1t
@@ -744,11 +815,21 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
744
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
745
 
746
  input_ids = x_t
 
747
  # Truncate stop_token
748
  if stop_token in input_ids[:, original_input_length:]:
749
  stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
750
  input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
751
- return input_ids
 
 
 
 
 
 
 
 
 
752
 
753
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
754
  # Calculate probabilities
@@ -782,4 +863,4 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
782
  p_1t = normalized_probs
783
  x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
784
 
785
- return x_1, p_1t
 
6
  import torch.nn.functional as F
7
  from functools import partial
8
 
9
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
10
  from transformers.activations import ACT2FN
11
  from transformers.cache_utils import Cache, DynamicCache
12
  from transformers.generation import GenerationMixin
 
644
  loss=loss,
645
  logits=logits,
646
  past_key_values=outputs.past_key_values,
647
+ hidden_states=hidden_states,
648
  attentions=outputs.attentions,
649
  block_past_key_values=outputs.block_past_key_values,
650
  )
 
653
  def generate(
654
  self,
655
  input_ids,
656
+ max_new_tokens=None,
657
+ max_length=None,
658
+ tokenizer=None,
659
  mask_id=151665,
660
  threshold=1,
661
  small_block_size=8,
 
665
  top_p=0.95,
666
  temperature=0,
667
  use_block_cache=False,
668
+ return_dict_in_generate=False,
669
+ output_scores=False,
670
+ output_hidden_states=False,
671
  **kwargs
672
  ):
673
+ if max_new_tokens is None and max_length is None:
674
+ raise ValueError("Either max_new_tokens or max_length must be specified")
675
+ if max_new_tokens is None:
676
+ max_new_tokens = max_length - input_ids.shape[1]
677
+
678
+ scores_list = [] if output_scores else None
679
+ decoder_hidden_states = [] if output_hidden_states else None
680
+
681
  num_blocks = max_new_tokens // block_size
682
  original_input_length = input_ids.shape[1]
683
 
684
  if input_ids.shape[1] > block_size:
685
+ output = self.forward(
686
+ input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)],
687
+ use_cache=True,
688
+ update_past_key_values=True,
689
+ block_size=block_size
690
+ )
691
  logits, past_key_values = output.logits, output.past_key_values
692
+
693
+ if output_scores:
694
+ scores_list.append(logits)
695
+ if output_hidden_states and hasattr(output, 'hidden_states'):
696
+ decoder_hidden_states.append(output.hidden_states)
697
+
698
  if input_ids.shape[1] % block_size == 0:
699
  next_token = logits[:, -1:, :].argmax(dim=-1)
700
  input_ids = torch.cat([input_ids, next_token], dim=1)
 
708
  break
709
  prompt_length = input_ids.shape[1]
710
  # Initialize x_init with mask_id
711
+ x_init = mask_id * torch.ones(
712
+ (input_ids.shape[0], block_size-prompt_length%block_size),
713
+ device=self.device,
714
+ dtype=torch.long
715
+ )
716
  x_init = torch.cat([input_ids, x_init], dim=1)
717
 
718
  x_t = x_init.clone()
719
  block_past_key_values = None
720
+
721
  while True:
722
  if stop_token in x_t[:, prompt_length:]:
723
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
724
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
725
  break
726
  mask_idx = (x_t[:, -block_size:] == mask_id)
727
+
728
  # Decode a complete block, update cache, and generate the next token
729
  if mask_idx.sum() == 0:
730
+ output = self.forward(
731
+ input_ids=x_t[:, -block_size:],
732
+ use_cache=True,
733
+ past_key_values=past_key_values,
734
+ update_past_key_values=True,
735
+ block_size=block_size
736
+ )
737
  logits, past_key_values = output.logits, output.past_key_values
738
+
739
+ # 收集输出信息
740
+ if output_scores:
741
+ scores_list.append(logits)
742
+ if output_hidden_states and hasattr(output, 'hidden_states'):
743
+ decoder_hidden_states.append(output.hidden_states)
744
+
745
  next_token = logits[:, -1:, :].argmax(dim=-1)
746
  x_t = torch.cat([x_t, next_token], dim=1)
747
  break
748
+
749
  for small_block_idx in range(num_small_blocks):
750
  small_block_start_idx = small_block_idx * small_block_size
751
  small_block_end_idx = small_block_start_idx + small_block_size
752
 
753
  start = -block_size + small_block_start_idx
754
  end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
755
+
756
  while True:
757
  mask_idx = (x_t[:, -block_size:] == mask_id)
758
  if mask_idx[:, start:end].sum() == 0:
 
764
 
765
  if use_block_cache:
766
  if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
767
+ output = self.forward(
768
+ input_ids=x_t[:, -block_size:],
769
+ use_cache=True,
770
+ past_key_values=past_key_values,
771
+ update_past_key_values=False,
772
+ use_block_cache=True,
773
+ )
774
  logits, block_past_key_values = output.logits, output.block_past_key_values
775
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
776
  logits = logits[:, start:end]
777
  else:
778
+ output = self.forward(
779
+ input_ids=x_t[:,start:end],
780
+ use_cache=True,
781
+ past_key_values=past_key_values,
782
+ update_past_key_values=False,
783
+ use_block_cache=True,
784
+ block_past_key_values=block_past_key_values,
785
+ replace_position=small_block_start_idx
786
+ )
787
+ logits = output.logits
788
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
789
  else:
790
+ output = self.forward(
791
+ input_ids=x_t[:, -block_size:],
792
+ use_cache=True,
793
+ past_key_values=past_key_values,
794
+ update_past_key_values=False
795
+ )
796
+ logits = output.logits
797
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
798
  logits = logits[:, start:end]
799
 
800
+ if output_scores:
801
+ scores_list.append(logits)
802
+ if output_hidden_states and hasattr(output, 'hidden_states'):
803
+ decoder_hidden_states.append(output.hidden_states)
804
 
805
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
806
  # Select tokens with probability greater than threshold from p_1t
 
815
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
816
 
817
  input_ids = x_t
818
+
819
  # Truncate stop_token
820
  if stop_token in input_ids[:, original_input_length:]:
821
  stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
822
  input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
823
+
824
+ if return_dict_in_generate:
825
+ return GenerateDecoderOnlyOutput(
826
+ sequences=input_ids,
827
+ scores=tuple(scores_list) if output_scores and scores_list else None,
828
+ hidden_states=tuple(decoder_hidden_states) if output_hidden_states and decoder_hidden_states else None,
829
+ )
830
+ else:
831
+ return input_ids
832
+
833
 
834
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
835
  # Calculate probabilities
 
863
  p_1t = normalized_probs
864
  x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
865
 
866
+ return x_1, p_1t