KotshinZ commited on
Commit
7900f86
·
verified ·
1 Parent(s): 96dc2e4

Model save

Browse files
MemoryCell.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
4
+ from transformers import PreTrainedModel
5
+
6
+ from .PreTrainedRMTConfig import PreTrainedRMTConfig
7
+
8
+ class MemoryCell(torch.nn.Module):
9
+ """Holds memory tensors.
10
+ Replicates memory tensor for each batch size.
11
+ Adds memory tokens to the input tensor and returns that tensor.
12
+ Processes the model output and returns a new memory state.
13
+
14
+ Parameters
15
+ ----------
16
+ torch : _type_
17
+ _description_
18
+ """
19
+
20
+ def __init__(self, base_model, num_mem_tokens):
21
+ super().__init__()
22
+ self.model = base_model
23
+ self.create_memory(num_mem_tokens)
24
+ self.config = base_model.config
25
+
26
+ # token_type_embeddingsの追加
27
+ #self.token_type_embeddings = torch.nn.Embedding(2, getattr(self.model.config, "n_embd", self.model.config.hidden_size))
28
+
29
+ def create_memory(self, num_mem_tokens):
30
+ """Randomly initializes an embedding matrix (tensor) for memory tokens and registers it for gradient computation.
31
+ Sets read and write positions for memory tokens.
32
+
33
+ Parameters
34
+ ----------
35
+ num_mem_tokens : _type_
36
+ Number of memory tokens.
37
+ """
38
+ self.read_memory_position = range(num_mem_tokens)
39
+ self.write_memory_position = range(-num_mem_tokens, 0)
40
+
41
+ self.num_mem_tokens = num_mem_tokens
42
+ embeddings = self.model.get_input_embeddings()
43
+ memory_dim = getattr(self.model.config, "n_embd", self.model.config.hidden_size)
44
+ memory_weights = (
45
+ torch.randn((num_mem_tokens, memory_dim))# * embeddings.weight.data.std()
46
+ )
47
+
48
+ self.register_parameter(
49
+ "memory", torch.nn.Parameter(memory_weights, requires_grad=True)
50
+ )
51
+
52
+ def set_memory(self, input_shape):
53
+ """Replicates memory tensor for each batch size
54
+
55
+ Parameters
56
+ ----------
57
+ input_shape : _type_
58
+ _description_
59
+
60
+ Returns
61
+ -------
62
+ _type_
63
+ Replicated memory tensor. (batch_size, num_mem_tokens, memory_dim)
64
+ """
65
+ memory = self.memory.repeat(
66
+ input_shape[0], 1, 1
67
+ ) #  メモリテンソルをバッチサイズ分だけ複製する
68
+ return memory # (batch_size, num_mem_tokens, memory_dim)
69
+
70
+ def forward(self, input_ids, memory_state=None, **kwargs):
71
+ """Performs inference.
72
+
73
+ Parameters
74
+ ----------
75
+ input_ids : torch.Tensor
76
+ Input tensor.
77
+ memory_state : torch.Tensor, optional
78
+ Memory tensor, by default None (num_mem_tokens, memory_dim)
79
+
80
+ Returns
81
+ -------
82
+ tuple(tuple, torch.Tensor)
83
+ out : tuple
84
+ Model output.
85
+ new_memory_state : torch.Tensor
86
+ New memory state.
87
+ """
88
+ if memory_state is None:
89
+ # メモリテンソルをバッチサイズ分だけ複製する
90
+ memory_state = self.set_memory(input_ids.shape)
91
+
92
+ # メモリトークンを入力テンソルに追加し、そのテンソルを返す
93
+ seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
94
+ out = self.model(**seg_kwargs)
95
+ #print(out)
96
+
97
+ # モデルの出力を処理し、新しいメモリ状態を返す
98
+ out, new_memory_state = self.process_output(out, **kwargs)
99
+
100
+ return out, new_memory_state
101
+
102
+ def process_input(self, input_ids, memory_state, **kwargs):
103
+ """Adds memory tokens to the input tensor and returns that tensor
104
+
105
+ Parameters
106
+ ----------
107
+ input_ids : _type_
108
+ Input tensor.
109
+ memory_state : _type_
110
+ Memory tensor.
111
+
112
+ Returns
113
+ -------
114
+ _type_
115
+ Input tensor with added memory tokens. (batch_size, seq_len, hidden_size)
116
+ """
117
+ seg_kwargs = dict(**kwargs)
118
+
119
+ inputs_embeds = kwargs.get("inputs_embeds")
120
+ if inputs_embeds is None:
121
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
122
+ if inputs_embeds.shape[0] != memory_state.shape[0]: # バッチサイズが異なる場合
123
+ memory_state = self.set_memory(inputs_embeds.shape)
124
+
125
+ # メモリトークンを入力テンソルに追加
126
+ inputs_embeds = torch.cat(
127
+ [memory_state, inputs_embeds, memory_state], dim=1
128
+ ).to(input_ids.device)
129
+ """
130
+ # token_type_idsの生成
131
+ token_type_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)
132
+ token_type_ids[:, self.num_mem_tokens:-self.num_mem_tokens] = 1
133
+
134
+ # token_type_embeddingsの追加と入力の更新
135
+ token_type_embeds = self.token_type_embeddings(token_type_ids)
136
+ inputs_embeds = inputs_embeds + token_type_embeds
137
+ """
138
+
139
+ seg_kwargs["input_ids"] = None
140
+ seg_kwargs["inputs_embeds"] = inputs_embeds
141
+ if kwargs.get("attention_mask") is not None:
142
+ seg_kwargs["attention_mask"] = self.pad_attention_mask(
143
+ kwargs["attention_mask"], inputs_embeds.shape
144
+ )
145
+ seg_kwargs["output_hidden_states"] = True
146
+
147
+ # Positional Embeddings
148
+ pos_mem1 = torch.arange(self.num_mem_tokens, device=input_ids.device)
149
+ pos_mem2 = torch.arange(self.num_mem_tokens, self.num_mem_tokens * 2, device=input_ids.device)
150
+ pos_seg = torch.arange(self.num_mem_tokens * 2, self.num_mem_tokens * 2 + input_ids.shape[1], device=input_ids.device)
151
+ pos = torch.cat([pos_mem1, pos_seg, pos_mem2], dim=0)
152
+ pos = pos.unsqueeze(0).expand(input_ids.shape[0], -1)
153
+ seg_kwargs["position_ids"] = pos
154
+
155
+ return seg_kwargs
156
+
157
+ def pad_attention_mask(self, attention_mask, shape):
158
+ if self.num_mem_tokens in {0, None}:
159
+ return attention_mask
160
+ else:
161
+ attention_mask = torch.cat(
162
+ [
163
+ torch.ones(
164
+ shape[0], self.num_mem_tokens, device=attention_mask.device
165
+ ),
166
+ attention_mask,
167
+ torch.ones(
168
+ shape[0], self.num_mem_tokens, device=attention_mask.device
169
+ ),
170
+ ],
171
+ dim=1,
172
+ )
173
+ return attention_mask
174
+
175
+ def compute_logpi(mean, stddev, action):
176
+ a1 =-0.5 * torch.log(2*torch.fill(stddev.shape, math.pi))
177
+ a2 = -torch.log(stddev)
178
+ a3 = -0.5 * (((action - mean) / stddev) ** 2)
179
+ return a1 + a2 + a3
180
+
181
+ def process_output(self, model_outputs, **kwargs):
182
+ if self.num_mem_tokens not in {0, None}:
183
+ out = CausalLMOutputWithCrossAttentions()
184
+ memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens :]
185
+ out["logits"] = model_outputs.logits[
186
+ :, self.num_mem_tokens : -self.num_mem_tokens
187
+ ]
188
+
189
+ if kwargs.get("output_hidden_states"):
190
+ out["hidden_states"] = [
191
+ lh[:, self.num_mem_tokens : -self.num_mem_tokens]
192
+ for lh in model_outputs.hidden_states
193
+ ]
194
+ if kwargs.get("output_attentions"):
195
+ out["attentions"] = model_outputs["attentions"]
196
+ else:
197
+ memory_state = None
198
+ out = model_outputs
199
+
200
+ return out, memory_state
201
+
202
+ def generate(self, input_ids, memory_state, attention_mask, **generate_kwargs):
203
+ if memory_state is None:
204
+ memory_state = self.set_memory(input_ids.shape)
205
+
206
+ seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask)
207
+ out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs)
208
+ return out
PreTrainedRMTConfig.py CHANGED
@@ -1,15 +1,21 @@
1
  import os
