gpt2-RMT-8 / PreTrainedRMTConfig.py
KotshinZ's picture
Model save
6dd0a9e verified
import os
import json
from transformers import PretrainedConfig
class PreTrainedRMTConfig(PretrainedConfig):
"""
Recurrent Memory Transformer の設定クラス
"""
model_type = "rmt"
# マッピング情報を追加(設定クラスとモデルクラスの関連付け)
auto_map = {
"AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
}
def __init__(
self,
base_model_config=None,
is_memory_all=True,
max_n_segments=1,
input_seg_len=512,
output_seg_len=512,
align="left",
num_mem_tokens=10,
**kwargs
):
super().__init__(**kwargs)
self.base_model_config = base_model_config
self.is_memory_all = is_memory_all
self.max_n_segments = max_n_segments
self.input_seg_len = input_seg_len
self.output_seg_len = output_seg_len
self.align = align
self.num_mem_tokens = num_mem_tokens
if base_model_config is not None:
if type(base_model_config) is not dict:
dict_config: dict = base_model_config.to_dict()
else:
dict_config: dict = base_model_config
for key, value in dict_config.items():
setattr(self, key, value)
self.base_model_type = dict_config.get("model_type")
if self.base_model_type is None:
raise ValueError("base_model_configにmodel_typeが指定されていません。")
PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type
"""
def __repr__(self):
return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
"""
PreTrainedRMTConfig.register_for_auto_class()