File size: 2,024 Bytes
6dd0a9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import json
from transformers import PretrainedConfig
class PreTrainedRMTConfig(PretrainedConfig):
"""
Recurrent Memory Transformer の設定クラス
"""
model_type = "rmt"
# マッピング情報を追加(設定クラスとモデルクラスの関連付け)
auto_map = {
"AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
}
def __init__(
self,
base_model_config=None,
is_memory_all=True,
max_n_segments=1,
input_seg_len=512,
output_seg_len=512,
align="left",
num_mem_tokens=10,
**kwargs
):
super().__init__(**kwargs)
self.base_model_config = base_model_config
self.is_memory_all = is_memory_all
self.max_n_segments = max_n_segments
self.input_seg_len = input_seg_len
self.output_seg_len = output_seg_len
self.align = align
self.num_mem_tokens = num_mem_tokens
if base_model_config is not None:
if type(base_model_config) is not dict:
dict_config: dict = base_model_config.to_dict()
else:
dict_config: dict = base_model_config
for key, value in dict_config.items():
setattr(self, key, value)
self.base_model_type = dict_config.get("model_type")
if self.base_model_type is None:
raise ValueError("base_model_configにmodel_typeが指定されていません。")
PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type
"""
def __repr__(self):
return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
"""
PreTrainedRMTConfig.register_for_auto_class() |