2
  import json
3
- from transformers import PretrainedConfig
 
 
 
 
 
 
 
4
 
5
  class PreTrainedRMTConfig(PretrainedConfig):
6
  """
7
- Recurrent Memory Transformer の設定クラス
8
  """
9
 
10
  model_type = "rmt"
11
 
12
- # マッピング情報を追加(設定クラスとモデルクラスの関連付け)
13
  auto_map = {
14
  "AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
15
  }
@@ -45,12 +51,10 @@ class PreTrainedRMTConfig(PretrainedConfig):
45
  self.base_model_type = dict_config.get("model_type")
46
  if self.base_model_type is None:
47
  raise ValueError("base_model_configにmodel_typeが指定されていません。")
48
- PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type
49
  """
50
  def __repr__(self):
51
  return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
52
  f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
53
  f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
54
- """
55
-
56
- PreTrainedRMTConfig.register_for_auto_class()
 
1
  import os
2
  import json
3
+ from typing import Type
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+ def register_to_hf_auto_config(
7
+ config_class: Type[PretrainedConfig],
8
+ ) -> Type[PretrainedConfig]:
9
+ AutoConfig.register(config_class.model_type, config_class)
10
+ return config_class
11
 
12
  class PreTrainedRMTConfig(PretrainedConfig):
13
  """
14
+ Recurrent Memory Transformer configuration class
15
  """
16
 
17
  model_type = "rmt"
18
 
 
19
  auto_map = {
20
  "AutoModelForCausalLM": "open_r1.rmt.RecurrentMemoryTransofomer.RecurrentMemoryTransformer"
21
  }
 
51
  self.base_model_type = dict_config.get("model_type")
52
  if self.base_model_type is None:
53
  raise ValueError("base_model_configにmodel_typeが指定されていません。")
54
+ #PreTrainedRMTConfig.model_type = "rmt_" + self.base_model_type
55
  """
56
  def __repr__(self):
57
  return f"PreTrainedRMTConfig(is_memory_all={self.is_memory_all}, max_n_segments={self.max_n_segments}, " \
58
  f"input_seg_len={self.input_seg_len}, output_seg_len={self.output_seg_len}, " \
59
  f"align='{self.align}', num_mem_tokens={self.num_mem_tokens})"
60
+ """
 
 
README.md CHANGED
@@ -1,11 +1,9 @@
1
  ---
2
  base_model: openai-community/gpt2
3
- datasets: HuggingFaceFW/fineweb-edu
4
  library_name: transformers
5
  model_name: gpt2-RMT-2-mem512
6
  tags:
7
  - generated_from_trainer
8
- - open-r1
9
  - trl
10
  - sft
