|
|
import os |
|
|
import json |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class PreTrainedRMTConfig(PretrainedConfig): |
|
|
""" |
|
|
Recurrent Memory Transformer の設定クラス |
|
|
""" |
|
|
|
|
|
model_type = "rmt" |
|
|
|
|
|
|
|
|
auto_map = { |
|
|
"AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer" |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model_config=None, |
|
|
is_memory_all=True, |
|
|
max_n_segments=1, |
|
|
input_seg_len=512, |
|
|
output_seg_len=512, |
|
|
align="left", |
|
|
num_mem_tokens=10, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model_config = base_model_config |
|
|
self.is_memory_all = is_memory_all |
|
|
self.max_n_segments = max_n_segments |
|
|
self.input_seg_len = input_seg_len |
|
|
self.output_seg_len = output_seg_len |
|
|
self.align = align |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
|
|
|
if base_model_config is not None: |
|
|
if type(base_model_config) is not dict: |
|
|
dict_config: dict = base_model_config.to_dict() |
|
|
else: |
|
|
dict_config: dict = base_model_config |
|
|
|
|
|
for key, value in dict_config.items(): |
|
|
setattr(self, key, value) |
|
|
self.base_model_type = dict_config.get("model_type") |
|
|
if self.base_model_type is None: |
|
|
raise ValueError("base_model_configにmodel_typeが指定されていません。") |
|
|
PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type |
|
|
""" |
|
|
def __repr__(self): |
|
|
return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \ |
|
|
f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \ |
|
|
f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})" |
|
|
""" |
|
|
|
|
|
PreTrainedRMTConfig.register_for_auto_class() |