revise generate
Browse files- modeling.py +91 -10
modeling.py
CHANGED
|
@@ -6,6 +6,7 @@ from torch import nn
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from functools import partial
|
| 8 |
|
|
|
|
| 9 |
from transformers.activations import ACT2FN
|
| 10 |
from transformers.cache_utils import Cache, DynamicCache
|
| 11 |
from transformers.generation import GenerationMixin
|
|
@@ -643,7 +644,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 643 |
loss=loss,
|
| 644 |
logits=logits,
|
| 645 |
past_key_values=outputs.past_key_values,
|
| 646 |
-
hidden_states=
|
| 647 |
attentions=outputs.attentions,
|
| 648 |
block_past_key_values=outputs.block_past_key_values,
|
| 649 |
)
|
|
@@ -652,7 +653,9 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 652 |
def generate(
|
| 653 |
self,
|
| 654 |
input_ids,
|
| 655 |
-
max_new_tokens,
|
|
|
|
|
|
|
| 656 |
mask_id=151665,
|
| 657 |
threshold=1,
|
| 658 |
small_block_size=8,
|
|
@@ -662,14 +665,36 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 662 |
top_p=0.95,
|
| 663 |
temperature=0,
|
| 664 |
use_block_cache=False,
|
|
|
|
|
|
|
|
|
|
| 665 |
**kwargs
|
| 666 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
num_blocks = max_new_tokens // block_size
|
| 668 |
original_input_length = input_ids.shape[1]
|
| 669 |
|
| 670 |
if input_ids.shape[1] > block_size:
|
| 671 |
-
output = self.forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
logits, past_key_values = output.logits, output.past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
if input_ids.shape[1] % block_size == 0:
|
| 674 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 675 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
@@ -683,30 +708,51 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 683 |
break
|
| 684 |
prompt_length = input_ids.shape[1]
|
| 685 |
# Initialize x_init with mask_id
|
| 686 |
-
x_init = mask_id * torch.ones(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 688 |
|
| 689 |
x_t = x_init.clone()
|
| 690 |
block_past_key_values = None
|
|
|
|
| 691 |
while True:
|
| 692 |
if stop_token in x_t[:, prompt_length:]:
|
| 693 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 694 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 695 |
break
|
| 696 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
|
|
|
| 697 |
# Decode a complete block, update cache, and generate the next token
|
| 698 |
if mask_idx.sum() == 0:
|
| 699 |
-
output = self.forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
logits, past_key_values = output.logits, output.past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 702 |
x_t = torch.cat([x_t, next_token], dim=1)
|
| 703 |
break
|
|
|
|
| 704 |
for small_block_idx in range(num_small_blocks):
|
| 705 |
small_block_start_idx = small_block_idx * small_block_size
|
| 706 |
small_block_end_idx = small_block_start_idx + small_block_size
|
| 707 |
|
| 708 |
start = -block_size + small_block_start_idx
|
| 709 |
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
|
|
|
| 710 |
while True:
|
| 711 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 712 |
if mask_idx[:, start:end].sum() == 0:
|
|
@@ -718,18 +764,43 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 718 |
|
| 719 |
if use_block_cache:
|
| 720 |
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 721 |
-
output = self.forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 723 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 724 |
logits = logits[:, start:end]
|
| 725 |
else:
|
| 726 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 728 |
else:
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 731 |
logits = logits[:, start:end]
|
| 732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
|
| 734 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 735 |
# Select tokens with probability greater than threshold from p_1t
|
|
@@ -744,11 +815,21 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 744 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 745 |
|
| 746 |
input_ids = x_t
|
|
|
|
| 747 |
# Truncate stop_token
|
| 748 |
if stop_token in input_ids[:, original_input_length:]:
|
| 749 |
stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
|
| 750 |
input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
|
| 751 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
|
| 753 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
| 754 |
# Calculate probabilities
|
|
@@ -782,4 +863,4 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 782 |
p_1t = normalized_probs
|
| 783 |
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
|
| 784 |
|
| 785 |
-
return x_1, p_1t
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from functools import partial
|
| 8 |
|
| 9 |
+
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
| 10 |
from transformers.activations import ACT2FN
|
| 11 |
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
from transformers.generation import GenerationMixin
|
|
|
|
| 644 |
loss=loss,
|
| 645 |
logits=logits,
|
| 646 |
past_key_values=outputs.past_key_values,
|
| 647 |
+
hidden_states=hidden_states,
|
| 648 |
attentions=outputs.attentions,
|
| 649 |
block_past_key_values=outputs.block_past_key_values,
|
| 650 |
)
|
|
|
|
| 653 |
def generate(
|
| 654 |
self,
|
| 655 |
input_ids,
|
| 656 |
+
max_new_tokens=None,
|
| 657 |
+
max_length=None,
|
| 658 |
+
tokenizer=None,
|
| 659 |
mask_id=151665,
|
| 660 |
threshold=1,
|
| 661 |
small_block_size=8,
|
|
|
|
| 665 |
top_p=0.95,
|
| 666 |
temperature=0,
|
| 667 |
use_block_cache=False,
|
| 668 |
+
return_dict_in_generate=False,
|
| 669 |
+
output_scores=False,
|
| 670 |
+
output_hidden_states=False,
|
| 671 |
**kwargs
|
| 672 |
):
|
| 673 |
+
if max_new_tokens is None and max_length is None:
|
| 674 |
+
raise ValueError("Either max_new_tokens or max_length must be specified")
|
| 675 |
+
if max_new_tokens is None:
|
| 676 |
+
max_new_tokens = max_length - input_ids.shape[1]
|
| 677 |
+
|
| 678 |
+
scores_list = [] if output_scores else None
|
| 679 |
+
decoder_hidden_states = [] if output_hidden_states else None
|
| 680 |
+
|
| 681 |
num_blocks = max_new_tokens // block_size
|
| 682 |
original_input_length = input_ids.shape[1]
|
| 683 |
|
| 684 |
if input_ids.shape[1] > block_size:
|
| 685 |
+
output = self.forward(
|
| 686 |
+
input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)],
|
| 687 |
+
use_cache=True,
|
| 688 |
+
update_past_key_values=True,
|
| 689 |
+
block_size=block_size
|
| 690 |
+
)
|
| 691 |
logits, past_key_values = output.logits, output.past_key_values
|
| 692 |
+
|
| 693 |
+
if output_scores:
|
| 694 |
+
scores_list.append(logits)
|
| 695 |
+
if output_hidden_states and hasattr(output, 'hidden_states'):
|
| 696 |
+
decoder_hidden_states.append(output.hidden_states)
|
| 697 |
+
|
| 698 |
if input_ids.shape[1] % block_size == 0:
|
| 699 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 700 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
| 708 |
break
|
| 709 |
prompt_length = input_ids.shape[1]
|
| 710 |
# Initialize x_init with mask_id
|
| 711 |
+
x_init = mask_id * torch.ones(
|
| 712 |
+
(input_ids.shape[0], block_size-prompt_length%block_size),
|
| 713 |
+
device=self.device,
|
| 714 |
+
dtype=torch.long
|
| 715 |
+
)
|
| 716 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 717 |
|
| 718 |
x_t = x_init.clone()
|
| 719 |
block_past_key_values = None
|
| 720 |
+
|
| 721 |
while True:
|
| 722 |
if stop_token in x_t[:, prompt_length:]:
|
| 723 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 724 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 725 |
break
|
| 726 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 727 |
+
|
| 728 |
# Decode a complete block, update cache, and generate the next token
|
| 729 |
if mask_idx.sum() == 0:
|
| 730 |
+
output = self.forward(
|
| 731 |
+
input_ids=x_t[:, -block_size:],
|
| 732 |
+
use_cache=True,
|
| 733 |
+
past_key_values=past_key_values,
|
| 734 |
+
update_past_key_values=True,
|
| 735 |
+
block_size=block_size
|
| 736 |
+
)
|
| 737 |
logits, past_key_values = output.logits, output.past_key_values
|
| 738 |
+
|
| 739 |
+
# 收集输出信息
|
| 740 |
+
if output_scores:
|
| 741 |
+
scores_list.append(logits)
|
| 742 |
+
if output_hidden_states and hasattr(output, 'hidden_states'):
|
| 743 |
+
decoder_hidden_states.append(output.hidden_states)
|
| 744 |
+
|
| 745 |
next_token = logits[:, -1:, :].argmax(dim=-1)
|
| 746 |
x_t = torch.cat([x_t, next_token], dim=1)
|
| 747 |
break
|
| 748 |
+
|
| 749 |
for small_block_idx in range(num_small_blocks):
|
| 750 |
small_block_start_idx = small_block_idx * small_block_size
|
| 751 |
small_block_end_idx = small_block_start_idx + small_block_size
|
| 752 |
|
| 753 |
start = -block_size + small_block_start_idx
|
| 754 |
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
| 755 |
+
|
| 756 |
while True:
|
| 757 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 758 |
if mask_idx[:, start:end].sum() == 0:
|
|
|
|
| 764 |
|
| 765 |
if use_block_cache:
|
| 766 |
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 767 |
+
output = self.forward(
|
| 768 |
+
input_ids=x_t[:, -block_size:],
|
| 769 |
+
use_cache=True,
|
| 770 |
+
past_key_values=past_key_values,
|
| 771 |
+
update_past_key_values=False,
|
| 772 |
+
use_block_cache=True,
|
| 773 |
+
)
|
| 774 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 775 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 776 |
logits = logits[:, start:end]
|
| 777 |
else:
|
| 778 |
+
output = self.forward(
|
| 779 |
+
input_ids=x_t[:,start:end],
|
| 780 |
+
use_cache=True,
|
| 781 |
+
past_key_values=past_key_values,
|
| 782 |
+
update_past_key_values=False,
|
| 783 |
+
use_block_cache=True,
|
| 784 |
+
block_past_key_values=block_past_key_values,
|
| 785 |
+
replace_position=small_block_start_idx
|
| 786 |
+
)
|
| 787 |
+
logits = output.logits
|
| 788 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 789 |
else:
|
| 790 |
+
output = self.forward(
|
| 791 |
+
input_ids=x_t[:, -block_size:],
|
| 792 |
+
use_cache=True,
|
| 793 |
+
past_key_values=past_key_values,
|
| 794 |
+
update_past_key_values=False
|
| 795 |
+
)
|
| 796 |
+
logits = output.logits
|
| 797 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 798 |
logits = logits[:, start:end]
|
| 799 |
|
| 800 |
+
if output_scores:
|
| 801 |
+
scores_list.append(logits)
|
| 802 |
+
if output_hidden_states and hasattr(output, 'hidden_states'):
|
| 803 |
+
decoder_hidden_states.append(output.hidden_states)
|
| 804 |
|
| 805 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 806 |
# Select tokens with probability greater than threshold from p_1t
|
|
|
|
| 815 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 816 |
|
| 817 |
input_ids = x_t
|
| 818 |
+
|
| 819 |
# Truncate stop_token
|
| 820 |
if stop_token in input_ids[:, original_input_length:]:
|
| 821 |
stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
|
| 822 |
input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
|
| 823 |
+
|
| 824 |
+
if return_dict_in_generate:
|
| 825 |
+
return GenerateDecoderOnlyOutput(
|
| 826 |
+
sequences=input_ids,
|
| 827 |
+
scores=tuple(scores_list) if output_scores and scores_list else None,
|
| 828 |
+
hidden_states=tuple(decoder_hidden_states) if output_hidden_states and decoder_hidden_states else None,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
return input_ids
|
| 832 |
+
|
| 833 |
|
| 834 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
| 835 |
# Calculate probabilities
|
|
|
|
| 863 |
p_1t = normalized_probs
|
| 864 |
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
|
| 865 |
|
| 866 |
+
return x_1, p_1t
|