11
  licence: license
@@ -13,7 +11,7 @@ licence: license
13
 
14
  # Model Card for gpt2-RMT-2-mem512
15
 
16
- This model is a fine-tuned version of [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) on the [HuggingFaceFW/fineweb-edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) dataset.
17
  It has been trained using [TRL](https://github.com/huggingface/trl).
18
 
19
  ## Quick start
@@ -29,7 +27,7 @@ print(output["generated_text"])
29
 
30
  ## Training procedure
31
 
32
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/shin2021001-osaka-city-university/huggingface/runs/nt4l8say)
33
 
34
 
35
  This model was trained with SFT.
@@ -38,7 +36,7 @@ This model was trained with SFT.
38
 
39
  - TRL: 0.15.2
40
  - Transformers: 4.50.0.dev0
41
- - Pytorch: 2.5.1
42
  - Datasets: 3.3.2
43
  - Tokenizers: 0.21.0
44
 
 
1
  ---
2
  base_model: openai-community/gpt2
 
3
  library_name: transformers
4
  model_name: gpt2-RMT-2-mem512
5
  tags:
6
  - generated_from_trainer
 
7
  - trl
8
  - sft
9
  licence: license
 
11
 
12
  # Model Card for gpt2-RMT-2-mem512
13
 
14
+ This model is a fine-tuned version of [openai-community/gpt2](https://huggingface.co/openai-community/gpt2).
15
  It has been trained using [TRL](https://github.com/huggingface/trl).
16
 
17
  ## Quick start
 
27
 
28
  ## Training procedure
29
 
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/shin2021001-osaka-city-university/huggingface/runs/p1finncz)
31
 
32
 
33
  This model was trained with SFT.
 
36
 
37
  - TRL: 0.15.2
38
  - Transformers: 4.50.0.dev0
39
+ - Pytorch: 2.5.1+cu121
40
  - Datasets: 3.3.2
41
  - Tokenizers: 0.21.0
42
 
RecurrentMemoryTransformer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
3
+ from transformers.models.auto.auto_factory import _BaseAutoModelClass
4
+ from .MemoryCell import MemoryCell
5
+ from .RecurrentWrapper import RecurrentWrapper
6
+ from .PreTrainedRMTConfig import PreTrainedRMTConfig
7
+
8
+
9
+ # @register_for_auto_class("AutoModelForCausalLM")
10
+ class RecurrentMemoryTransformer(PreTrainedModel):
11
+ """
12
+ Recurrent Memory Transformer Model Class
13
+ A transformer model that processes long context in segments and retains information using memory
14
+ """
15
+
16
+ config_class = PreTrainedRMTConfig
17
+ auto_model_class = "AutoModelForCausalLM"
18
+
19
+ # マッピングを定義してAutoクラスが適切なモデルを見つけられるようにする
20
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
21
+
22
+ # AUTO_MAPを定義(モデル名からクラスへのマッピング)
23
+ AUTO_MAP = {
24
+ "AutoModelForCausalLM": "RecurrentMemoryTransformer",
25
+ }
26
+
27
+ def __init__(self, config, base_model=None):
28
+ """
29
+ Initialization
30
+
31
+ Parameters
32
+ ----------
33
+ config : PreTrainedRMTConfig
34
+ Model configuration
35
+ base_model : PreTrainedModel, optional
36
+ Base transformer model
37
+ """
38
+ super().__init__(config)
39
+
40
+ # base_modelが指定されていない場合は、configから自動生成
41
+ if base_model is None:
42
+ # ベースモデルのタイプを確認
43
+ if not hasattr(config, "base_model_type"):
44
+ raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。")
45
+ base_model_type = config.base_model_type
46
+
47
+ # ベースモデル用の設定を作成
48
+ base_config = AutoConfig.from_pretrained(base_model_type)
49
+
50
+ # RMT固有のパラメータを除外してベースモデルの設定を作成
51
+ rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len',
52
+ 'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type']
53
+ for key, value in config.__dict__.items():
54
+ if key not in rmt_specific_params and not key.startswith('_'):
55
+ setattr(base_config, key, value)
56
+
57
+ # ベースモデルを作成
58
+ base_model = AutoModelForCausalLM.from_config(base_config)
59
+
60
+ # MemoryCellとRecurrentWrapperの初期化
61
+ memory_cell = MemoryCell(base_model, config.num_mem_tokens)
62
+ self.recurrent_wrapper = RecurrentWrapper(
63
+ memory_cell=memory_cell,
64
+ is_memory_all=config.is_memory_all,
65
+ max_n_segments=config.max_n_segments,
66
+ input_seg_len=config.input_seg_len,
67
+ output_seg_len=config.output_seg_len,
68
+ align=config.align
69
+ )
70
+
71
+ def get_base_model(self):
72
+ """
73
+ Get the base model
74
+ """
75
+ return self.recurrent_wrapper.memory_cell.model
76
+
77
+ def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None,
78
+ inputs_embeds=None, output_attentions=None, output_hidden_states=None):
79
+ """
80
+ Forward pass of the model
81
+
82
+ Parameters
83
+ ----------
84
+ input_ids : torch.Tensor, optional
85
+ Input tensor
86
+ attention_mask : torch.Tensor, optional
87
+ Attention mask
88
+ labels : torch.Tensor, optional
89
+ Label tensor
90
+ labels_mask : torch.Tensor, optional
91
+ Label mask
92
+ inputs_embeds : torch.Tensor, optional
93
+ Input embeddings
94
+ output_attentions : bool, optional
95
+ Whether to output attention weights
96
+ output_hidden_states : bool, optional
97
+ Whether to output hidden states
98
+ """
99
+ forward_kwargs = {}
100
+ if input_ids is not None:
101
+ forward_kwargs["input_ids"] = input_ids
102
+ if labels is not None:
103
+ forward_kwargs["labels"] = labels
104
+ if attention_mask is not None:
105
+ forward_kwargs["attention_mask"] = attention_mask
106
+ if labels_mask is not None:
107
+ forward_kwargs["labels_mask"] = labels_mask
108
+ if inputs_embeds is not None:
109
+ forward_kwargs["inputs_embeds"] = inputs_embeds
110
+ if output_attentions is not None:
111
+ forward_kwargs["output_attentions"] = output_attentions
112
+ if output_hidden_states is not None:
113
+ forward_kwargs["output_hidden_states"] = output_hidden_states
114
+
115
+ #forward_kwargs.update(kwargs)
116
+
117
+ # 通常の順伝播処理
118
+ out = self.recurrent_wrapper.forward(**forward_kwargs)
119
+ """
120
+ # デバッグ出力を削除(または必要に応じてコメント化)
121
+ # print(out["loss"])
122
+
123
+ # 分散環境で損失が二��計算されないよう、ワールドサイズで割る
124
+ # これは処理済みの場合は不要なので、環境変数などで制御することも可能
125
+ if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None:
126
+ # 既にDeepSpeedが処理している可能性があるため、確認が必要
127
+ # テスト目的で一時的に追加(実際の環境に合わせて調整が必要)
128
+ # world_size = torch.distributed.get_world_size()
129
+ # out["loss"] = out["loss"] / world_size
130
+ pass
131
+ """
132
+ return out
133
+
134
+ def generate(self, **kwargs):
135
+ """
136
+ Text generation
137
+ """
138
+ return self.recurrent_wrapper.generate(**kwargs)
139
+
140
+ def generate_with_tokenizer(self, tokenizer, input_text, **kwargs):
141
+ """
142
+ Text generation using tokenizer
143
+ """
144
+ return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs)
145
+
146
+ def get_input_embeddings(self):
147
+ """
148
+ Get input embeddings
149
+ """
150
+ return self.get_base_model().get_input_embeddings()
151
+
152
+ def set_input_embeddings(self, embeddings):
153
+ """
154
+ Set input embeddings
155
+ """
156
+ self.get_base_model().set_input_embeddings(embeddings)
157
+
158
+ def get_output_embeddings(self):
159
+ """
160
+ Get output embeddings
161
+ """
162
+ return self.get_base_model().get_output_embeddings()
163
+
164
+ def resize_token_embeddings(self, new_num_tokens):
165
+ """
166
+ Resize token embeddings
167
+ """
168
+ self.get_base_model().resize_token_embeddings(new_num_tokens)
169
+ return self.get_input_embeddings()
170
+
171
+ RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM")
RecurrentWrapper.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
5
+ from .PreTrainedRMTConfig import PreTrainedRMTConfig
6
+ from .MemoryCell import MemoryCell
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from transformers import PreTrainedModel
9
+
10
+ class RecurrentWrapper(torch.nn.Module):
11
+ #config_class = PreTrainedRMTConfig
12
+
13
+ def __init__(
14
+ self,
15
+ memory_cell: MemoryCell,
16
+ is_memory_all: bool,
17
+ max_n_segments: int,
18
+ input_seg_len: int,
19
+ output_seg_len: int,
20
+ align: str = "left"):
21
+
22
+ super().__init__()
23
+ self.memory_cell:MemoryCell = memory_cell
24
+ self.is_memory_all = is_memory_all # Whether to share memory state between segments
25
+ self.memory_state: torch.Tensor = None # Memory state
26
+ self.config = memory_cell.config # Model configuration
27
+ self.max_n_segments = max_n_segments # Maximum number of segments for backpropagation
28
+ self.input_seg_len = input_seg_len # Segment size
29
+ self.output_seg_len = output_seg_len
30
+ self.align = align # Segment alignment default: left
31
+
32
+ def forward(
33
+ self,
34
+ input_ids,
35
+ labels=None,
36
+ labels_mask=None,
37
+ inputs_embeds=None,
38
+ attention_mask=None,
39
+ output_attentions=None,
40
+ output_hidden_states=None,
41
+ **kwargs
42
+ ):
43
+ """Performs inference.
44
+
45
+ Parameters
46
+ ----------
47
+ input_ids : torch.Tensor
48
+ Input tensor. (batch_size, seq_len * n_segments)
49
+ labels : _type_, torch.Tensor
50
+ Input tensor. (batch_size, seq_len * n_segments)
51
+
52
+ Returns
53
+ ----------
54
+ dict
55
+ "loss" : torch.Tensor
56
+ Loss value.
57
+ "logits" : torch.Tensor
58
+ Model output.
59
+ "out[f"{key}_{seg_num}"]" : torch.Tensor
60
+ Output for each segment.
61
+ """
62
+ if self.memory_state is not None:
63
+ if self.is_memory_all is False:
64
+ self.memory_state = None
65
+ else :
66
+ self.memory_state.detach() # メモリ状態の勾配を計算しないようにする
67
+
68
+ # 入力テンソルをセグメント単位に分割する。 (セグメントは1ステップでモデルに渡される入力のサブセット)
69
+ segmented = self.segment(
70
+ self.input_seg_len,
71
+ input_ids=input_ids,
72
+ inputs_embeds=inputs_embeds,
73
+ attention_mask=attention_mask,
74
+ )
75
+
76
+ cell_outputs = [] # 各セグメントの出力を保存するリスト
77
+ for seg_num, segment in enumerate(segmented):
78
+ cell_out, self.memory_state = self.memory_cell(
79
+ **segment, memory_state=self.memory_state, **kwargs
80
+ )
81
+ cell_outputs.append(cell_out)
82
+ a = self.manage_gradients(
83
+ self.memory_state, seg_num, len(segmented)
84
+ ) # メモリ状態の勾配計算を制御する
85
+ #print(seg_num, a)
86
+
87
+ out = self.process_outputs(
88
+ cell_outputs,
89
+ labels=labels,
90
+ labels_mask=labels_mask,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ )
94
+ return out
95
+
96
+ def log(self, t, eps = 1e-20):
97
+ return torch.log(t.clamp(min = eps))
98
+
99
+ def gumbel_noise(self, t):
100
+ noise = torch.zeros_like(t).uniform_(0, 1)
101
+ return -self.log(-self.log(noise))
102
+
103
+ def gumbel_sample(self, t, temperature = 1., dim = -1):
104
+ return ((t / max(float(temperature), float(1e-10))) + self.gumbel_noise(t)).argmax(dim = dim)
105
+
106
+ def top_k(self, logits, thres = 0.9):
107
+ k = math.ceil((1 - thres) * logits.shape[-1])
108
+ val, ind = torch.topk(logits, k)
109
+ probs = torch.full_like(logits, float('-inf'))
110
+ probs.scatter_(1, ind, val)
111
+ return probs
112
+
113
+ def segment(self, seg_len, **kwargs):
114
+ """
115
+ Segments input tensors and adjusts their size. Returns a list of dicts.
116
+
117
+ Parameters
118
+ ----------
119
+ **kwargs : dict
120
+ Tensors to be segmented.
121
+ Specify tensors that need to be split in keyword argument format.
122
+ Example: segment(input_ids=tensor1, attention_mask=tensor2)
123
+
124
+ Returns
125
+ -------
126
+ segments : list of dict
127
+ List of dictionaries containing segmented tensors.
128
+ Example: [{'input_ids': segment1, 'attention_mask': segment1}, {'input_ids': segment2, 'attention_mask': segment2}, ...]
129
+
130
+ Notes
131
+ -----
132
+ - This function uses the `self.split_tensor` method, so `self` must implement it.
133
+ - Each tensor is split in a specific way by `self.split_tensor`. The same keys are stored with the same order of indices.
134
+ """
135
+ segments = [] # 各セグメントを保存するリストを初期化
136
+ for k, tensor in kwargs.items(): # keyで繰り返し
137
+ if tensor is not None:
138
+ k_segments = self.split_tensor(
139
+ tensor, seg_len
140
+ ) # 2次元テンソルを分割し、セグメント化
141
+ for s, k_seg in enumerate(k_segments):
142
+ if s < len(segments):
143
+ segments[s][k] = k_seg
144
+ else:
145
+ segments.append({k: k_seg}) # 新たな辞書 {k: k_seg} を作成し、segments リストに追加します。
146
+
147
+ return segments
148
+
149
+ def split_tensor(self, tensor, seg_len):
150
+ if self.align in {"left", None}:
151
+ split_inds = list(range(0, tensor.shape[1], seg_len)) + [
152
+ tensor.shape[1]
153
+ ]
154
+ segments = [
155
+ tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
156
+ ]
157
+ elif self.align in {"right", None}:
158
+ split_inds = (list(range(tensor.shape[1], 0, -seg_len)) + [0])[::-1]
159
+ segments = [
160
+ tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
161
+ ]
162
+ elif self.align == "center":
163
+ n_seg = math.ceil(tensor.shape[1] / seg_len)
164
+ segments = torch.chunk(tensor, n_seg, dim=1)
165
+ else:
166
+ split_inds = list(range(0, tensor.shape[1], seg_len)) + [
167
+ tensor.shape[1]
168
+ ]
169
+ segments = [
170
+ tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
171
+ ]
172
+ return segments
173
+
174
+ def process_outputs(self, cell_outputs, **kwargs):
175
+ """Calculates loss for a list of outputs. Also concatenates and returns logits.
176
+
177
+ Parameters
178
+ ----------
179
+ cell_outputs : list of torch.Tensor
180
+ List containing outputs from each segment.
181
+
182
+ Returns
183
+ -------
184
+ dict
185
+ "loss" : torch.Tensor
186
+ Loss value.
187
+ "logits" : torch.Tensor
188
+ Model output.
189
+ "out[f"{key}_{seg_num}"]" : torch.Tensor
190
+ Output for each segment.
191
+ """
192
+ out = CausalLMOutputWithCrossAttentions()
193
+ full_logits = torch.cat(
194
+ [o.logits for o in cell_outputs], dim=1
195
+ ) # セグメントごとのlogitsを結合する (batch_size, seq_len * seg_len, vocab_size)
196
+
197
+ if kwargs.get("output_hidden_states"):
198
+ full_hidden_states = tuple(
199
+ [
200
+ torch.cat(layer_hs, dim=1)
201
+ for layer_hs in zip(*[o.hidden_states for o in cell_outputs])
202
+ ]
203
+ )
204
+
205
+ labels = kwargs.get("labels")
206
+ if labels is not None: # ラベルがある場合のみlossを計算する
207
+
208
+ shift_labels = labels[..., 1:].contiguous() # DataSetでシフトされない場合
209
+ shift_logits = full_logits[..., :-1, :].contiguous()# DataSetでシフトされない場合
210
+ #shift_labels = labels.contiguous() # DataSetでシフトされる場合
211
+ #shift_logits = full_logits.contiguous() # DataSetでシフトされる場合
212
+
213
+ flat_labels = shift_labels.view(
214
+ -1
215
+ ) # バッチとセグメントの次元を結合して1次元にする (batch_size * (seq_len-1) * seg_len)
216
+ flat_logits = shift_logits.view(
217
+ -1, shift_logits.size(-1)
218
+ ) # バッチとセグメントの次元を結合して1次元にする (batch_size * (seq_len-1) * seg_len, vocab_size)
219
+
220
+ loss_fct = CrossEntropyLoss()
221
+ labels_mask = kwargs.get("labels_mask")
222
+ if labels_mask is not None:
223
+ shift_mask = labels_mask[..., :-1].contiguous()
224
+
225
+ flat_labels = flat_labels[shift_mask.view(-1)]
226
+ flat_logits = flat_logits[shift_mask.view(-1)]
227
+ out["loss"] = loss_fct(flat_logits, flat_labels)
228
+ else:
229
+ out["loss"] = 0
230
+ print("labels is None")
231
+
232
+ out["logits"] = full_logits
233
+ segment_keys = ["loss", "logits"]
234
+ if kwargs.get("output_attentions"):
235
+ segment_keys.append("attentions")
236
+ if kwargs.get("output_hidden_states"):
237
+ segment_keys.append("hidden_states")
238
+ out["hidden_states"] = full_hidden_states
239
+
240
+ for seg_num, o in enumerate(cell_outputs):
241
+ for key, value in o.items():
242
+ if any([sk in key for sk in segment_keys]):
243
+ out[f"{key}_{seg_num}"] = value
244
+
245
+ return out
246
+
247
+ def manage_gradients(self, memory_state, seg_num, seg_len):
248
+ """Controls gradient calculation for memory state
249
+
250
+ Parameters
251
+ ----------
252
+ memory_state : torch.Tensor
253
+ Memory state. (batch_size, num_mem_tokens, memory_dim)
254
+ seg_num : int
255
+ Number of the segment currently being processed.
256
+
257
+ Returns
258
+ ----------
259
+ bool
260
+ Whether to calculate gradients. True: calculate gradients, False: do not calculate gradients
261
+ """
262
+
263
+ # max_n_segments: 処理できる最大セグメント数を示すパラメータです。この値を使って、必要に応じてメモリの更新を決定します。
264
+
265
+ # seg_numが0の時はReccurentでない時なので勾配は計算する。
266
+ # 最後のほうのセグメントは勾配を計算する。
267
+ if seg_num == 0 or self.max_n_segments in {-1, None} or seg_len - seg_num <= self.max_n_segments:
268
+ self.memory_state = memory_state # Retain gradients
269
+ return True
270
+ else:
271
+ self.memory_state = memory_state.detach() # Detach to stop gradient tracking
272
+ return False
273
+
274
+ def generate_groq(
275
+ self,
276
+ input_ids,
277
+ max_length=25,
278
+ temperature=1.0,
279
+ top_k=None,
280
+ top_p=None,
281
+ do_sample=True,
282
+ pad_token_id=None,
283
+ eos_token_id=None,
284
+ **kwargs
285
+ ):
286
+ """
287
+ Generate new tokens based on the input sequence.
288
+
289
+ Parameters
290
+ ----------
291
+ input_ids : torch.Tensor
292
+ Initial input sequence. Shape: (batch_size, seq_len)
293
+ max_length : int
294
+ Maximum number of tokens to generate (including initial sequence length).
295
+ temperature : float, default 1.0
296
+ Temperature parameter for sampling. Lower values make it more deterministic.
297
+ top_k : int, optional
298
+ Used to sample from top k tokens.
299
+ top_p : float, optional
300
+ Used to filter tokens based on cumulative probability p.
301
+ do_sample : bool, default True
302
+ If True, use probabilistic sampling. If False, use greedy decoding.
303
+ pad_token_id : int, optional
304
+ ID of the padding token.
305
+ eos_token_id : int, optional
306
+ ID of the end-of-sequence token.
307
+ **kwargs : dict
308
+ Additional arguments passed to MemoryCell.
309
+
310
+ Returns
311
+ -------
312
+ torch.Tensor
313
+ Generated token sequence. Shape: (batch_size, generated_seq_len)
314
+ """
315
+ # 初期の入力シーケンスを処理
316
+ segmented = self.segment(self.input_seg_len, input_ids=input_ids)
317
+ memory_state = None
318
+ for segment in segmented:
319
+ cell_out, memory_state = self.memory_cell(
320
+ **segment, memory_state=memory_state, **kwargs
321
+ )
322
+
323
+ # 生成ループ
324
+ output_ids = input_ids
325
+ while output_ids.shape[1] < max_length:
326
+ # 最後のトークンを input_ids として使用
327
+ last_token = output_ids[:, -1:]
328
+ # MemoryCell に渡す
329
+ cell_out, memory_state = self.memory_cell(
330
+ input_ids=last_token, memory_state=memory_state, **kwargs
331
+ )
332
+ # logits を取得(最後のトークンの logits)
333
+ logits = cell_out.logits[:, -1, :]
334
+ # 次のトークンをサンプリング
335
+ next_token = self.sample_next_token(
336
+ logits, temperature, top_k, top_p, do_sample
337
+ )
338
+ # 出力シーケンスに追加
339
+ output_ids = torch.cat([output_ids, next_token], dim=1)
340
+ # 終了条件をチェック
341
+ if eos_token_id is not None and next_token.item() == eos_token_id:
342
+ break
343
+
344
+ return output_ids
345
+
346
+ def sample_next_token(self, logits, temperature=1, top_k=50, top_p=0.9, do_sample=False):
347
+ """
348
+ logits から次のトークンをサンプリングする。
349
+
350
+ Parameters
351
+ ----------
352
+ logits : torch.Tensor
353
+ トークンの予測スコア。形状: (batch_size, vocab_size)
354
+ temperature : float
355
+ サンプリング時の温度パラメータ。
356
+ top_k : int, optional
357
+ 上位 k トークンからサンプリングする場合に使用。
358
+ top_p : float, optional
359
+ 累積確率 p に基づいてトークンをフィルタリングする場合に使用。
360
+ do_sample : bool
361
+ True の場合、確率的サンプリングを使用。False の場合、貪欲法を使用。
362
+
363
+ Returns
364
+ -------
365
+ torch.Tensor
366
+ サンプリングされたトークン。形状: (batch_size, 1)
367
+ """
368
+ if do_sample:
369
+ if temperature != 1.0:
370
+ logits = logits / temperature
371
+ if top_k is not None:
372
+ logits = self.top_k_groq(logits, top_k)
373
+ if top_p is not None:
374
+ logits = self.top_p(logits, top_p)
375
+ probs = torch.softmax(logits, dim=-1)
376
+ next_token = torch.multinomial(probs, num_samples=1)
377
+ else:
378
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
379
+ return next_token
380
+
381
+ def top_k_groq(self, logits, k):
382
+ """
383
+ 上位 k トークンのみを考慮するように logits をフィルタリングする。
384
+
385
+ Parameters
386
+ ----------
387
+ logits : torch.Tensor
388
+ トークンの予測スコア。形状: (batch_size, vocab_size)
389
+ k : int
390
+ 上位 k トークンを選択。
391
+
392
+ Returns
393
+ -------
394
+ torch.Tensor
395
+ フィルタリングされた logits。形状: (batch_size, vocab_size)
396
+ """
397
+ values, indices = torch.topk(logits, k, dim=-1)
398
+ min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
399
+ return torch.where(
400
+ logits >= min_values, logits, torch.full_like(logits, float('-inf'))
401
+ )
402
+
403
+ def top_p(self, logits, p):
404
+ """
405
+ 累積確率 p に基づいてトークンをフィルタリングする。
406
+
407
+ Parameters
408
+ ----------
409
+ logits : torch.Tensor
410
+ トークンの予測スコア。形状: (batch_size, vocab_size)
411
+ p : float
412
+ 累積確率の閾値。
413
+
414
+ Returns
415
+ -------
416
+ torch.Tensor
417
+ フィルタリングされた logits。形状: (batch_size, vocab_size)
418
+ """
419
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
420
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
421
+ sorted_indices_to_remove = cumulative_probs > p
422
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
423
+ sorted_indices_to_remove[:, 0] = 0
424
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
425
+ logits.scatter_(1, indices_to_remove, float('-inf'))
426
+ return logits
427
+
428
+ def generate_default(self, input_ids, attention_mask = None, **generate_kwargs):
429
+ memory_state = None
430
+ segmented = self.segment(self.input_seg_len, input_ids=input_ids, attention_mask=attention_mask)
431
+
432
+ for seg_num, segment in enumerate(segmented[:-1]):
433
+ cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state)
434
+
435
+ final_segment = segmented[-1]
436
+ out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs)
437
+
438
+ return out
439
+
440
+ def generate(self, input_ids:torch.Tensor, **generate_kwargs):
441
+ with torch.no_grad():
442
+ if self.is_memory_all is False:
443
+ self.memory_state = None
444
+ elif self.memory_state is not None:
445
+ self.memory_state.detach() # メモリ状態の勾配を計算しないようにする
446
+
447
+ # 入力テンソルをセグメント化してサイズを調整 return: [{'input_ids': 分割1, 'attention_mask': 分割1}, {'input_ids': 分割2, 'attention_mask': 分割2}, ...]
448
+ segmented = self.segment(self.input_seg_len, input_ids=input_ids)
449
+
450
+ for seg_num, segment in enumerate(segmented[:-1]): # 最後のセグメント以外
451
+ # メモリセルに入力テンソルを渡し、出力と新しいメモリ状態を取得
452
+ cell_out, self.memory_state = self.memory_cell(
453
+ **segment, memory_state=self.memory_state, output_hidden_states=True
454
+ )
455
+
456
+ curr_segment = segmented[-1]
457
+ """
458
+ outs = []
459
+ for i in range(math.ceil(generate_kwargs["max_length"] / self.input_seg_len)):
460
+ out = self.memory_cell.generate(
461
+ **curr_segment,
462
+ memory_state=self.memory_state,
463
+ max_length=min(generate_kwargs["max_length"] - i * self.input_seg_len, self.input_seg_len - curr_segment["input_ids"].shape[-1]),
464
+ **generate_kwargs)
465
+ outs.append(out)
466
+
467
+ for out in outs:
468
+ for key, value in out.items():
469
+ curr_segment[key] = torch.cat((curr_segment[key], value), dim = -1)
470
+ self.memory_state = out["memory_state"]
471
+ """
472
+
473
+ output_ids = None
474
+ if generate_kwargs.get("max_length") is None:
475
+ length = generate_kwargs.get("max_new_tokens", 25)
476
+ else:
477
+ length = generate_kwargs.get("max_length") - curr_segment["input_ids"].shape[-1]
478
+
479
+ for ind in range(length):
480
+ # メモリセルに入力テンソルを渡し、出力と新しいメモリ状態を取得
481
+ out, next_memories = self.memory_cell(**curr_segment, memory_state=self.memory_state, output_hidden_states=True)
482
+ logits = out["logits"][:,-1] # (batch_size, vocab_size)
483
+ sampled = self.sample_next_token(logits, temperature = generate_kwargs.get("temperature", 1), top_k = generate_kwargs.get("top_k", 0.9), top_p = generate_kwargs.get("top_p", 0.9), do_sample = generate_kwargs.get("do_sample", False)) # サンプリング (batch_size, 1)
484
+ #filtered_logits = self.top_k(logits, generate_kwargs.get("top_k", 0.9)) # トップkの確率を取得
485
+ #sampled = self.gumbel_sample(filtered_logits, temperature = generate_kwargs.get("temperture", 1)).unsqueeze(1) # サンプリング (batch_size, 1)
486
+
487
+ output_ids = sampled if output_ids is None else torch.cat((output_ids, sampled), dim = 1)
488
+
489
+ curr_segment["input_ids"] = torch.cat((curr_segment["input_ids"], sampled), dim = -1) # セグメントにサンプリングされたトークンを追加 (batch_size, seq_len)
490
+ #curr_segment["attention_mask"] = torch.cat((curr_segment["attention_mask"], torch.ones_like(sampled)), dim = -1) # セグメントのアテンションマスクを更新
491
+
492
+ if curr_segment["input_ids"].shape[-1] > self.input_seg_len: # セグメントサイズを超えた場合
493
+ for key, value in curr_segment.items():
494
+ curr_segment[key] = value[:, -1:] # セグメントサイズに切り詰める
495
+ self.memory_state = next_memories # メモリ状態を更新
496
+
497
+ return output_ids
498
+
499
+ def generate_with_tokenizer(self, tokenizer, input_text, **generate_kwargs):
500
+ if isinstance(input_text, str):
501
+ tok = tokenizer(input_text, return_tensors="pt")
502
+ tok["input_ids"] = tok["input_ids"]
503
+ tok["attention_mask"] = tok["attention_mask"]
504
+ else:
505
+ tok = tokenizer(input_text)
506
+ for k, v in tok.items():
507
+ pd = tokenizer.pad_token_id if k != 'attention_mask' else 0
508
+ tok[k] = pad_sequence([torch.tensor(o) for o in v], padding_value=pd, padding_side="left").T
509
+
510
+ output_ids = self.generate(tok["input_ids"], **generate_kwargs)
511
+
512
+ if isinstance(input_text, str):
513
+ return tokenizer.decode(torch.cat((tok["input_ids"][0], output_ids[0]), dim=0), skip_special_tokens=True)
514
+ else:
515
+ return tokenizer.batch_decode(torch.cat((tok["input_ids"], output_ids), dim=-1), skip_special_tokens=True)
516
+
517
+ def can_generate(self):
518
+ return True
519
+
all_results.json CHANGED
@@ -3,10 +3,10 @@
3
  "eval_samples": 100,
