import math import torch from torch.nn import CrossEntropyLoss from transformers import StoppingCriteria from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions class RMTConfig(PretrainedConfig): model_type = "rmt" def __init__(self, base_model_name="HuggingFaceTB/SmolLM2-135M", num_mem_tokens=16, max_n_segments=10, think_token_id=None, answer_token_id=None, bos_token_id=None, eos_token_id=None, **kwargs): super().__init__(**kwargs) self.base_model_name = base_model_name self.num_mem_tokens = num_mem_tokens self.max_n_segments = max_n_segments self.think_token_id = think_token_id self.answer_token_id = answer_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.memory_cell_cls = "MemoryCell" self.recurrent_wrapper_cls = "RecurrentWrapperNoSegmentationGenerate" def get(self, attr: str, default=None): if hasattr(self, attr): return getattr(self, attr) else: return default class RMTForReasoning(PreTrainedModel): config_class = RMTConfig def __init__(self, config: RMTConfig, **kwargs): super().__init__(config, **kwargs) from transformers import AutoConfig, AutoModelForCausalLM base_config = AutoConfig.from_pretrained(config.base_model_name) base_model = AutoModelForCausalLM.from_config(base_config) self.rmt_config = config memory_cell = MemoryCell(base_model, num_mem_tokens=config.num_mem_tokens) self.rmt = RecurrentWrapperNoSegmentationGenerate( memory_cell, max_n_segments=config.max_n_segments, think_token_id=config.think_token_id, answer_token_id=config.answer_token_id, bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id ) def forward(self, *args, **kwargs): return self.rmt(*args, **kwargs) def generate(self, *args, **kwargs): return self.rmt.generate(*args, **kwargs) def load_state_dict(self, state_dict, strict=True, assign=False): try: return super().load_state_dict(state_dict, strict, assign) except RuntimeError: print("Failed to load state, retrying with RMT loader.") self.rmt.load_state_dict(state_dict, strict=True, assign=assign) print("Success!") @classmethod def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs): from transformers.utils.hub import cached_file, HfHubHTTPError import torch if config is None: config = RMTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) state_dict = None try: weights_path = cached_file(pretrained_model_name_or_path, "model.safetensors", **kwargs) from safetensors.torch import load_file state_dict = load_file(weights_path, device="cpu") except (OSError, HfHubHTTPError): try: weights_path = cached_file(pretrained_model_name_or_path, "pytorch_model.bin", **kwargs) state_dict = torch.load(weights_path, map_location="cpu") except (OSError, HfHubHTTPError): print(f"Warning: Could not find weights for {pretrained_model_name_or_path}. " f"The model is initialized randomly.") if state_dict is not None: model.load_state_dict(state_dict, strict=False) return model class MemoryCell(torch.nn.Module): def __init__(self, base_model, num_mem_tokens): super().__init__() self.model = base_model self.create_memory(num_mem_tokens) def create_memory(self, num_mem_tokens): self.num_mem_tokens = num_mem_tokens embeddings = self.model.get_input_embeddings() memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size) memory_weights = torch.randn((num_mem_tokens, memory_dim)) * embeddings.weight.data.std() self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True)) self.read_memory_position = range(num_mem_tokens) self.write_memory_position = range(-num_mem_tokens, 0) def set_memory(self, input_shape): memory = self.memory.repeat(input_shape[0], 1, 1) return memory def forward(self, input_ids, memory_state=None, **kwargs): if memory_state is None: memory_state = self.set_memory(input_ids.shape) seg_kwargs = self.process_input(input_ids, memory_state, write_mem=True, **kwargs) out = self.model(**seg_kwargs) out, new_memory_state = self.process_output(out, **kwargs) return out, new_memory_state def generate(self, input_ids, memory_state, attention_mask=None, **generate_kwargs): if memory_state is None: memory_state = self.set_memory(input_ids.shape) seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask, write_mem=False) out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs) return out def process_input(self, input_ids, memory_state, write_mem, **kwargs): seg_kwargs = dict(**kwargs) inputs_embeds = kwargs.get('inputs_embeds') if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) if self.num_mem_tokens > 0: if write_mem: inputs_embeds = torch.cat([memory_state, inputs_embeds, memory_state], dim=1) else: inputs_embeds = torch.cat([memory_state, inputs_embeds], dim=1) seg_kwargs['input_ids'] = None seg_kwargs['inputs_embeds'] = inputs_embeds if kwargs.get('attention_mask') is not None: seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], inputs_embeds.shape) seg_kwargs['output_hidden_states'] = True return seg_kwargs def pad_attention_mask(self, attention_mask, shape): if self.num_mem_tokens in {0, None}: return attention_mask else: mask = torch.ones(*shape[:2], dtype=torch.int64).to(attention_mask.device) mask[:, self.num_mem_tokens: self.num_mem_tokens + attention_mask.shape[1]] = attention_mask return mask def process_output(self, model_outputs, **kwargs): if self.num_mem_tokens not in {0, None}: out = CausalLMOutputWithCrossAttentions() memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens:] out['logits'] = model_outputs.logits[:, self.num_mem_tokens:-self.num_mem_tokens] if kwargs.get('output_hidden_states'): out['hidden_states'] = [lh[:, self.num_mem_tokens:-self.num_mem_tokens] for lh in model_outputs.hidden_states] if kwargs.get('output_attentions'): out['attentions'] = model_outputs['attentions'] else: memory_state = None out = model_outputs return out, memory_state class RecurrentWrapper(torch.nn.Module): def __init__(self, memory_cell, **rmt_kwargs): super().__init__() self.memory_cell = memory_cell self.rmt_config = rmt_kwargs def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, output_attentions=None, output_hidden_states=None): memory_state = None segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask) cell_outputs = [] for seg_num, segment in enumerate(segmented): cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True) cell_outputs.append(cell_out) memory_state = self.manage_gradients(memory_state, seg_num) out = self.process_outputs(cell_outputs, labels=labels, labels_mask=labels_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states) return out def generate(self, input_ids, attention_mask=None, **generate_kwargs): memory_state = None segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask) for seg_num, segment in enumerate(segmented[:-1]): cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True) final_segment = segmented[-1] out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs) return out def segment(self, **kwargs): segments = [] for k, tensor in kwargs.items(): if tensor is not None: k_segments = self.split_tensor(tensor) for s, k_seg in enumerate(k_segments): if s < len(segments): segments[s][k] = k_seg else: segments.append({k: k_seg}) return segments def split_tensor(self, tensor): align = self.rmt_config.get('segment_alignment') segment_size = self.rmt_config.get('segment_size') if align in {'left', None}: split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]] segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] elif align in {'right', None}: split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1] segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] elif align == 'center': n_seg = math.ceil(tensor.shape[1] / segment_size) segments = torch.chunk(tensor, n_seg, dim=1) else: raise NotImplementedError return segments def process_outputs(self, cell_outputs, **kwargs): out = CausalLMOutputWithCrossAttentions() full_logits = torch.cat([o.logits for o in cell_outputs], dim=1) full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])]) labels = kwargs.get('labels') if labels is not None: shift_labels = labels[..., 1:].contiguous() shift_logits = full_logits[..., :-1, :].contiguous() flat_labels = shift_labels.view(-1) flat_logits = shift_logits.view(-1, shift_logits.size(-1)) loss_fct = CrossEntropyLoss() labels_mask = kwargs.get('labels_mask') if labels_mask is not None: shift_mask = labels_mask[..., :-1].contiguous() flat_labels = flat_labels[shift_mask.view(-1)] flat_logits = flat_logits[shift_mask.view(-1)] out['loss'] = loss_fct(flat_logits, flat_labels) else: out['loss'] = 0 out['logits'] = full_logits segment_keys = ['loss', 'logits'] if kwargs.get('output_attentions'): segment_keys.append('attentions') if kwargs.get('output_hidden_states'): segment_keys.append('hidden_states') out['hidden_states'] = full_hidden_states return out def manage_gradients(self, memory_state, seg_num): k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments') if seg_num == 0 \ or k2 in {-1, None} \ or seg_num + k2 > max_n_segments: return memory_state memory_state = memory_state.detach() return memory_state def gradient_checkpointing_enable(self, *args, **kwargs): self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) class RecurrentWrapperNoSegmentation(RecurrentWrapper): def forward(self, segments, labels, output_attentions=None, output_hidden_states=None): memory_state = None cell_outputs = [] for seg_num, segment in enumerate(segments): cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], attention_mask=segment['attention_mask'], memory_state=memory_state, output_hidden_states=True) cell_outputs.append(cell_out) memory_state = self.manage_gradients(memory_state, seg_num) out = self.process_outputs(cell_outputs, segments, output_attentions=output_attentions, output_hidden_states=output_hidden_states) return out def generate(self, segments, **generate_kwargs): raise NotImplementedError("Generation not implemented for this wrapper.") def process_outputs(self, cell_outputs, segments, **kwargs): out = CausalLMOutputWithCrossAttentions() proxy_out = {} for seg_num, segment in enumerate(segments): cell_out = cell_outputs[seg_num] full_logits = cell_out.logits labels = segment.get('labels') if labels is not None: shift_labels = labels[..., 1:].contiguous() shift_logits = full_logits[..., :-1, :].contiguous() flat_labels = shift_labels.view(-1) flat_logits = shift_logits.view(-1, shift_logits.size(-1)) loss_fct = CrossEntropyLoss() labels_mask = segment.get('labels_mask') if labels_mask is not None: shift_mask = labels_mask[..., :-1].contiguous() flat_labels = flat_labels[shift_mask.view(-1)] flat_logits = flat_logits[shift_mask.view(-1)] if labels_mask.sum() == 0: loss_value = 0 else: loss_value = loss_fct(flat_logits, flat_labels) proxy_out[f'loss_{seg_num}'] = loss_value else: proxy_out[f'loss_{seg_num}'] = 0 segment_keys = ['loss'] if kwargs.get('output_attentions'): segment_keys.append('attentions') if kwargs.get('output_hidden_states'): segment_keys.append('hidden_states') for key, value in cell_out.items(): if any([sk in key for sk in segment_keys]): proxy_out[f'{key}_{seg_num}'] = value num_segments = len(segments) out['loss'] = sum([proxy_out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments out['logits'] = torch.cat([cell_out.logits for cell_out in cell_outputs], dim=1) # print(out.keys(), out.loss) return out def gradient_checkpointing_enable(self, *args, **kwargs): if hasattr(self.memory_cell.model, "gradient_checkpointing_enable"): return self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) class StopOnSpecialTokenCriteria(StoppingCriteria): def __init__(self, special_token_ids): self.special_token_ids = set(special_token_ids) def __call__(self, input_ids, scores, **kwargs): last_token = input_ids[0, -1].item() return last_token in self.special_token_ids class RecurrentWrapperNoSegmentationGenerate(RecurrentWrapperNoSegmentation): def forward(self, segments, labels, output_attentions=None, output_hidden_states=None): memory_state = None cell_outputs = [] for seg_num, segment in enumerate(segments): cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], attention_mask=segment['attention_mask'], memory_state=memory_state, output_hidden_states=True) cell_outputs.append(cell_out) self.manage_gradients(memory_state, seg_num) out = self.process_outputs(cell_outputs, segments, output_attentions=output_attentions, output_hidden_states=output_hidden_states) return out def generate(self, segments, **kwargs): memory_state = None for seg_num, segment in enumerate(segments): cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], attention_mask=segment['attention_mask'], memory_state=memory_state, output_hidden_states=True) generated_segments = [] for seg_num in range(len(segments), self.rmt_config.get("max_n_segments", 32)): output_ids, memory_state = self.generate_segment(memory_state=memory_state, **kwargs) generated_segments.append(output_ids) if self.all_done(generated_segments): break return generated_segments def generate_segment(self, memory_state, **kwargs): input_ids = self.get_bos_tensor(memory_state) attention_mask = torch.ones_like(input_ids).bool() generated = self.memory_cell.generate( input_ids=input_ids, attention_mask=attention_mask, memory_state=memory_state, stopping_criteria=self.make_custom_stopping_criteria(), **kwargs ) # Update memory state from generation fwd_inputs = torch.cat((input_ids, generated), dim=1)[:, :-1] _, memory_state = self.memory_cell(input_ids=fwd_inputs, memory_state=memory_state) return generated, memory_state def get_bos_tensor(self, memory_state): bos = self.rmt_config["bos_token_id"] bos_tensor = torch.tensor([bos] * memory_state.shape[0]).reshape(-1, 1) return bos_tensor.to(memory_state.device) def all_done(self, generated_segments): eos = self.rmt_config['eos_token_id'] bs = generated_segments[0].shape[0] have_eos = [any([eos in seg[i] for seg in generated_segments]) for i in range(bs)] all_done = all(have_eos) return all_done def make_custom_stopping_criteria(self): return [StopOnSpecialTokenCriteria([self.rmt_config['think_token_id'], self.rmt_config['answer_token_id']])]