File size: 6,763 Bytes
7900f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from .MemoryCell import MemoryCell
from .RecurrentWrapper import RecurrentWrapper
from .PreTrainedRMTConfig import PreTrainedRMTConfig


# @register_for_auto_class("AutoModelForCausalLM")
class RecurrentMemoryTransformer(PreTrainedModel):
    """
    Recurrent Memory Transformer Model Class
    A transformer model that processes long context in segments and retains information using memory
    """
    
    config_class = PreTrainedRMTConfig
    auto_model_class = "AutoModelForCausalLM"
    
    # マッピングを定義してAutoクラスが適切なモデルを見つけられるようにする
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    
    # AUTO_MAPを定義(モデル名からクラスへのマッピング)
    AUTO_MAP = {
        "AutoModelForCausalLM": "RecurrentMemoryTransformer",
    }
    
    def __init__(self, config, base_model=None):
        """
        Initialization
        
        Parameters
        ----------
        config : PreTrainedRMTConfig
            Model configuration
        base_model : PreTrainedModel, optional
            Base transformer model
        """
        super().__init__(config)
        
        # base_modelが指定されていない場合は、configから自動生成
        if base_model is None:
            # ベースモデルのタイプを確認
            if not hasattr(config, "base_model_type"):
                raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。")
            base_model_type = config.base_model_type
            
            # ベースモデル用の設定を作成
            base_config = AutoConfig.from_pretrained(base_model_type)
            
            # RMT固有のパラメータを除外してベースモデルの設定を作成
            rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len', 
                                  'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type']
            for key, value in config.__dict__.items():
                if key not in rmt_specific_params and not key.startswith('_'):
                    setattr(base_config, key, value)
            
            # ベースモデルを作成
            base_model = AutoModelForCausalLM.from_config(base_config)
        
        # MemoryCellとRecurrentWrapperの初期化
        memory_cell = MemoryCell(base_model, config.num_mem_tokens)
        self.recurrent_wrapper = RecurrentWrapper(
            memory_cell=memory_cell,
            is_memory_all=config.is_memory_all,
            max_n_segments=config.max_n_segments,
            input_seg_len=config.input_seg_len,
            output_seg_len=config.output_seg_len,
            align=config.align
        )
        
    def get_base_model(self):
        """
        Get the base model
        """
        return self.recurrent_wrapper.memory_cell.model
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None, 
                inputs_embeds=None, output_attentions=None, output_hidden_states=None):
        """
        Forward pass of the model
        
        Parameters
        ----------
        input_ids : torch.Tensor, optional
            Input tensor
        attention_mask : torch.Tensor, optional
            Attention mask
        labels : torch.Tensor, optional
            Label tensor
        labels_mask : torch.Tensor, optional
            Label mask
        inputs_embeds : torch.Tensor, optional
            Input embeddings
        output_attentions : bool, optional
            Whether to output attention weights
        output_hidden_states : bool, optional
            Whether to output hidden states
        """
        forward_kwargs = {}
        if input_ids is not None:
            forward_kwargs["input_ids"] = input_ids
        if labels is not None:
            forward_kwargs["labels"] = labels
        if attention_mask is not None:
            forward_kwargs["attention_mask"] = attention_mask
        if labels_mask is not None:
            forward_kwargs["labels_mask"] = labels_mask
        if inputs_embeds is not None:
            forward_kwargs["inputs_embeds"] = inputs_embeds
        if output_attentions is not None:
            forward_kwargs["output_attentions"] = output_attentions
        if output_hidden_states is not None:
            forward_kwargs["output_hidden_states"] = output_hidden_states
        
        #forward_kwargs.update(kwargs)
        
        # 通常の順伝播処理
        out = self.recurrent_wrapper.forward(**forward_kwargs)
        """
        # デバッグ出力を削除(または必要に応じてコメント化)
        # print(out["loss"])
        
        # 分散環境で損失が二重計算されないよう、ワールドサイズで割る
        # これは処理済みの場合は不要なので、環境変数などで制御することも可能
        if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None:
            # 既にDeepSpeedが処理している可能性があるため、確認が必要
            # テスト目的で一時的に追加(実際の環境に合わせて調整が必要)
            # world_size = torch.distributed.get_world_size()
            # out["loss"] = out["loss"] / world_size
            pass
        """
        return out
    
    def generate(self, **kwargs):
        """
        Text generation
        """
        return self.recurrent_wrapper.generate(**kwargs)
    
    def generate_with_tokenizer(self, tokenizer, input_text, **kwargs):
        """
        Text generation using tokenizer
        """
        return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs)
    
    def get_input_embeddings(self):
        """
        Get input embeddings
        """
        return self.get_base_model().get_input_embeddings()
    
    def set_input_embeddings(self, embeddings):
        """
        Set input embeddings
        """
        self.get_base_model().set_input_embeddings(embeddings)
    
    def get_output_embeddings(self):
        """
        Get output embeddings
        """
        return self.get_base_model().get_output_embeddings()
    
    def resize_token_embeddings(self, new_num_tokens):
        """
        Resize token embeddings
        """
        self.get_base_model().resize_token_embeddings(new_num_tokens)
        return self.get_input_embeddings()

RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM")