if001 commited on
Commit
b9d23dd
·
verified ·
1 Parent(s): 8a8bd55

fix generate method

Browse files
Files changed (1) hide show
  1. modeling_residualnet.py +48 -117
modeling_residualnet.py CHANGED
@@ -586,123 +586,54 @@ class ResidualNetForCausalLM(Phi3PreTrainedModel, GenerationMixin):
586
  def base_model(self):
587
  return self.model
588
 
589
- @torch.no_grad()
590
- def generate(self, *args, **kwargs):
591
- return super().generate(*args, **kwargs, custom_generate=window3_generate)
592
-
593
-
594
- def window3_generate(
595
- model,
596
- input_ids: torch.LongTensor,
597
- logits_processor,
598
- stopping_criteria,
599
- generation_config,
600
- synced_gpus,
601
- streamer=None,
602
- **model_kwargs,
603
- ):
604
- """
605
- 要件:
606
- - i 番目の生成は 0..i-1 の文脈(通常のprefill)
607
- - i+1 生成時の入力は [i-2, i-1, i]、使用する KV は 0..i-2
608
- - i+2 生成時の入力は [i-1, i, i+1]、使用する KV は 0..i-1
609
- 実現方法:
610
- - Cache は in-place 更新。KV の「どこに書くか」は cache_position で制御
611
- - マスク/RoPE 整合のため past_kv_len = cache_position[0]
612
- 依存:
613
- - model.forward (past_key_values=Cache, cache_position=LongTensor) を受け付ける
614
- - GenerationMixin.generate() が step 7 で Cache を model_kwargs["past_key_values"] に用意済み
615
- """
616
- device = input_ids.device
617
- batch_size = input_ids.size(0)
618
- assert (
619
- batch_size == 1
620
- ), "window3_decode はまず単一バッチで運用してください(拡張は容易)"
621
-
622
- # 必須前提:use_cache=True(generate() がすでに設定)
623
- model_kwargs["use_cache"] = True
624
-
625
- # Cache 取得(デフォルトのキーは "past_key_values")
626
- cache = model_kwargs.get("past_key_values", None)
627
- if cache is None:
628
- # HFの既定では _prepare_cache_for_generation がここを必ず埋めます
629
- raise RuntimeError(
630
- "past_key_values (Cache) が見つかりません。generate() の step 7 で設定されている必要があります。"
631
- )
632
-
633
- # ---- 1) prefill: 通常の全文脈で i を生成 ----
634
- # cache_position は [0..L0-1]
635
- seq_len0 = input_ids.size(1)
636
- cache_pos = torch.arange(0, seq_len0, device=device) # [L0]
637
- model_kwargs["cache_position"] = cache_pos
638
- # attention_mask は 1 埋め(左パディング運用ならそのまま 0/1 を渡す)
639
- if "attention_mask" not in model_kwargs or model_kwargs["attention_mask"] is None:
640
- model_kwargs["attention_mask"] = torch.ones_like(input_ids, device=device)
641
-
642
- outputs = model(
643
- input_ids=input_ids,
644
- **model_kwargs,
645
- )
646
- # logits -> processors -> next token(ここは greedy。sampling は必要に応じて拡張)
647
- next_token_logits = outputs.logits[:, -1, :]
648
- next_token_scores = logits_processor(input_ids, next_token_logits)
649
- next_tokens = torch.argmax(next_token_scores, dim=-1, keepdim=True) # [1,1]
650
-
651
- if streamer is not None:
652
- streamer.put(next_tokens.cpu())
653
-
654
- sequences = torch.cat([input_ids, next_tokens], dim=1) # 0..i
655
- # cur_len = sequences.size(1)
656
-
657
- # ---- 2) 以降: 毎回 3 トークン窓で前進、KV は “2つ前まで” を可視に ----
658
- # stopping_criteria は generate() 側で組み立て済み
659
- N = 3
660
- while True:
661
- # 停止判定(EOS, max_length, 任意の criteria)
662
- if stopping_criteria(sequences, None):
663
- break
664
- if sequences.size(1) >= generation_config.max_length:
665
- break
666
-
667
- # 直近 index t(直前に確定した末尾)
668
- t = sequences.size(1) - 1 # i, i+1, ...
669
-
670
- # KV を使わせる過去長 keep_len = t-2(= i+1 生成時に i-2 まで)
671
- keep_len = max(0, t - N - 1)
672
-
673
- # 入力は直近3トークン(不足時は短くなるのでそのまま)
674
- window = sequences[:, -N:] if sequences.size(1) >= N else sequences
675
- # この窓を書き込む位置を明示: [keep_len .. keep_len+len(window)-1]
676
  Lw = window.size(1)
