Chengyue Wu commited on
Commit
0f41374
·
1 Parent(s): e7a6fc6

add block cache

Browse files
Files changed (1) hide show
  1. modeling.py +38 -19
modeling.py CHANGED
@@ -554,6 +554,8 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
554
  stopping_criteria=None,
555
  top_p=0.95,
556
  temperature=0,
 
 
557
  **kwargs
558
  ):
559
  num_blocks = max_new_tokens // block_size
@@ -574,18 +576,20 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
574
  if stop_token in input_ids[:, original_input_length:]:
575
  break
576
  prompt_length = input_ids.shape[1]
577
- # 初始化x_initmask_id
578
  x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
579
  x_init = torch.cat([input_ids, x_init], dim=1)
580
 
581
  x_t = x_init.clone()
 
 
582
  while True:
583
  if stop_token in x_t[:, prompt_length:]:
584
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
585
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
586
  break
587
  mask_idx = (x_t[:, -block_size:] == mask_id)
588
- # 解码完整的一个block,更新cache,并且生成下一个token
589
  if mask_idx.sum() == 0:
590
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
591
  logits, past_key_values = output.logits, output.past_key_values
@@ -595,43 +599,59 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
595
  for small_block_idx in range(num_small_blocks):
596
  small_block_start_idx = small_block_idx * small_block_size
597
  small_block_end_idx = small_block_start_idx + small_block_size
 
 
 
598
  while True:
599
  mask_idx = (x_t[:, -block_size:] == mask_id)
600
- if mask_idx[:, small_block_start_idx:small_block_end_idx].sum() == 0:
601
  break
602
  if stop_token in x_t[:, prompt_length:]:
603
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
604
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
605
  break
606
- logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, block_size=block_size).logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
608
- # 选出p_1t中概率大于threshold的token
609
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
610
  x1_p = torch.where(mask_idx, x1_p, -torch.inf)
611
- x1_p[:, small_block_end_idx:] = -torch.inf
612
  unmask_idx = (x1_p > threshold)
 
 
 
613
 
614
- if unmask_idx.sum() > 0:
615
- x_t[:, -block_size:][unmask_idx] = x_1[unmask_idx]
616
- else:
617
- # 选出p_1t中概率最大的那一个token
618
- token_position = x1_p.argmax()
619
- x_t[:, -block_size:][0, token_position] = x_1[0, token_position]
620
 
 
621
  input_ids = x_t
622
- # 截断stop_token
623
  if stop_token in input_ids[:, original_input_length:]:
624
  stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
625
  input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
626
  return input_ids
627
 
628
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
629
- # 计算概率
630
  if temperature > 0:
631
  scaled_logits = logits / temperature
632
  else:
633
  p_1t = torch.softmax(logits, dim=-1)
634
- p_1t = torch.cat([p_1t[:, :1, :], p_1t[:, :-1, :]], dim=1)
635
  x_1 = p_1t.argmax(dim=-1)
636
  return x_1, p_1t
637
 
@@ -650,13 +670,12 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
650
 
651
  probs[indices_to_remove] = 0
652
 
653
- # 3. 重新归一化并采样
654
- # 重新归一化,使得剩余 token 的概率和为 1
655
- # 添加一个极小值 eps 防止除以零
656
  probs_sum = torch.sum(probs, dim=-1, keepdim=True)
657
  normalized_probs = probs / probs_sum
658
 
659
- p_1t = torch.cat([normalized_probs[:, :1, :], normalized_probs[:, :-1, :]], dim=1)
660
  x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
661
 
662
  return x_1, p_1t
 
554
  stopping_criteria=None,
555
  top_p=0.95,
556
  temperature=0,
557
+ use_block_cache=False,
558
+ block_cache_refresh_interval=16,
559
  **kwargs
560
  ):
561
  num_blocks = max_new_tokens // block_size
 
576
  if stop_token in input_ids[:, original_input_length:]:
577
  break
578
  prompt_length = input_ids.shape[1]
579
+ # Initialize x_init with mask_id
580
  x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
581
  x_init = torch.cat([input_ids, x_init], dim=1)
582
 
583
  x_t = x_init.clone()
584
+ step = 0
585
+ block_past_key_values = None
586
  while True:
587
  if stop_token in x_t[:, prompt_length:]:
588
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
589
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
590
  break
591
  mask_idx = (x_t[:, -block_size:] == mask_id)
592
+ # Decode a complete block, update cache, and generate the next token
593
  if mask_idx.sum() == 0:
594
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
595
  logits, past_key_values = output.logits, output.past_key_values
 
599
  for small_block_idx in range(num_small_blocks):
600
  small_block_start_idx = small_block_idx * small_block_size
601
  small_block_end_idx = small_block_start_idx + small_block_size
602
+
603
+ start = -block_size + small_block_start_idx
604
+ end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
605
  while True:
606
  mask_idx = (x_t[:, -block_size:] == mask_id)
607
+ if mask_idx[:, start:end].sum() == 0:
608
  break
609
  if stop_token in x_t[:, prompt_length:]:
610
  stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
611
  if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
612
  break
613
+
614
+ if use_block_cache:
615
+ if step % block_cache_refresh_interval == 0 or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
616
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
617
+ logits, block_past_key_values = output.logits, output.block_past_key_values
618
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
619
+ logits = logits[:, start:end]
620
+ else:
621
+ logits = self.forward(input_ids=x_t[:, -block_size+small_block_start_idx:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
622
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
623
+ else:
624
+ logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
625
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
626
+ logits = logits[:, start:end]
627
+
628
+
629
  x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
630
+ # Select tokens with probability greater than threshold from p_1t
631
  x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
632
  x1_p = torch.where(mask_idx, x1_p, -torch.inf)
633
+
634
  unmask_idx = (x1_p > threshold)
635
+ max_prob_idx = x1_p.argmax(dim=-1)
636
+ unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
637
+ unmask_idx = unmask_idx & mask_idx[:, start:end]
638
 
639
+ x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
 
 
 
 
 
640
 
641
+ step += 1
642
  input_ids = x_t
643
+ # Truncate stop_token
644
  if stop_token in input_ids[:, original_input_length:]:
645
  stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
646
  input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
647
  return input_ids
648
 
649
  def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
650
+ # Calculate probabilities
651
  if temperature > 0:
652
  scaled_logits = logits / temperature
653
  else:
654
  p_1t = torch.softmax(logits, dim=-1)
 
655
  x_1 = p_1t.argmax(dim=-1)
656
  return x_1, p_1t
657
 
 
670
 
671
  probs[indices_to_remove] = 0
672
 
673
+ # Renormalize so that the probabilities of remaining tokens sum to 1
674
+ # Add a small epsilon value to prevent division by zero
 
675
  probs_sum = torch.sum(probs, dim=-1, keepdim=True)
676
  normalized_probs = probs / probs_sum
677
 
678
+ p_1t = normalized_probs
679
  x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
680
 
681
  return x_1, p_1t