File size: 6,994 Bytes
683cc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from open_r1.rmt.MemoryCell import MemoryCell
from open_r1.rmt.RecurrentWrapper import RecurrentWrapper
from open_r1.rmt.PreTrainedRMTConfig import PreTrainedRMTConfig


# @register_for_auto_class("AutoModelForCausalLM")
class RecurrentMemoryTransformer(PreTrainedModel):
    """
    Recurrent Memory Transformer モデルクラス
    長い文脈をセグメント単位で処理し、メモリを使って情報を保持するトランスフォーマーモデル
    """
    
    config_class = PreTrainedRMTConfig
    auto_model_class = "AutoModelForCausalLM"
    
    # マッピングを定義してAutoクラスが適切なモデルを見つけられるようにする
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    
    # AUTO_MAPを定義(モデル名からクラスへのマッピング)
    AUTO_MAP = {
        "AutoModelForCausalLM": "RecurrentMemoryTransformer",
    }
    
    def __init__(self, config, base_model=None):
        """
        初期化
        
        Parameters
        ----------
        config : PreTrainedRMTConfig
            モデルの設定
        base_model : PreTrainedModel, optional
            ベースとなるトランスフォーマーモデル
        """
        super().__init__(config)
        
        # base_modelが指定されていない場合は、configから自動生成
        if base_model is None:
            # ベースモデルのタイプを確認
            if not hasattr(config, "base_model_type"):
                raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。")
            base_model_type = config.base_model_type
            
            # ベースモデル用の設定を作成
            base_config = AutoConfig.from_pretrained(base_model_type)
            
            # RMT固有のパラメータを除外してベースモデルの設定を作成
            rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len', 
                                  'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type']
            for key, value in config.__dict__.items():
                if key not in rmt_specific_params and not key.startswith('_'):
                    setattr(base_config, key, value)
            
            # ベースモデルを作成
            base_model = AutoModelForCausalLM.from_config(base_config)
        
        # MemoryCellとRecurrentWrapperの初期化
        memory_cell = MemoryCell(base_model, config.num_mem_tokens)
        self.recurrent_wrapper = RecurrentWrapper(
            memory_cell=memory_cell,
            is_memory_all=config.is_memory_all,
            max_n_segments=config.max_n_segments,
            input_seg_len=config.input_seg_len,
            output_seg_len=config.output_seg_len,
            align=config.align
        )
        
    def get_base_model(self):
        """
        ベースモデルを取得
        """
        return self.recurrent_wrapper.memory_cell.model
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None, 
                inputs_embeds=None, output_attentions=None, output_hidden_states=None):
        """
        モデルの順伝播
        
        Parameters
        ----------
        input_ids : torch.Tensor, optional
            入力テンソル
        attention_mask : torch.Tensor, optional
            アテンションマスク
        labels : torch.Tensor, optional
            ラベルテンソル
        labels_mask : torch.Tensor, optional
            ラベルマスク
        inputs_embeds : torch.Tensor, optional
            入力埋め込み
        output_attentions : bool, optional
            アテンション重みを出力するかどうか
        output_hidden_states : bool, optional
            隠れ状態を出力するかどうか
        """
        forward_kwargs = {}
        if input_ids is not None:
            forward_kwargs["input_ids"] = input_ids
        if labels is not None:
            forward_kwargs["labels"] = labels
        if attention_mask is not None:
            forward_kwargs["attention_mask"] = attention_mask
        if labels_mask is not None:
            forward_kwargs["labels_mask"] = labels_mask
        if inputs_embeds is not None:
            forward_kwargs["inputs_embeds"] = inputs_embeds
        if output_attentions is not None:
            forward_kwargs["output_attentions"] = output_attentions
        if output_hidden_states is not None:
            forward_kwargs["output_hidden_states"] = output_hidden_states
        
        #forward_kwargs.update(kwargs)
        
        # 通常の順伝播処理
        out = self.recurrent_wrapper.forward(**forward_kwargs)
        """
        # デバッグ出力を削除(または必要に応じてコメント化)
        # print(out["loss"])
        
        # 分散環境で損失が二重計算されないよう、ワールドサイズで割る
        # これは処理済みの場合は不要なので、環境変数などで制御することも可能
        if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None:
            # 既にDeepSpeedが処理している可能性があるため、確認が必要
            # テスト目的で一時的に追加(実際の環境に合わせて調整が必要)
            # world_size = torch.distributed.get_world_size()
            # out["loss"] = out["loss"] / world_size
            pass
        """
        return out
    
    def generate(self, **kwargs):
        """
        テキスト生成
        """
        return self.recurrent_wrapper.generate(**kwargs)
    
    def generate_with_tokenizer(self, tokenizer, input_text, **kwargs):
        """
        トークナイザーを用いたテキスト生成
        """
        return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs)
    
    def get_input_embeddings(self):
        """
        入力埋め込みを取得
        """
        return self.get_base_model().get_input_embeddings()
    
    def set_input_embeddings(self, embeddings):
        """
        入力埋め込みを設定
        """
        self.get_base_model().set_input_embeddings(embeddings)
    
    def get_output_embeddings(self):
        """
        出力埋め込みを取得
        """
        return self.get_base_model().get_output_embeddings()
    
    def resize_token_embeddings(self, new_num_tokens):
        """
        トークン埋め込みのサイズを変更
        """
        self.get_base_model().resize_token_embeddings(new_num_tokens)
        return self.get_input_embeddings()

RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM")