KotshinZ commited on
Commit
6dd0a9e
·
verified ·
1 Parent(s): a7a763a

Model save

Browse files
PreTrainedRMTConfig.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from transformers import PretrainedConfig
4
+
5
+ class PreTrainedRMTConfig(PretrainedConfig):
6
+ """
7
+ Recurrent Memory Transformer の設定クラス
8
+ """
9
+
10
+ model_type = "rmt"
11
+
12
+ # マッピング情報を追加(設定クラスとモデルクラスの関連付け)
13
+ auto_map = {
14
+ "AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
15
+ }
16
+
17
+ def __init__(
18
+ self,
19
+ base_model_config=None,
20
+ is_memory_all=True,
21
+ max_n_segments=1,
22
+ input_seg_len=512,
23
+ output_seg_len=512,
24
+ align="left",
25
+ num_mem_tokens=10,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+ self.base_model_config = base_model_config
30
+ self.is_memory_all = is_memory_all
31
+ self.max_n_segments = max_n_segments
32
+ self.input_seg_len = input_seg_len
33
+ self.output_seg_len = output_seg_len
34
+ self.align = align
35
+ self.num_mem_tokens = num_mem_tokens
36
+
37
+ if base_model_config is not None:
38
+ if type(base_model_config) is not dict:
39
+ dict_config: dict = base_model_config.to_dict()
40
+ else:
41
+ dict_config: dict = base_model_config
42
+
43
+ for key, value in dict_config.items():
44
+ setattr(self, key, value)
45
+ self.base_model_type = dict_config.get("model_type")
46
+ if self.base_model_type is None:
47
+ raise ValueError("base_model_configにmodel_typeが指定されていません。")
48
+ PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type
49
+ """
50
+ def __repr__(self):
51
+ return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
52
+ f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
53
+ f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
54
+ """
55
+
56
+ PreTrainedRMTConfig.register_for_auto_class()
README.md CHANGED
@@ -1,11 +1,9 @@
1
  ---
2
  base_model: KotshinZ/gpt2-RMT-7
3
- datasets: HuggingFaceH4/Bespoke-Stratos-17k
4
  library_name: transformers
5
  model_name: gpt2-RMT-8
6
  tags:
7
  - generated_from_trainer
8
- - open-r1
9
  - trl
10
  - sft
11
  licence: license
@@ -13,7 +11,7 @@ licence: license
13
 
14
  # Model Card for gpt2-RMT-8
15
 
16
- This model is a fine-tuned version of [KotshinZ/gpt2-RMT-7](https://huggingface.co/KotshinZ/gpt2-RMT-7) on the [HuggingFaceH4/Bespoke-Stratos-17k](https://huggingface.co/datasets/HuggingFaceH4/Bespoke-Stratos-17k) dataset.
17
  It has been trained using [TRL](https://github.com/huggingface/trl).
18
 
19
  ## Quick start
@@ -29,14 +27,14 @@ print(output["generated_text"])
29
 
30
  ## Training procedure
31
 
32
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/s18574s18574-/huggingface/runs/ch6o4apk)
33
 
34
 
35
  This model was trained with SFT.
36
 
37
  ### Framework versions
38
 
39
- - TRL: 0.16.0.dev0
40
  - Transformers: 4.50.0.dev0
41
  - Pytorch: 2.5.1
42
  - Datasets: 3.3.2
 
1
  ---
2
  base_model: KotshinZ/gpt2-RMT-7
 
3
  library_name: transformers
4
  model_name: gpt2-RMT-8
5
  tags:
6
  - generated_from_trainer
 
7
  - trl
8
  - sft
9
  licence: license
 
11
 
12
  # Model Card for gpt2-RMT-8
13
 
14
+ This model is a fine-tuned version of [KotshinZ/gpt2-RMT-7](https://huggingface.co/KotshinZ/gpt2-RMT-7).
15
  It has been trained using [TRL](https://github.com/huggingface/trl).
16
 
17
  ## Quick start
 
27
 
28
  ## Training procedure
29
 
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/shin2021001-osaka-city-university/huggingface/runs/4gmo4uif)
31
 
32
 
33
  This model was trained with SFT.
34
 
35
  ### Framework versions
36
 
37
+ - TRL: 0.15.2
38
  - Transformers: 4.50.0.dev0
39
  - Pytorch: 2.5.1
40
  - Datasets: 3.3.2
RecurrentMemoryTransofomer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
3
+ from transformers.models.auto.auto_factory import _BaseAutoModelClass
4
+ from open_r1.rmt.MemoryCell import MemoryCell
5
+ from open_r1.rmt.RecurrentWrapper import RecurrentWrapper
6
+ from open_r1.rmt.PreTrainedRMTConfig import PreTrainedRMTConfig
7
+
8
+
9
+ # @register_for_auto_class("AutoModelForCausalLM")
10
+ class RecurrentMemoryTransformer(PreTrainedModel):
11
+ """
12
+ Recurrent Memory Transformer モデルクラス
13
+ 長い文脈をセグメント単位で処理し、メモリを使って情報を保持するトランスフォーマーモデル
14
+ """
15
+
16
+ config_class = PreTrainedRMTConfig
17
+ auto_model_class = "AutoModelForCausalLM"
18
+
19
+ # マッピングを定義してAutoクラスが適切なモデルを見つけられるようにする
20
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
21
+
22
+ # AUTO_MAPを定義(モデル名からクラスへのマッピング)
23
+ AUTO_MAP = {
24
+ "AutoModelForCausalLM": "RecurrentMemoryTransformer",
25
+ }
26
+
27
+ def __init__(self, config, base_model=None):
28
+ """
29
+ 初期化
30
+
31
+ Parameters
32
+ ----------
33
+ config : PreTrainedRMTConfig
34
+ モデルの設定
35
+ base_model : PreTrainedModel, optional
36
+ ベースとなるトランスフォーマーモデル
37
+ """
38
+ super().__init__(config)
39
+
40
+ # base_modelが指定されていない場合は、configから自動生成
41
+ if base_model is None:
42
+ # ベースモデルのタイプを確認
43
+ if not hasattr(config, "base_model_type"):
44
+ raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。")
45
+ base_model_type = config.base_model_type
46
+
47
+ # ベースモデル用の設定を作成
48
+ base_config = AutoConfig.from_pretrained(base_model_type)
49
+
50
+ # RMT固有のパラメータを除外してベースモデルの設定を作成
51
+ rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len',
52
+ 'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type']
53
+ for key, value in config.__dict__.items():
54
+ if key not in rmt_specific_params and not key.startswith('_'):
55
+ setattr(base_config, key, value)
56
+
57
+ # ベースモデルを作成
58
+ base_model = AutoModelForCausalLM.from_config(base_config)
59
+
60
+ # MemoryCellとRecurrentWrapperの初期化
61
+ memory_cell = MemoryCell(base_model, config.num_mem_tokens)
62
+ self.recurrent_wrapper = RecurrentWrapper(
63
+ memory_cell=memory_cell,
64
+ is_memory_all=config.is_memory_all,
65
+ max_n_segments=config.max_n_segments,
66
+ input_seg_len=config.input_seg_len,
67
+ output_seg_len=config.output_seg_len,
68
+ align=config.align
69
+ )
70
+
71
+ def get_base_model(self):
72
+ """
73
+ ベースモデルを取得
74
+ """
75
+ return self.recurrent_wrapper.memory_cell.model
76
+
77
+ def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None,
78
+ inputs_embeds=None, output_attentions=None, output_hidden_states=None):
79
+ """
80
+ モデルの順伝播
81
+
82
+ Parameters
83
+ ----------
84
+ input_ids : torch.Tensor, optional
85
+ 入力テンソル
86
+ attention_mask : torch.Tensor, optional
87
+ アテンションマスク
88
+ labels : torch.Tensor, optional
89
+ ラベルテンソル
90
+ labels_mask : torch.Tensor, optional
91
+ ラベルマスク
92
+ inputs_embeds : torch.Tensor, optional
93
+ 入力埋め込み
94
+ output_attentions : bool, optional
95
+ アテンション重みを出力するかどうか
96
+ output_hidden_states : bool, optional
97
+ 隠れ状態を出力するかどうか
98
+ """
99
+ forward_kwargs = {}
100
+ if input_ids is not None:
101
+ forward_kwargs["input_ids"] = input_ids
102
+ if labels is not None:
103
+ forward_kwargs["labels"] = labels
104
+ if attention_mask is not None:
105
+ forward_kwargs["attention_mask"] = attention_mask
106
+ if labels_mask is not None:
107
+ forward_kwargs["labels_mask"] = labels_mask
108
+ if inputs_embeds is not None:
109
+ forward_kwargs["inputs_embeds"] = inputs_embeds
110
+ if output_attentions is not None:
111
+ forward_kwargs["output_attentions"] = output_attentions
112
+ if output_hidden_states is not None:
113
+ forward_kwargs["output_hidden_states"] = output_hidden_states
114
+
115
+ #forward_kwargs.update(kwargs)
116
+
117
+ # 通常の順伝播処理
118
+ out = self.recurrent_wrapper.forward(**forward_kwargs)
119
+ """
120
+ # デバッグ出力を削除(または必要に応じてコメント化)
121
+ # print(out["loss"])
122
+
123
+ # 分散環境で損失が二重計算されないよう、ワールドサイズで割る
124
+ # これは処理済みの場合は不要なので、環境変数などで制御することも可能
125
+ if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None:
126
+ # 既にDeepSpeedが処理している可能性があるため、確認が必要
127
+ # テスト目的で一時的に追加(実際の環境に合わせて調整が必要)
128
+ # world_size = torch.distributed.get_world_size()
129
+ # out["loss"] = out["loss"] / world_size
130
+ pass
131
+ """
132
+ return out
133
+
134
+ def generate(self, **kwargs):
135
+ """
136
+ テキスト生成
137
+ """
138
+ return self.recurrent_wrapper.generate(**kwargs)
139
+
140
+ def generate_with_tokenizer(self, tokenizer, input_text, **kwargs):
141
+ """
142
+ トークナイザーを用いたテキスト生成
143
+ """
144
+ return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs)
145
+
146
+ def get_input_embeddings(self):
147
+ """
148
+ 入力埋め込みを取得
149
+ """
150
+ return self.get_base_model().get_input_embeddings()
151
+
152
+ def set_input_embeddings(self, embeddings):
153
+ """
154
+ 入力埋め込みを設定
155
+ """
156
+ self.get_base_model().set_input_embeddings(embeddings)
157
+
158
+ def get_output_embeddings(self):
159
+ """
160
+ 出力埋め込みを取得
161
+ """
162
+ return self.get_base_model().get_output_embeddings()
163
+
164
+ def resize_token_embeddings(self, new_num_tokens):
165
+ """
166
+ トークン埋め込みのサイズを変更
167
+ """
168
+ self.get_base_model().resize_token_embeddings(new_num_tokens)
169
+ return self.get_input_embeddings()
170
+
171
+ RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM")
all_results.json CHANGED
@@ -1,13 +1,8 @@
1
  {
2
- "eval_loss": 1.9101824760437012,
3
- "eval_runtime": 2.9445,
4
- "eval_samples": 100,
5
- "eval_samples_per_second": 372.563,
6
- "eval_steps_per_second": 93.395,
7
- "total_flos": 2.3971976183808e+16,
8
- "train_loss": 1.9497717069673088,
9
- "train_runtime": 714.9694,
10
- "train_samples": 16610,
11
- "train_samples_per_second": 128.305,
12
- "train_steps_per_second": 8.02
13
  }
 
1
  {
2
+ "total_flos": 5452245613150208.0,
3
+ "train_loss": 3.1222703313253013,
4
+ "train_runtime": 348.715,
5
+ "train_samples": 19883,
6
+ "train_samples_per_second": 7.605,
7
+ "train_steps_per_second": 0.238
 
 
 
 
 
8
  }
config.json CHANGED
@@ -1,22 +1,126 @@
1
  {
2
- "_name_or_path": "KotshinZ/gpt2-RMT-7",
3
  "activation_function": "gelu_new",
 
4
  "architectures": [
5
- "GPT2LMHeadModel"
6
  ],
7
  "attn_pdrop": 0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  "bos_token_id": 50256,
9
  "embd_pdrop": 0.1,
10
  "eos_token_id": 50256,
 
 
 
 
11
  "initializer_range": 0.02,
 
 
12
  "layer_norm_epsilon": 1e-05,
13
- "model_type": "gpt2",
 
 
14
  "n_ctx": 1024,
15
  "n_embd": 768,
16
  "n_head": 12,
17
  "n_inner": null,
18
  "n_layer": 12,
19
  "n_positions": 1024,
 
 
20
  "reorder_and_upcast_attn": false,
21
  "resid_pdrop": 0.1,
22
  "scale_attn_by_inverse_layer_idx": false,
@@ -34,6 +138,6 @@
34
  },
