File size: 6,994 Bytes
8ff980e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from open_r1.rmt.MemoryCell import MemoryCell
from open_r1.rmt.RecurrentWrapper import RecurrentWrapper
from open_r1.rmt.PreTrainedRMTConfig import PreTrainedRMTConfig
# @register_for_auto_class("AutoModelForCausalLM")
class RecurrentMemoryTransformer(PreTrainedModel):
"""
Recurrent Memory Transformer モデルクラス
長い文脈をセグメント単位で処理し、メモリを使って情報を保持するトランスフォーマーモデル
"""
config_class = PreTrainedRMTConfig
auto_model_class = "AutoModelForCausalLM"
# マッピングを定義してAutoクラスが適切なモデルを見つけられるようにする
_keys_to_ignore_on_load_missing = [r"position_ids"]
# AUTO_MAPを定義(モデル名からクラスへのマッピング)
AUTO_MAP = {
"AutoModelForCausalLM": "RecurrentMemoryTransformer",
}
def __init__(self, config, base_model=None):
"""
初期化
Parameters
----------
config : PreTrainedRMTConfig
モデルの設定
base_model : PreTrainedModel, optional
ベースとなるトランスフォーマーモデル
"""
super().__init__(config)
# base_modelが指定されていない場合は、configから自動生成
if base_model is None:
# ベースモデルのタイプを確認
if not hasattr(config, "base_model_type"):
raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。")
base_model_type = config.base_model_type
# ベースモデル用の設定を作成
base_config = AutoConfig.from_pretrained(base_model_type)
# RMT固有のパラメータを除外してベースモデルの設定を作成
rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len',
'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type']
for key, value in config.__dict__.items():
if key not in rmt_specific_params and not key.startswith('_'):
setattr(base_config, key, value)
# ベースモデルを作成
base_model = AutoModelForCausalLM.from_config(base_config)
# MemoryCellとRecurrentWrapperの初期化
memory_cell = MemoryCell(base_model, config.num_mem_tokens)
self.recurrent_wrapper = RecurrentWrapper(
memory_cell=memory_cell,
is_memory_all=config.is_memory_all,
max_n_segments=config.max_n_segments,
input_seg_len=config.input_seg_len,
output_seg_len=config.output_seg_len,
align=config.align
)
def get_base_model(self):
"""
ベースモデルを取得
"""
return self.recurrent_wrapper.memory_cell.model
def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None,
inputs_embeds=None, output_attentions=None, output_hidden_states=None):
"""
モデルの順伝播
Parameters
----------
input_ids : torch.Tensor, optional
入力テンソル
attention_mask : torch.Tensor, optional
アテンションマスク
labels : torch.Tensor, optional
ラベルテンソル
labels_mask : torch.Tensor, optional
ラベルマスク
inputs_embeds : torch.Tensor, optional
入力埋め込み
output_attentions : bool, optional
アテンション重みを出力するかどうか
output_hidden_states : bool, optional
隠れ状態を出力するかどうか
"""
forward_kwargs = {}
if input_ids is not None:
forward_kwargs["input_ids"] = input_ids
if labels is not None:
forward_kwargs["labels"] = labels
if attention_mask is not None:
forward_kwargs["attention_mask"] = attention_mask
if labels_mask is not None:
forward_kwargs["labels_mask"] = labels_mask
if inputs_embeds is not None:
forward_kwargs["inputs_embeds"] = inputs_embeds
if output_attentions is not None:
forward_kwargs["output_attentions"] = output_attentions
if output_hidden_states is not None:
forward_kwargs["output_hidden_states"] = output_hidden_states
#forward_kwargs.update(kwargs)
# 通常の順伝播処理
out = self.recurrent_wrapper.forward(**forward_kwargs)
"""
# デバッグ出力を削除(または必要に応じてコメント化)
# print(out["loss"])
# 分散環境で損失が二重計算されないよう、ワールドサイズで割る
# これは処理済みの場合は不要なので、環境変数などで制御することも可能
if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None:
# 既にDeepSpeedが処理している可能性があるため、確認が必要
# テスト目的で一時的に追加(実際の環境に合わせて調整が必要)
# world_size = torch.distributed.get_world_size()
# out["loss"] = out["loss"] / world_size
pass
"""
return out
def generate(self, **kwargs):
"""
テキスト生成
"""
return self.recurrent_wrapper.generate(**kwargs)
def generate_with_tokenizer(self, tokenizer, input_text, **kwargs):
"""
トークナイザーを用いたテキスト生成
"""
return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs)
def get_input_embeddings(self):
"""
入力埋め込みを取得
"""
return self.get_base_model().get_input_embeddings()
def set_input_embeddings(self, embeddings):
"""
入力埋め込みを設定
"""
self.get_base_model().set_input_embeddings(embeddings)
def get_output_embeddings(self):
"""
出力埋め込みを取得
"""
return self.get_base_model().get_output_embeddings()
def resize_token_embeddings(self, new_num_tokens):
"""
トークン埋め込みのサイズを変更
"""
self.get_base_model().resize_token_embeddings(new_num_tokens)
return self.get_input_embeddings()
RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM") |