4
  "eval_samples_per_second": 376.2,
5
  "eval_steps_per_second": 23.56,
6
- "total_flos": 5419008396361728.0,
7
- "train_loss": 4.076150745858688,
8
- "train_runtime": 7573.4415,
9
  "train_samples": 19883,
10
- "train_samples_per_second": 87.49,
11
- "train_steps_per_second": 2.734
12
  }
 
3
  "eval_samples": 100,
4
  "eval_samples_per_second": 376.2,
5
  "eval_steps_per_second": 23.56,
6
+ "total_flos": 5418484972388352.0,
7
+ "train_loss": 3.606253622488408,
8
+ "train_runtime": 424.9732,
9
  "train_samples": 19883,
10
+ "train_samples_per_second": 48.742,
11
+ "train_steps_per_second": 1.522
12
  }
config.json CHANGED
@@ -103,12 +103,12 @@
103
  "embd_pdrop": 0.1,
104
  "eos_token_id": 50256,
105
  "initializer_range": 0.02,
106
- "input_seg_len": 16,
107
  "is_memory_all": false,
108
  "layer_norm_epsilon": 1e-05,
109
  "max_n_segments": 2,
110
  "memory_size": 512,
111
- "model_type": "rmt_gpt2",
112
  "n_ctx": 1024,
