File size: 2,024 Bytes
6dd0a9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import json
from transformers import PretrainedConfig

class PreTrainedRMTConfig(PretrainedConfig):
    """
    Recurrent Memory Transformer の設定クラス
    """
    
    model_type = "rmt"
    
    # マッピング情報を追加(設定クラスとモデルクラスの関連付け)
    auto_map = {
        "AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
    }
    
    def __init__(
        self,
        base_model_config=None,
        is_memory_all=True,
        max_n_segments=1,
        input_seg_len=512,
        output_seg_len=512,
        align="left",
        num_mem_tokens=10,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_config = base_model_config
        self.is_memory_all = is_memory_all
        self.max_n_segments = max_n_segments
        self.input_seg_len = input_seg_len
        self.output_seg_len = output_seg_len
        self.align = align
        self.num_mem_tokens = num_mem_tokens
        
        if base_model_config is not None:
            if type(base_model_config) is not dict:
                dict_config: dict = base_model_config.to_dict()
            else:
                dict_config: dict = base_model_config
                
            for key, value in dict_config.items():
                setattr(self, key, value)
            self.base_model_type = dict_config.get("model_type")
            if self.base_model_type is None:
                raise ValueError("base_model_configにmodel_typeが指定されていません。")
            PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type   
    """
    def __repr__(self):
        return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
               f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
               f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
    """

PreTrainedRMTConfig.register_for_auto_class()