Chengyue Wu commited on
Commit ·
0f41374
1
Parent(s): e7a6fc6
add block cache
Browse files- modeling.py +38 -19
modeling.py
CHANGED
|
@@ -554,6 +554,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 554 |
stopping_criteria=None,
|
| 555 |
top_p=0.95,
|
| 556 |
temperature=0,
|
|
|
|
|
|
|
| 557 |
**kwargs
|
| 558 |
):
|
| 559 |
num_blocks = max_new_tokens // block_size
|
|
@@ -574,18 +576,20 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 574 |
if stop_token in input_ids[:, original_input_length:]:
|
| 575 |
break
|
| 576 |
prompt_length = input_ids.shape[1]
|
| 577 |
-
#
|
| 578 |
x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
|
| 579 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 580 |
|
| 581 |
x_t = x_init.clone()
|
|
|
|
|
|
|
| 582 |
while True:
|
| 583 |
if stop_token in x_t[:, prompt_length:]:
|
| 584 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 585 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 586 |
break
|
| 587 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 588 |
-
#
|
| 589 |
if mask_idx.sum() == 0:
|
| 590 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
|
| 591 |
logits, past_key_values = output.logits, output.past_key_values
|
|
@@ -595,43 +599,59 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 595 |
for small_block_idx in range(num_small_blocks):
|
| 596 |
small_block_start_idx = small_block_idx * small_block_size
|
| 597 |
small_block_end_idx = small_block_start_idx + small_block_size
|
|
|
|
|
|
|
|
|
|
| 598 |
while True:
|
| 599 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 600 |
-
if mask_idx[:,
|
| 601 |
break
|
| 602 |
if stop_token in x_t[:, prompt_length:]:
|
| 603 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 604 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 605 |
break
|
| 606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 608 |
-
#
|
| 609 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 610 |
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
|
| 611 |
-
|
| 612 |
unmask_idx = (x1_p > threshold)
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
-
|
| 615 |
-
x_t[:, -block_size:][unmask_idx] = x_1[unmask_idx]
|
| 616 |
-
else:
|
| 617 |
-
# 选出p_1t中概率最大的那一个token
|
| 618 |
-
token_position = x1_p.argmax()
|
| 619 |
-
x_t[:, -block_size:][0, token_position] = x_1[0, token_position]
|
| 620 |
|
|
|
|
| 621 |
input_ids = x_t
|
| 622 |
-
#
|
| 623 |
if stop_token in input_ids[:, original_input_length:]:
|
| 624 |
stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
|
| 625 |
input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
|
| 626 |
return input_ids
|
| 627 |
|
| 628 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
| 629 |
-
#
|
| 630 |
if temperature > 0:
|
| 631 |
scaled_logits = logits / temperature
|
| 632 |
else:
|
| 633 |
p_1t = torch.softmax(logits, dim=-1)
|
| 634 |
-
p_1t = torch.cat([p_1t[:, :1, :], p_1t[:, :-1, :]], dim=1)
|
| 635 |
x_1 = p_1t.argmax(dim=-1)
|
| 636 |
return x_1, p_1t
|
| 637 |
|
|
@@ -650,13 +670,12 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 650 |
|
| 651 |
probs[indices_to_remove] = 0
|
| 652 |
|
| 653 |
-
#
|
| 654 |
-
#
|
| 655 |
-
# 添加一个极小值 eps 防止除以零
|
| 656 |
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
|
| 657 |
normalized_probs = probs / probs_sum
|
| 658 |
|
| 659 |
-
p_1t =
|
| 660 |
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
|
| 661 |
|
| 662 |
return x_1, p_1t
|
|
|
|
| 554 |
stopping_criteria=None,
|
| 555 |
top_p=0.95,
|
| 556 |
temperature=0,
|
| 557 |
+
use_block_cache=False,
|
| 558 |
+
block_cache_refresh_interval=16,
|
| 559 |
**kwargs
|
| 560 |
):
|
| 561 |
num_blocks = max_new_tokens // block_size
|
|
|
|
| 576 |
if stop_token in input_ids[:, original_input_length:]:
|
| 577 |
break
|
| 578 |
prompt_length = input_ids.shape[1]
|
| 579 |
+
# Initialize x_init with mask_id
|
| 580 |
x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
|
| 581 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 582 |
|
| 583 |
x_t = x_init.clone()
|
| 584 |
+
step = 0
|
| 585 |
+
block_past_key_values = None
|
| 586 |
while True:
|
| 587 |
if stop_token in x_t[:, prompt_length:]:
|
| 588 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 589 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 590 |
break
|
| 591 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 592 |
+
# Decode a complete block, update cache, and generate the next token
|
| 593 |
if mask_idx.sum() == 0:
|
| 594 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
|
| 595 |
logits, past_key_values = output.logits, output.past_key_values
|
|
|
|
| 599 |
for small_block_idx in range(num_small_blocks):
|
| 600 |
small_block_start_idx = small_block_idx * small_block_size
|
| 601 |
small_block_end_idx = small_block_start_idx + small_block_size
|
| 602 |
+
|
| 603 |
+
start = -block_size + small_block_start_idx
|
| 604 |
+
end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
|
| 605 |
while True:
|
| 606 |
mask_idx = (x_t[:, -block_size:] == mask_id)
|
| 607 |
+
if mask_idx[:, start:end].sum() == 0:
|
| 608 |
break
|
| 609 |
if stop_token in x_t[:, prompt_length:]:
|
| 610 |
stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
|
| 611 |
if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
|
| 612 |
break
|
| 613 |
+
|
| 614 |
+
if use_block_cache:
|
| 615 |
+
if step % block_cache_refresh_interval == 0 or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 616 |
+
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 617 |
+
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 618 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 619 |
+
logits = logits[:, start:end]
|
| 620 |
+
else:
|
| 621 |
+
logits = self.forward(input_ids=x_t[:, -block_size+small_block_start_idx:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
|
| 622 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 623 |
+
else:
|
| 624 |
+
logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
|
| 625 |
+
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
| 626 |
+
logits = logits[:, start:end]
|
| 627 |
+
|
| 628 |
+
|
| 629 |
x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
|
| 630 |
+
# Select tokens with probability greater than threshold from p_1t
|
| 631 |
x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
|
| 632 |
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
|
| 633 |
+
|
| 634 |
unmask_idx = (x1_p > threshold)
|
| 635 |
+
max_prob_idx = x1_p.argmax(dim=-1)
|
| 636 |
+
unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
|
| 637 |
+
unmask_idx = unmask_idx & mask_idx[:, start:end]
|
| 638 |
|
| 639 |
+
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
+
step += 1
|
| 642 |
input_ids = x_t
|
| 643 |
+
# Truncate stop_token
|
| 644 |
if stop_token in input_ids[:, original_input_length:]:
|
| 645 |
stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
|
| 646 |
input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
|
| 647 |
return input_ids
|
| 648 |
|
| 649 |
def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
|
| 650 |
+
# Calculate probabilities
|
| 651 |
if temperature > 0:
|
| 652 |
scaled_logits = logits / temperature
|
| 653 |
else:
|
| 654 |
p_1t = torch.softmax(logits, dim=-1)
|
|
|
|
| 655 |
x_1 = p_1t.argmax(dim=-1)
|
| 656 |
return x_1, p_1t
|
| 657 |
|
|
|
|
| 670 |
|
| 671 |
probs[indices_to_remove] = 0
|
| 672 |
|
| 673 |
+
# Renormalize so that the probabilities of remaining tokens sum to 1
|
| 674 |
+
# Add a small epsilon value to prevent division by zero
|
|
|
|
| 675 |
probs_sum = torch.sum(probs, dim=-1, keepdim=True)
|
| 676 |
normalized_probs = probs / probs_sum
|
| 677 |
|
| 678 |
+
p_1t = normalized_probs
|
| 679 |
x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
|
| 680 |
|
| 681 |
return x_1, p_1t
|