35
  "torch_dtype": "bfloat16",
36
  "transformers_version": "4.50.0.dev0",
37
- "use_cache": true,
38
  "vocab_size": 50257
39
  }
 
1
  {
 
2
  "activation_function": "gelu_new",
3
+ "align": "left",
4
  "architectures": [
5
+ "RecurrentMemoryTransformer"
6
  ],
7
  "attn_pdrop": 0.1,
8
+ "base_model_config": {
9
+ "_attn_implementation_autoset": true,
10
+ "_name_or_path": "openai-community/gpt2",
11
+ "activation_function": "gelu_new",
12
+ "add_cross_attention": false,
13
+ "architectures": [
14
+ "GPT2LMHeadModel"
15
+ ],
16
+ "attn_pdrop": 0.1,
17
+ "bad_words_ids": null,
18
+ "begin_suppress_tokens": null,
19
+ "bos_token_id": 50256,
20
+ "chunk_size_feed_forward": 0,
21
+ "cross_attention_hidden_size": null,
22
+ "decoder_start_token_id": null,
23
+ "diversity_penalty": 0.0,
24
+ "do_sample": false,
25
+ "early_stopping": false,
26
+ "embd_pdrop": 0.1,
27
+ "encoder_no_repeat_ngram_size": 0,
28
+ "eos_token_id": 50256,
29
+ "exponential_decay_length_penalty": null,
30
+ "finetuning_task": null,
31
+ "forced_bos_token_id": null,
32
+ "forced_eos_token_id": null,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1"
36
+ },
37
+ "initializer_range": 0.02,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_epsilon": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "min_length": 0,
48
+ "model_type": "gpt2",
49
+ "n_ctx": 1024,
50
+ "n_embd": 768,
51
+ "n_head": 12,
52
+ "n_inner": null,
53
+ "n_layer": 12,
54
+ "n_positions": 1024,
55
+ "no_repeat_ngram_size": 0,
56
+ "num_beam_groups": 1,
57
+ "num_beams": 1,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": null,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "pruned_heads": {},
66
+ "remove_invalid_values": false,
67
+ "reorder_and_upcast_attn": false,
68
+ "repetition_penalty": 1.0,
69
+ "resid_pdrop": 0.1,
70
+ "return_dict": true,
71
+ "return_dict_in_generate": false,
72
+ "scale_attn_by_inverse_layer_idx": false,
73
+ "scale_attn_weights": true,
74
+ "sep_token_id": null,
75
+ "summary_activation": null,
76
+ "summary_first_dropout": 0.1,
77
+ "summary_proj_to_labels": true,
78
+ "summary_type": "cls_index",
79
+ "summary_use_proj": true,
80
+ "suppress_tokens": null,
81
+ "task_specific_params": {
82
+ "text-generation": {
83
+ "do_sample": true,
84
+ "max_length": 50
85
+ }
86
+ },
87
+ "temperature": 1.0,
88
+ "tf_legacy_loss": false,
89
+ "tie_encoder_decoder": false,
90
+ "tie_word_embeddings": true,
91
+ "tokenizer_class": null,
92
+ "top_k": 50,
93
+ "top_p": 1.0,
94
+ "torch_dtype": "bfloat16",
95
+ "torchscript": false,
96
+ "typical_p": 1.0,
97
+ "use_bfloat16": false,
98
+ "use_cache": false,
99
+ "vocab_size": 50257
100
+ },
101
+ "base_model_type": "gpt2",
102
  "bos_token_id": 50256,
