booydar commited on
Commit
2a081d3
·
verified ·
1 Parent(s): f5425a8

Upload RMTForReasoning

Browse files
Files changed (2) hide show
  1. config.json +5 -5
  2. huggingface.py +451 -0
config.json CHANGED
@@ -1,20 +1,20 @@
1
  {
2
  "answer_token_id": 10,
3
  "architectures": [
4
- "RMT"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "language_modeling.RMTConfig",
8
- "AutoModel": "language_modeling.RMT"
9
  },
10
  "base_model_name": "HuggingFaceTB/SmolLM2-135M",
11
  "bos_token_id": 0,
12
  "eos_token_id": 0,
13
  "max_n_segments": 10,
14
- "memory_cell_cls": "modeling_rmt.language_modeling:MemoryCell",
15
  "model_type": "rmt",
16
  "num_mem_tokens": 32,
17
- "recurrent_wrapper_cls": "modeling_rmt.experimental:RecurrentWrapperNoSegmentationGenerate",
18
  "think_token_id": 8,
19
  "torch_dtype": "float32",
20
  "transformers_version": "4.54.1"
 
1
  {
2
  "answer_token_id": 10,
3
  "architectures": [
4
+ "RMTForReasoning"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "huggingface.RMTConfig",
8
+ "AutoModel": "huggingface.RMTForReasoning"
9
  },
10
  "base_model_name": "HuggingFaceTB/SmolLM2-135M",
11
  "bos_token_id": 0,
12
  "eos_token_id": 0,
13
  "max_n_segments": 10,
14
+ "memory_cell_cls": "MemoryCell",
15
  "model_type": "rmt",
16
  "num_mem_tokens": 32,
17
+ "recurrent_wrapper_cls": "RecurrentWrapperNoSegmentationGenerate",
18
  "think_token_id": 8,
19
  "torch_dtype": "float32",
20
  "transformers_version": "4.54.1"
huggingface.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import CrossEntropyLoss
4
+
5
+ from transformers import StoppingCriteria
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+
9
+
10
+ class RMTConfig(PretrainedConfig):
11
+ model_type = "rmt"
12
+
13
+ def __init__(self,
14
+ base_model_name="HuggingFaceTB/SmolLM2-135M",
15
+ num_mem_tokens=16,
16
+ max_n_segments=10,
17
+ think_token_id=None,
18
+ answer_token_id=None,
19
+ bos_token_id=None,
20
+ eos_token_id=None,
21
+ **kwargs):
22
+ super().__init__(**kwargs)
23
+ self.base_model_name = base_model_name
24
+ self.num_mem_tokens = num_mem_tokens
25
+ self.max_n_segments = max_n_segments
26
+ self.think_token_id = think_token_id
27
+ self.answer_token_id = answer_token_id
28
+ self.bos_token_id = bos_token_id
29
+ self.eos_token_id = eos_token_id
30
+ self.memory_cell_cls = "MemoryCell"
31
+ self.recurrent_wrapper_cls = "RecurrentWrapperNoSegmentationGenerate"
32
+
33
+ def get(self, attr: str, default=None):
34
+ if hasattr(self, attr):
35
+ return getattr(self, attr)
36
+ else:
37
+ return default
38
+
39
+
40
+ class RMTForReasoning(PreTrainedModel):
41
+ config_class = RMTConfig
42
+
43
+ def __init__(self, config: RMTConfig, **kwargs):
44
+ super().__init__(config, **kwargs)
45
+ from transformers import AutoConfig, AutoModelForCausalLM
46
+ base_config = AutoConfig.from_pretrained(config.base_model_name)
47
+ base_model = AutoModelForCausalLM.from_config(base_config)
48
+
49
+ self.rmt_config = config
50
+ memory_cell = MemoryCell(base_model, num_mem_tokens=config.num_mem_tokens)
51
+ self.rmt = RecurrentWrapperNoSegmentationGenerate(
52
+ memory_cell,
53
+ max_n_segments=config.max_n_segments,
54
+ think_token_id=config.think_token_id,
55
+ answer_token_id=config.answer_token_id,
56
+ bos_token_id=config.bos_token_id,
57
+ eos_token_id=config.eos_token_id
58
+ )
59
+
60
+ def forward(self, *args, **kwargs):
61
+ return self.rmt(*args, **kwargs)
62
+
63
+ def generate(self, *args, **kwargs):
64
+ return self.rmt.generate(*args, **kwargs)
65
+
66
+ def load_state_dict(self, state_dict, strict=True, assign=False):
67
+ try:
68
+ return super().load_state_dict(state_dict, strict, assign)
69
+ except RuntimeError:
70
+ print("Failed to load state, retrying with RMT loader.")
71
+ self.rmt.load_state_dict(state_dict, strict=True, assign=assign)
72
+ print("Success!")
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs):
76
+ from transformers.utils.hub import cached_file, HfHubHTTPError
77
+ import torch
78
+
79
+ if config is None:
80
+ config = RMTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
81
+
82
+ model = cls(config)
83
+
84
+ state_dict = None
85
+ try:
86
+ weights_path = cached_file(pretrained_model_name_or_path, "model.safetensors", **kwargs)
87
+ from safetensors.torch import load_file
88
+ state_dict = load_file(weights_path, device="cpu")
89
+ except (OSError, HfHubHTTPError):
90
+ try:
91
+ weights_path = cached_file(pretrained_model_name_or_path, "pytorch_model.bin", **kwargs)
92
+ state_dict = torch.load(weights_path, map_location="cpu")
93
+ except (OSError, HfHubHTTPError):
94
+ print(f"Warning: Could not find weights for {pretrained_model_name_or_path}. "
95
+ f"The model is initialized randomly.")
96
+
97
+ if state_dict is not None:
98
+ model.load_state_dict(state_dict, strict=False)
99
+
100
+ return model
101
+
102
+
103
+ class MemoryCell(torch.nn.Module):
104
+ def __init__(self, base_model, num_mem_tokens):
105
+ super().__init__()
106
+ self.model = base_model
107
+ self.create_memory(num_mem_tokens)
108
+
109
+ def create_memory(self, num_mem_tokens):
110
+ self.num_mem_tokens = num_mem_tokens
111
+ embeddings = self.model.get_input_embeddings()
112
+ memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size)
113
+ memory_weights = torch.randn((num_mem_tokens, memory_dim)) * embeddings.weight.data.std()
114
+ self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))
115
+
116
+ self.read_memory_position = range(num_mem_tokens)
117
+ self.write_memory_position = range(-num_mem_tokens, 0)
118
+
119
+ def set_memory(self, input_shape):
120
+ memory = self.memory.repeat(input_shape[0], 1, 1)
121
+ return memory
122
+
123
+ def forward(self, input_ids, memory_state=None, **kwargs):
124
+ if memory_state is None:
125
+ memory_state = self.set_memory(input_ids.shape)
126
+
127
+ seg_kwargs = self.process_input(input_ids, memory_state, write_mem=True, **kwargs)
128
+ out = self.model(**seg_kwargs)
129
+ out, new_memory_state = self.process_output(out, **kwargs)
130
+
131
+ return out, new_memory_state
132
+
133
+ def generate(self, input_ids, memory_state, attention_mask=None, **generate_kwargs):
134
+ if memory_state is None:
135
+ memory_state = self.set_memory(input_ids.shape)
136
+
137
+ seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask, write_mem=False)
138
+ out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'],
139
+ attention_mask=seg_kwargs['attention_mask'],
140
+ **generate_kwargs)
141
+ return out
142
+
143
+ def process_input(self, input_ids, memory_state, write_mem, **kwargs):
144
+ seg_kwargs = dict(**kwargs)
145
+
146
+ inputs_embeds = kwargs.get('inputs_embeds')
147
+ if inputs_embeds is None:
148
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
149
+
150
+ if self.num_mem_tokens > 0:
151
+ if write_mem:
152
+ inputs_embeds = torch.cat([memory_state, inputs_embeds, memory_state], dim=1)
153
+ else:
154
+ inputs_embeds = torch.cat([memory_state, inputs_embeds], dim=1)
155
+
156
+ seg_kwargs['input_ids'] = None
157
+ seg_kwargs['inputs_embeds'] = inputs_embeds
158
+ if kwargs.get('attention_mask') is not None:
159
+ seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], inputs_embeds.shape)
160
+ seg_kwargs['output_hidden_states'] = True
161
+ return seg_kwargs
162
+
163
+ def pad_attention_mask(self, attention_mask, shape):
164
+ if self.num_mem_tokens in {0, None}:
165
+ return attention_mask
166
+ else:
167
+ mask = torch.ones(*shape[:2], dtype=torch.int64).to(attention_mask.device)
168
+ mask[:, self.num_mem_tokens: self.num_mem_tokens + attention_mask.shape[1]] = attention_mask
169
+ return mask
170
+
171
+ def process_output(self, model_outputs, **kwargs):
172
+ if self.num_mem_tokens not in {0, None}:
173
+ out = CausalLMOutputWithCrossAttentions()
174
+ memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens:]
175
+ out['logits'] = model_outputs.logits[:, self.num_mem_tokens:-self.num_mem_tokens]
176
+
177
+ if kwargs.get('output_hidden_states'):
178
+ out['hidden_states'] = [lh[:, self.num_mem_tokens:-self.num_mem_tokens]
179
+ for lh in model_outputs.hidden_states]
180
+ if kwargs.get('output_attentions'):
181
+ out['attentions'] = model_outputs['attentions']
182
+ else:
183
+ memory_state = None
184
+ out = model_outputs
185
+
186
+ return out, memory_state
187
+
188
+
189
+ class RecurrentWrapper(torch.nn.Module):
190
+ def __init__(self, memory_cell, **rmt_kwargs):
191
+ super().__init__()
192
+ self.memory_cell = memory_cell
193
+ self.rmt_config = rmt_kwargs
194
+
195
+ def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None,
196
+ output_attentions=None, output_hidden_states=None):
197
+ memory_state = None
198
+ segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask)
199
+
200
+ cell_outputs = []
201
+ for seg_num, segment in enumerate(segmented):
202
+ cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
203
+ cell_outputs.append(cell_out)
204
+ memory_state = self.manage_gradients(memory_state, seg_num)
205
+
206
+ out = self.process_outputs(cell_outputs, labels=labels,
207
+ labels_mask=labels_mask,
208
+ output_attentions=output_attentions,
209
+ output_hidden_states=output_hidden_states)
210
+ return out
211
+
212
+ def generate(self, input_ids, attention_mask=None, **generate_kwargs):
213
+ memory_state = None
214
+ segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask)
215
+
216
+ for seg_num, segment in enumerate(segmented[:-1]):
217
+ cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
218
+
219
+ final_segment = segmented[-1]
220
+ out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs)
221
+
222
+ return out
223
+
224
+ def segment(self, **kwargs):
225
+ segments = []
226
+ for k, tensor in kwargs.items():
227
+ if tensor is not None:
228
+ k_segments = self.split_tensor(tensor)
229
+ for s, k_seg in enumerate(k_segments):
230
+ if s < len(segments):
231
+ segments[s][k] = k_seg
232
+ else:
233
+ segments.append({k: k_seg})
234
+
235
+ return segments
236
+
237
+ def split_tensor(self, tensor):
238
+ align = self.rmt_config.get('segment_alignment')
239
+ segment_size = self.rmt_config.get('segment_size')
240
+ if align in {'left', None}:
241
+ split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]]
242
+ segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
243
+ elif align in {'right', None}:
244
+ split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1]
245
+ segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])]
246
+ elif align == 'center':
247
+ n_seg = math.ceil(tensor.shape[1] / segment_size)
248
+ segments = torch.chunk(tensor, n_seg, dim=1)
249
+ else:
250
+ raise NotImplementedError
251
+ return segments
252
+
253
+ def process_outputs(self, cell_outputs, **kwargs):
254
+ out = CausalLMOutputWithCrossAttentions()
255
+ full_logits = torch.cat([o.logits for o in cell_outputs], dim=1)
256
+ full_hidden_states = tuple([torch.cat(layer_hs, dim=1)
257
+ for layer_hs in zip(*[o.hidden_states for o in cell_outputs])])
258
+
259
+ labels = kwargs.get('labels')
260
+ if labels is not None:
261
+ shift_labels = labels[..., 1:].contiguous()
262
+ shift_logits = full_logits[..., :-1, :].contiguous()
263
+ flat_labels = shift_labels.view(-1)
264
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
265
+
266
+ loss_fct = CrossEntropyLoss()
267
+ labels_mask = kwargs.get('labels_mask')
268
+ if labels_mask is not None:
269
+ shift_mask = labels_mask[..., :-1].contiguous()
270
+
271
+ flat_labels = flat_labels[shift_mask.view(-1)]
272
+ flat_logits = flat_logits[shift_mask.view(-1)]
273
+
274
+ out['loss'] = loss_fct(flat_logits, flat_labels)
275
+ else:
276
+ out['loss'] = 0
277
+
278
+ out['logits'] = full_logits
279
+ segment_keys = ['loss', 'logits']
280
+ if kwargs.get('output_attentions'):
281
+ segment_keys.append('attentions')
282
+ if kwargs.get('output_hidden_states'):
283
+ segment_keys.append('hidden_states')
284
+ out['hidden_states'] = full_hidden_states
285
+
286
+ return out
287
+
288
+ def manage_gradients(self, memory_state, seg_num):
289
+ k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
290
+ if seg_num == 0 \
291
+ or k2 in {-1, None} \
292
+ or seg_num + k2 > max_n_segments:
293
+ return memory_state
294
+
295
+ memory_state = memory_state.detach()
296
+ return memory_state
297
+
298
+ def gradient_checkpointing_enable(self, *args, **kwargs):
299
+ self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs)
300
+
301
+
302
+ class RecurrentWrapperNoSegmentation(RecurrentWrapper):
303
+ def forward(self, segments, labels, output_attentions=None, output_hidden_states=None):
304
+ memory_state = None
305
+
306
+ cell_outputs = []
307
+ for seg_num, segment in enumerate(segments):
308
+ cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
309
+ attention_mask=segment['attention_mask'],
310
+ memory_state=memory_state, output_hidden_states=True)
311
+ cell_outputs.append(cell_out)
312
+ memory_state = self.manage_gradients(memory_state, seg_num)
313
+
314
+ out = self.process_outputs(cell_outputs, segments,
315
+ output_attentions=output_attentions,
316
+ output_hidden_states=output_hidden_states)
317
+ return out
318
+
319
+ def generate(self, segments, **generate_kwargs):
320
+ raise NotImplementedError("Generation not implemented for this wrapper.")
321
+
322
+ def process_outputs(self, cell_outputs, segments, **kwargs):
323
+ out = CausalLMOutputWithCrossAttentions()
324
+ proxy_out = {}
325
+ for seg_num, segment in enumerate(segments):
326
+ cell_out = cell_outputs[seg_num]
327
+
328
+ full_logits = cell_out.logits
329
+
330
+ labels = segment.get('labels')
331
+ if labels is not None:
332
+ shift_labels = labels[..., 1:].contiguous()
333
+ shift_logits = full_logits[..., :-1, :].contiguous()
334
+ flat_labels = shift_labels.view(-1)
335
+ flat_logits = shift_logits.view(-1, shift_logits.size(-1))
336
+
337
+ loss_fct = CrossEntropyLoss()
338
+ labels_mask = segment.get('labels_mask')
339
+ if labels_mask is not None:
340
+ shift_mask = labels_mask[..., :-1].contiguous()
341
+
342
+ flat_labels = flat_labels[shift_mask.view(-1)]
343
+ flat_logits = flat_logits[shift_mask.view(-1)]
344
+
345
+ if labels_mask.sum() == 0:
346
+ loss_value = 0
347
+ else:
348
+ loss_value = loss_fct(flat_logits, flat_labels)
349
+
350
+ proxy_out[f'loss_{seg_num}'] = loss_value
351
+ else:
352
+ proxy_out[f'loss_{seg_num}'] = 0
353
+
354
+ segment_keys = ['loss']
355
+ if kwargs.get('output_attentions'):
356
+ segment_keys.append('attentions')
357
+ if kwargs.get('output_hidden_states'):
358
+ segment_keys.append('hidden_states')
359
+
360
+ for key, value in cell_out.items():
361
+ if any([sk in key for sk in segment_keys]):
362
+ proxy_out[f'{key}_{seg_num}'] = value
363
+
364
+ num_segments = len(segments)
365
+ out['loss'] = sum([proxy_out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments
366
+ out['logits'] = torch.cat([cell_out.logits for cell_out in cell_outputs], dim=1)
367
+ # print(out.keys(), out.loss)
368
+
369
+ return out
370
+
371
+ def gradient_checkpointing_enable(self, *args, **kwargs):
372
+ if hasattr(self.memory_cell.model, "gradient_checkpointing_enable"):
373
+ return self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs)
374
+
375
+
376
+ class StopOnSpecialTokenCriteria(StoppingCriteria):
377
+ def __init__(self, special_token_ids):
378
+ self.special_token_ids = set(special_token_ids)
379
+
380
+ def __call__(self, input_ids, scores, **kwargs):
381
+ last_token = input_ids[0, -1].item()
382
+ return last_token in self.special_token_ids
383
+
384
+
385
+ class RecurrentWrapperNoSegmentationGenerate(RecurrentWrapperNoSegmentation):
386
+ def forward(self, segments, labels, output_attentions=None, output_hidden_states=None):
387
+ memory_state = None
388
+
389
+ cell_outputs = []
390
+ for seg_num, segment in enumerate(segments):
391
+ cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
392
+ attention_mask=segment['attention_mask'],
393
+ memory_state=memory_state, output_hidden_states=True)
394
+ cell_outputs.append(cell_out)
395
+ self.manage_gradients(memory_state, seg_num)
396
+
397
+ out = self.process_outputs(cell_outputs, segments,
398
+ output_attentions=output_attentions,
399
+ output_hidden_states=output_hidden_states)
400
+ return out
401
+
402
+ def generate(self, segments, **kwargs):
403
+ memory_state = None
404
+
405
+ for seg_num, segment in enumerate(segments):
406
+ cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
407
+ attention_mask=segment['attention_mask'],
408
+ memory_state=memory_state, output_hidden_states=True)
409
+
410
+ generated_segments = []
411
+ for seg_num in range(len(segments), self.rmt_config.get("max_n_segments", 32)):
412
+ output_ids, memory_state = self.generate_segment(memory_state=memory_state, **kwargs)
413
+ generated_segments.append(output_ids)
414
+
415
+ if self.all_done(generated_segments):
416
+ break
417
+
418
+ return generated_segments
419
+
420
+ def generate_segment(self, memory_state, **kwargs):
421
+ input_ids = self.get_bos_tensor(memory_state)
422
+ attention_mask = torch.ones_like(input_ids).bool()
423
+
424
+ generated = self.memory_cell.generate(
425
+ input_ids=input_ids,
426
+ attention_mask=attention_mask,
427
+ memory_state=memory_state,
428
+ stopping_criteria=self.make_custom_stopping_criteria(),
429
+ **kwargs
430
+ )
431
+
432
+ # Update memory state from generation
433
+ fwd_inputs = torch.cat((input_ids, generated), dim=1)[:, :-1]
434
+ _, memory_state = self.memory_cell(input_ids=fwd_inputs, memory_state=memory_state)
435
+
436
+ return generated, memory_state
437
+
438
+ def get_bos_tensor(self, memory_state):
439
+ bos = self.rmt_config["bos_token_id"]
440
+ bos_tensor = torch.tensor([bos] * memory_state.shape[0]).reshape(-1, 1)
441
+ return bos_tensor.to(memory_state.device)
442
+
443
+ def all_done(self, generated_segments):
444
+ eos = self.rmt_config['eos_token_id']
445
+ bs = generated_segments[0].shape[0]
446
+ have_eos = [any([eos in seg[i] for seg in generated_segments]) for i in range(bs)]
447
+ all_done = all(have_eos)
448
+ return all_done
449
+
450
+ def make_custom_stopping_criteria(self):
451
+ return [StopOnSpecialTokenCriteria([self.rmt_config['think_token_id'], self.rmt_config['answer_token_id']])]