fix generate method
Browse files- modeling_residualnet.py +48 -117
modeling_residualnet.py
CHANGED
|
@@ -586,123 +586,54 @@ class ResidualNetForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
|
| 586 |
def base_model(self):
|
| 587 |
return self.model
|
| 588 |
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
# Cache 取得(デフォルトのキーは "past_key_values")
|
| 626 |
-
cache = model_kwargs.get("past_key_values", None)
|
| 627 |
-
if cache is None:
|
| 628 |
-
# HFの既定では _prepare_cache_for_generation がここを必ず埋めます
|
| 629 |
-
raise RuntimeError(
|
| 630 |
-
"past_key_values (Cache) が見つかりません。generate() の step 7 で設定されている必要があります。"
|
| 631 |
-
)
|
| 632 |
-
|
| 633 |
-
# ---- 1) prefill: 通常の全文脈で i を生成 ----
|
| 634 |
-
# cache_position は [0..L0-1]
|
| 635 |
-
seq_len0 = input_ids.size(1)
|
| 636 |
-
cache_pos = torch.arange(0, seq_len0, device=device) # [L0]
|
| 637 |
-
model_kwargs["cache_position"] = cache_pos
|
| 638 |
-
# attention_mask は 1 埋め(左パディング運用ならそのまま 0/1 を渡す)
|
| 639 |
-
if "attention_mask" not in model_kwargs or model_kwargs["attention_mask"] is None:
|
| 640 |
-
model_kwargs["attention_mask"] = torch.ones_like(input_ids, device=device)
|
| 641 |
-
|
| 642 |
-
outputs = model(
|
| 643 |
-
input_ids=input_ids,
|
| 644 |
-
**model_kwargs,
|
| 645 |
-
)
|
| 646 |
-
# logits -> processors -> next token(ここは greedy。sampling は必要に応じて拡張)
|
| 647 |
-
next_token_logits = outputs.logits[:, -1, :]
|
| 648 |
-
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 649 |
-
next_tokens = torch.argmax(next_token_scores, dim=-1, keepdim=True) # [1,1]
|
| 650 |
-
|
| 651 |
-
if streamer is not None:
|
| 652 |
-
streamer.put(next_tokens.cpu())
|
| 653 |
-
|
| 654 |
-
sequences = torch.cat([input_ids, next_tokens], dim=1) # 0..i
|
| 655 |
-
# cur_len = sequences.size(1)
|
| 656 |
-
|
| 657 |
-
# ---- 2) 以降: 毎回 3 トークン窓で前進、KV は “2つ前まで” を可視に ----
|
| 658 |
-
# stopping_criteria は generate() 側で組み立て済み
|
| 659 |
-
N = 3
|
| 660 |
-
while True:
|
| 661 |
-
# 停止判定(EOS, max_length, 任意の criteria)
|
| 662 |
-
if stopping_criteria(sequences, None):
|
| 663 |
-
break
|
| 664 |
-
if sequences.size(1) >= generation_config.max_length:
|
| 665 |
-
break
|
| 666 |
-
|
| 667 |
-
# 直近 index t(直前に確定した末尾)
|
| 668 |
-
t = sequences.size(1) - 1 # i, i+1, ...
|
| 669 |
-
|
| 670 |
-
# KV を使わせる過去長 keep_len = t-2(= i+1 生成時に i-2 まで)
|
| 671 |
-
keep_len = max(0, t - N - 1)
|
| 672 |
-
|
| 673 |
-
# 入力は直近3トークン(不足時は短くなるのでそのまま)
|
| 674 |
-
window = sequences[:, -N:] if sequences.size(1) >= N else sequences
|
| 675 |
-
# この窓を書き込む位置を明示: [keep_len .. keep_len+len(window)-1]
|
| 676 |
Lw = window.size(1)
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
# 前進
|
| 682 |
-
outputs = model(
|
| 683 |
-
input_ids=window,
|
| 684 |
-
**model_kwargs, # past_key_values は同じ Cache(in-place 更新)
|
| 685 |
-
)
|
| 686 |
-
next_token_logits = outputs.logits[:, -1, :]
|
| 687 |
-
next_token_scores = logits_processor(sequences, next_token_logits)
|
| 688 |
-
|
| 689 |
-
# greedy(必要に応じて sampling を追加)
|
| 690 |
-
next_tokens = torch.argmax(next_token_scores, dim=-1, keepdim=True) # [1,1]
|
| 691 |
-
|
| 692 |
-
sequences = torch.cat([sequences, next_tokens], dim=1)
|
| 693 |
-
|
| 694 |
-
if streamer is not None:
|
| 695 |
-
streamer.put(next_tokens.cpu())
|
| 696 |
|
| 697 |
-
|
| 698 |
-
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
)
|
| 708 |
-
return sequences
|
|
|
|
| 586 |
def base_model(self):
|
| 587 |
return self.model
|
| 588 |
|
| 589 |
+
def prepare_inputs_for_generation(
|
| 590 |
+
self,
|
| 591 |
+
input_ids: torch.LongTensor,
|
| 592 |
+
past_key_values: Optional["Cache"] = None,
|
| 593 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 594 |
+
use_cache: Optional[bool] = None,
|
| 595 |
+
**kwargs,
|
| 596 |
+
):
|
| 597 |
+
"""
|
| 598 |
+
- prefill(past_key_values None)では全文をそのまま渡し、cache_position = [0..L-1]
|
| 599 |
+
- decode(past_key_values あり)では直近3トークンだけを入力し、
|
| 600 |
+
cache_position = [keep_len .. keep_len+Lw-1] とする
|
| 601 |
+
* t = 直前に確定した末尾 index(= input_ids.size(1)-1)
|
| 602 |
+
* keep_len = max(0, t-2)
|
| 603 |
+
- attention_mask は簡単のため 1 埋め(左PAD運用ならそのまま attention_mask を流用)
|
| 604 |
+
"""
|
| 605 |
+
device = input_ids.device
|
| 606 |
+
use_cache = True if use_cache is None else use_cache
|
| 607 |
+
|
| 608 |
+
if past_key_values is None or not use_cache:
|
| 609 |
+
# ---- prefill: 全文 ----
|
| 610 |
+
L = input_ids.size(1)
|
| 611 |
+
cache_position = torch.arange(0, L, device=device) # [L]
|
| 612 |
+
if attention_mask is None:
|
| 613 |
+
attention_mask = torch.ones_like(input_ids, device=device)
|
| 614 |
+
return {
|
| 615 |
+
"input_ids": input_ids,
|
| 616 |
+
"attention_mask": attention_mask,
|
| 617 |
+
"past_key_values": past_key_values,
|
| 618 |
+
"use_cache": use_cache,
|
| 619 |
+
"cache_position": cache_position, # ★ [L]
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
# ---- decode: 3トークン窓 ----
|
| 623 |
+
# input_ids は「これまでの全文」なので、ここで直近3に切る
|
| 624 |
+
window = input_ids[:, -3:] # [B, Lw<=3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
Lw = window.size(1)
|
| 626 |
+
t = input_ids.size(1) - 1 # i, i+1, ...
|
| 627 |
+
keep_len = max(0, t - 2) # 仕様: “2つ前まで”を過去KVとして可視
|
| 628 |
+
cache_position = torch.arange(keep_len, keep_len + Lw, device=device) # ★ [Lw]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
+
# attention_mask は窓に合わせる(左PADを使っているなら適宜スライスに置換)
|
| 631 |
+
attn = torch.ones_like(window, device=device)
|
| 632 |
|
| 633 |
+
return {
|
| 634 |
+
"input_ids": window,
|
| 635 |
+
"attention_mask": attn,
|
| 636 |
+
"past_key_values": past_key_values, # Cache(in-place更新)
|
| 637 |
+
"use_cache": use_cache,
|
| 638 |
+
"cache_position": cache_position, # ★ [Lw]
|
| 639 |
+
}
|
|
|
|
|
|