113
  "n_embd": 768,
114
  "n_head": 12,
 
103
  "embd_pdrop": 0.1,
104
  "eos_token_id": 50256,
105
  "initializer_range": 0.02,
106
+ "input_seg_len": 512,
107
  "is_memory_all": false,
108
  "layer_norm_epsilon": 1e-05,
109
  "max_n_segments": 2,
110
  "memory_size": 512,
111
+ "model_type": "rmt",
112
  "n_ctx": 1024,
113
  "n_embd": 768,
114
  "n_head": 12,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:65cd08da7dc4048511a86bef339939ec8531258568d7775e32f42921c96aaab4
3
  size 248915448
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41c0d6e17ff62620d8f534dc2766060257a8bd950e39f3902a2e65e00a21481c
3
  size 248915448
train_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "total_flos": 5419008396361728.0,
3
- "train_loss": 4.076150745858688,
4
- "train_runtime": 7573.4415,
5
  "train_samples": 19883,
6
- "train_samples_per_second": 87.49,
7
- "train_steps_per_second": 2.734
8
  }
 
1
  {
2
+ "total_flos": 5418484972388352.0,
3
+ "train_loss": 3.606253622488408,
4
+ "train_runtime": 424.9732,
5
  "train_samples": 19883,
6
+ "train_samples_per_second": 48.742,
7
+ "train_steps_per_second": 1.522
8
  }
trainer_state.json CHANGED
The diff for this file is too large to render. See raw diff
 
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3e506da221a187e12d0f07664922d412eb372d52d46c7a3b6e4d2d2ee1a0abcd
3
  size 7352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbb45d4b8223f141e7950f15066fdb3796697a543d5274ffce9e5110eceddf62
3
  size 7352