103
  "embd_pdrop": 0.1,
104
  "eos_token_id": 50256,
105
+ "id2label": {
106
+ "0": "LABEL_0",
107
+ "1": "LABEL_1"
108
+ },
109
  "initializer_range": 0.02,
110
+ "input_seg_len": 1004,
111
+ "is_memory_all": false,
112
  "layer_norm_epsilon": 1e-05,
113
+ "max_n_segments": 2,
114
+ "memory_size": 10,
115
+ "model_type": "rmt_gpt2",
116
  "n_ctx": 1024,
117
  "n_embd": 768,
118
  "n_head": 12,
119
  "n_inner": null,
120
  "n_layer": 12,
121
  "n_positions": 1024,
122
+ "num_mem_tokens": 10,
123
+ "output_seg_len": 1004,
124
  "reorder_and_upcast_attn": false,
125
  "resid_pdrop": 0.1,
126
  "scale_attn_by_inverse_layer_idx": false,
 
138
  },
139
  "torch_dtype": "bfloat16",
140
  "transformers_version": "4.50.0.dev0",
141
+ "use_cache": false,
142
  "vocab_size": 50257
143
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:785b3bc458e117f5a369b848617750c5e7a0951dee5013c3a213521f521a682e
3
- size 326089656
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a608b7f1a18a1d7c70b018ba505eb4edce00cd300ac3027ffc3d60408e9c2c8
3
+ size 248915448
tokenizer_config.json CHANGED
@@ -11,7 +11,6 @@
11
  }
