import os import json from transformers import PretrainedConfig class PreTrainedRMTConfig(PretrainedConfig): """ Recurrent Memory Transformer の設定クラス """ model_type = "rmt" # マッピング情報を追加(設定クラスとモデルクラスの関連付け) auto_map = { "AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer" } def __init__( self, base_model_config=None, is_memory_all=True, max_n_segments=1, input_seg_len=512, output_seg_len=512, align="left", num_mem_tokens=10, **kwargs ): super().__init__(**kwargs) self.base_model_config = base_model_config self.is_memory_all = is_memory_all self.max_n_segments = max_n_segments self.input_seg_len = input_seg_len self.output_seg_len = output_seg_len self.align = align self.num_mem_tokens = num_mem_tokens if base_model_config is not None: if type(base_model_config) is not dict: dict_config: dict = base_model_config.to_dict() else: dict_config: dict = base_model_config for key, value in dict_config.items(): setattr(self, key, value) self.base_model_type = dict_config.get("model_type") if self.base_model_type is None: raise ValueError("base_model_configにmodel_typeが指定されていません。") PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type """ def __repr__(self): return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \ f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \ f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})" """ PreTrainedRMTConfig.register_for_auto_class()