677
- cache_pos = torch.arange(keep_len, keep_len + Lw, device=device) # [Lw]
678
- model_kwargs["cache_position"] = cache_pos
679
- model_kwargs["attention_mask"] = torch.ones_like(window, device=device)
680
-
681
- # 前進
682
- outputs = model(
683
- input_ids=window,
684
- **model_kwargs, # past_key_values は同じ Cache(in-place 更新)
685
- )
686
- next_token_logits = outputs.logits[:, -1, :]
687
- next_token_scores = logits_processor(sequences, next_token_logits)
688
-
689
- # greedy(必要に応じて sampling を追加)
690
- next_tokens = torch.argmax(next_token_scores, dim=-1, keepdim=True) # [1,1]
691
-
692
- sequences = torch.cat([sequences, next_tokens], dim=1)
693
-
694
- if streamer is not None:
695
- streamer.put(next_tokens.cpu())
696
 
697
- if streamer is not None:
698
- streamer.end()
699
 
700
- # return_dict_in_generate を尊重(最低限の互換)
701
- if generation_config.return_dict_in_generate:
702
- return GenerateDecoderOnlyOutput(
703
- sequences=sequences,
704
- scores=None,
705
- attentions=None,
706
- hidden_states=None,
707
- )
708
- return sequences
 
586
  def base_model(self):
587
  return self.model
588
 
589
+ def prepare_inputs_for_generation(
590
+ self,
591
+ input_ids: torch.LongTensor,
592
+ past_key_values: Optional["Cache"] = None,
593
+ attention_mask: Optional[torch.LongTensor] = None,
594
+ use_cache: Optional[bool] = None,
595
+ **kwargs,
596
+ ):
597
+ """
598
+ - prefill(past_key_values None)では全文をそのまま渡し、cache_position = [0..L-1]
599
+ - decode(past_key_values あり)では直近3トークンだけを入力し、
600
+ cache_position = [keep_len .. keep_len+Lw-1] とする
601
+ * t = 直前に確定した末尾 index(= input_ids.size(1)-1)
602
+ * keep_len = max(0, t-2)
603
+ - attention_mask は簡単のため 1 埋め(左PAD運用ならそのまま attention_mask を流用)
604
+ """
605
+ device = input_ids.device
606
+ use_cache = True if use_cache is None else use_cache
607
+
608
+ if past_key_values is None or not use_cache:
609
+ # ---- prefill: 全文 ----
610
+ L = input_ids.size(1)
611
+ cache_position = torch.arange(0, L, device=device) # [L]
612
+ if attention_mask is None:
613
+ attention_mask = torch.ones_like(input_ids, device=device)
614
+ return {
615
+ "input_ids": input_ids,
616
+ "attention_mask": attention_mask,
617
+ "past_key_values": past_key_values,
618
+ "use_cache": use_cache,
619
+ "cache_position": cache_position, # ★ [L]
620
+ }
621
+
622
+ # ---- decode: 3トークン窓 ----
623
+ # input_ids は「これまでの全文」なので、ここで直近3に切る
624
+ window = input_ids[:, -3:] # [B, Lw<=3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  Lw = window.size(1)
626
+ t = input_ids.size(1) - 1 # i, i+1, ...
627
+ keep_len = max(0, t - 2) # 仕様: “2つ前まで”を過去KVとして可視
628
+ cache_position = torch.arange(keep_len, keep_len + Lw, device=device) # ★ [Lw]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
+ # attention_mask は窓に合わせる(左PADを使っているなら適宜スライスに置換)
631
+ attn = torch.ones_like(window, device=device)
632
 
633
+ return {
634
+ "input_ids": window,
635
+ "attention_mask": attn,
636
+ "past_key_values": past_key_values, # Cache(in-place更新)
637
+ "use_cache": use_cache,
638
+ "cache_position": cache_position, # ★ [Lw]
639
+ }