12
  },
13
  "bos_token": "<|endoftext|>",
14
- "chat_template": "{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}\n{% endfor %}",
15
  "clean_up_tokenization_spaces": false,
16
  "eos_token": "<|endoftext|>",
17
  "extra_special_tokens": {},
 
11
  }
12
  },
13
  "bos_token": "<|endoftext|>",
 
14
  "clean_up_tokenization_spaces": false,
15
  "eos_token": "<|endoftext|>",
16
  "extra_special_tokens": {},
train_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "total_flos": 2.3971976183808e+16,
3
- "train_loss": 1.9497717069673088,
4
- "train_runtime": 714.9694,
5
- "train_samples": 16610,
6
- "train_samples_per_second": 128.305,
7
- "train_steps_per_second": 8.02
8
  }
 
1
  {
2
+ "total_flos": 5452245613150208.0,
3
+ "train_loss": 3.1222703313253013,
4
+ "train_runtime": 348.715,
5
+ "train_samples": 19883,
6
+ "train_samples_per_second": 7.605,
7
+ "train_steps_per_second": 0.238
8
  }
trainer_state.json CHANGED
The diff for this file is too large to render. See raw diff
 
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b104884d6ba1f77b5b69bc7b531f0c6d732ce72aed65b9dab9261c66a5de002
3
  size 7352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec32a102e3bf57cfe8a1fabd9e01b5ba70a29d6a8e8a4c20779e360df3b2ef1b
3
  size 7352