|
|
import math |
|
|
import torch |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers import StoppingCriteria |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
|
|
|
|
|
|
class RMTConfig(PretrainedConfig): |
|
|
model_type = "rmt" |
|
|
|
|
|
def __init__(self, |
|
|
base_model_name="HuggingFaceTB/SmolLM2-135M", |
|
|
num_mem_tokens=16, |
|
|
max_n_segments=10, |
|
|
think_token_id=None, |
|
|
answer_token_id=None, |
|
|
bos_token_id=None, |
|
|
eos_token_id=None, |
|
|
**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model_name = base_model_name |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
self.max_n_segments = max_n_segments |
|
|
self.think_token_id = think_token_id |
|
|
self.answer_token_id = answer_token_id |
|
|
self.bos_token_id = bos_token_id |
|
|
self.eos_token_id = eos_token_id |
|
|
self.memory_cell_cls = "MemoryCell" |
|
|
self.recurrent_wrapper_cls = "RecurrentWrapperNoSegmentationGenerate" |
|
|
|
|
|
def get(self, attr: str, default=None): |
|
|
if hasattr(self, attr): |
|
|
return getattr(self, attr) |
|
|
else: |
|
|
return default |
|
|
|
|
|
|
|
|
class RMTForReasoning(PreTrainedModel): |
|
|
config_class = RMTConfig |
|
|
|
|
|
def __init__(self, config: RMTConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
base_config = AutoConfig.from_pretrained(config.base_model_name) |
|
|
base_model = AutoModelForCausalLM.from_config(base_config) |
|
|
|
|
|
self.rmt_config = config |
|
|
memory_cell = MemoryCell(base_model, num_mem_tokens=config.num_mem_tokens) |
|
|
self.rmt = RecurrentWrapperNoSegmentationGenerate( |
|
|
memory_cell, |
|
|
max_n_segments=config.max_n_segments, |
|
|
think_token_id=config.think_token_id, |
|
|
answer_token_id=config.answer_token_id, |
|
|
bos_token_id=config.bos_token_id, |
|
|
eos_token_id=config.eos_token_id |
|
|
) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self.rmt(*args, **kwargs) |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
|
return self.rmt.generate(*args, **kwargs) |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True, assign=False): |
|
|
try: |
|
|
return super().load_state_dict(state_dict, strict, assign) |
|
|
except RuntimeError: |
|
|
print("Failed to load state, retrying with RMT loader.") |
|
|
self.rmt.load_state_dict(state_dict, strict=True, assign=assign) |
|
|
print("Success!") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs): |
|
|
from transformers.utils.hub import cached_file, HfHubHTTPError |
|
|
import torch |
|
|
|
|
|
if config is None: |
|
|
config = RMTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
state_dict = None |
|
|
try: |
|
|
weights_path = cached_file(pretrained_model_name_or_path, "model.safetensors", **kwargs) |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(weights_path, device="cpu") |
|
|
except (OSError, HfHubHTTPError): |
|
|
try: |
|
|
weights_path = cached_file(pretrained_model_name_or_path, "pytorch_model.bin", **kwargs) |
|
|
state_dict = torch.load(weights_path, map_location="cpu") |
|
|
except (OSError, HfHubHTTPError): |
|
|
print(f"Warning: Could not find weights for {pretrained_model_name_or_path}. " |
|
|
f"The model is initialized randomly.") |
|
|
|
|
|
if state_dict is not None: |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class MemoryCell(torch.nn.Module): |
|
|
def __init__(self, base_model, num_mem_tokens): |
|
|
super().__init__() |
|
|
self.model = base_model |
|
|
self.create_memory(num_mem_tokens) |
|
|
|
|
|
def create_memory(self, num_mem_tokens): |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
embeddings = self.model.get_input_embeddings() |
|
|
memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size) |
|
|
memory_weights = torch.randn((num_mem_tokens, memory_dim)) * embeddings.weight.data.std() |
|
|
self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True)) |
|
|
|
|
|
self.read_memory_position = range(num_mem_tokens) |
|
|
self.write_memory_position = range(-num_mem_tokens, 0) |
|
|
|
|
|
def set_memory(self, input_shape): |
|
|
memory = self.memory.repeat(input_shape[0], 1, 1) |
|
|
return memory |
|
|
|
|
|
def forward(self, input_ids, memory_state=None, **kwargs): |
|
|
if memory_state is None: |
|
|
memory_state = self.set_memory(input_ids.shape) |
|
|
|
|
|
seg_kwargs = self.process_input(input_ids, memory_state, write_mem=True, **kwargs) |
|
|
out = self.model(**seg_kwargs) |
|
|
out, new_memory_state = self.process_output(out, **kwargs) |
|
|
|
|
|
return out, new_memory_state |
|
|
|
|
|
def generate(self, input_ids, memory_state, attention_mask=None, **generate_kwargs): |
|
|
if memory_state is None: |
|
|
memory_state = self.set_memory(input_ids.shape) |
|
|
|
|
|
seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask, write_mem=False) |
|
|
out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], |
|
|
attention_mask=seg_kwargs['attention_mask'], |
|
|
**generate_kwargs) |
|
|
return out |
|
|
|
|
|
def process_input(self, input_ids, memory_state, write_mem, **kwargs): |
|
|
seg_kwargs = dict(**kwargs) |
|
|
|
|
|
inputs_embeds = kwargs.get('inputs_embeds') |
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.get_input_embeddings()(input_ids) |
|
|
|
|
|
if self.num_mem_tokens > 0: |
|
|
if write_mem: |
|
|
inputs_embeds = torch.cat([memory_state, inputs_embeds, memory_state], dim=1) |
|
|
else: |
|
|
inputs_embeds = torch.cat([memory_state, inputs_embeds], dim=1) |
|
|
|
|
|
seg_kwargs['input_ids'] = None |
|
|
seg_kwargs['inputs_embeds'] = inputs_embeds |
|
|
if kwargs.get('attention_mask') is not None: |
|
|
seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], inputs_embeds.shape) |
|
|
seg_kwargs['output_hidden_states'] = True |
|
|
return seg_kwargs |
|
|
|
|
|
def pad_attention_mask(self, attention_mask, shape): |
|
|
if self.num_mem_tokens in {0, None}: |
|
|
return attention_mask |
|
|
else: |
|
|
mask = torch.ones(*shape[:2], dtype=torch.int64).to(attention_mask.device) |
|
|
mask[:, self.num_mem_tokens: self.num_mem_tokens + attention_mask.shape[1]] = attention_mask |
|
|
return mask |
|
|
|
|
|
def process_output(self, model_outputs, **kwargs): |
|
|
if self.num_mem_tokens not in {0, None}: |
|
|
out = CausalLMOutputWithCrossAttentions() |
|
|
memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens:] |
|
|
out['logits'] = model_outputs.logits[:, self.num_mem_tokens:-self.num_mem_tokens] |
|
|
|
|
|
if kwargs.get('output_hidden_states'): |
|
|
out['hidden_states'] = [lh[:, self.num_mem_tokens:-self.num_mem_tokens] |
|
|
for lh in model_outputs.hidden_states] |
|
|
if kwargs.get('output_attentions'): |
|
|
out['attentions'] = model_outputs['attentions'] |
|
|
else: |
|
|
memory_state = None |
|
|
out = model_outputs |
|
|
|
|
|
return out, memory_state |
|
|
|
|
|
|
|
|
class RecurrentWrapper(torch.nn.Module): |
|
|
def __init__(self, memory_cell, **rmt_kwargs): |
|
|
super().__init__() |
|
|
self.memory_cell = memory_cell |
|
|
self.rmt_config = rmt_kwargs |
|
|
|
|
|
def forward(self, input_ids, labels=None, labels_mask=None, inputs_embeds=None, attention_mask=None, |
|
|
output_attentions=None, output_hidden_states=None): |
|
|
memory_state = None |
|
|
segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
|
|
|
|
|
cell_outputs = [] |
|
|
for seg_num, segment in enumerate(segmented): |
|
|
cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True) |
|
|
cell_outputs.append(cell_out) |
|
|
memory_state = self.manage_gradients(memory_state, seg_num) |
|
|
|
|
|
out = self.process_outputs(cell_outputs, labels=labels, |
|
|
labels_mask=labels_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states) |
|
|
return out |
|
|
|
|
|
def generate(self, input_ids, attention_mask=None, **generate_kwargs): |
|
|
memory_state = None |
|
|
segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
for seg_num, segment in enumerate(segmented[:-1]): |
|
|
cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True) |
|
|
|
|
|
final_segment = segmented[-1] |
|
|
out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs) |
|
|
|
|
|
return out |
|
|
|
|
|
def segment(self, **kwargs): |
|
|
segments = [] |
|
|
for k, tensor in kwargs.items(): |
|
|
if tensor is not None: |
|
|
k_segments = self.split_tensor(tensor) |
|
|
for s, k_seg in enumerate(k_segments): |
|
|
if s < len(segments): |
|
|
segments[s][k] = k_seg |
|
|
else: |
|
|
segments.append({k: k_seg}) |
|
|
|
|
|
return segments |
|
|
|
|
|
def split_tensor(self, tensor): |
|
|
align = self.rmt_config.get('segment_alignment') |
|
|
segment_size = self.rmt_config.get('segment_size') |
|
|
if align in {'left', None}: |
|
|
split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]] |
|
|
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
|
|
elif align in {'right', None}: |
|
|
split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1] |
|
|
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
|
|
elif align == 'center': |
|
|
n_seg = math.ceil(tensor.shape[1] / segment_size) |
|
|
segments = torch.chunk(tensor, n_seg, dim=1) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
return segments |
|
|
|
|
|
def process_outputs(self, cell_outputs, **kwargs): |
|
|
out = CausalLMOutputWithCrossAttentions() |
|
|
full_logits = torch.cat([o.logits for o in cell_outputs], dim=1) |
|
|
full_hidden_states = tuple([torch.cat(layer_hs, dim=1) |
|
|
for layer_hs in zip(*[o.hidden_states for o in cell_outputs])]) |
|
|
|
|
|
labels = kwargs.get('labels') |
|
|
if labels is not None: |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
shift_logits = full_logits[..., :-1, :].contiguous() |
|
|
flat_labels = shift_labels.view(-1) |
|
|
flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
labels_mask = kwargs.get('labels_mask') |
|
|
if labels_mask is not None: |
|
|
shift_mask = labels_mask[..., :-1].contiguous() |
|
|
|
|
|
flat_labels = flat_labels[shift_mask.view(-1)] |
|
|
flat_logits = flat_logits[shift_mask.view(-1)] |
|
|
|
|
|
out['loss'] = loss_fct(flat_logits, flat_labels) |
|
|
else: |
|
|
out['loss'] = 0 |
|
|
|
|
|
out['logits'] = full_logits |
|
|
segment_keys = ['loss', 'logits'] |
|
|
if kwargs.get('output_attentions'): |
|
|
segment_keys.append('attentions') |
|
|
if kwargs.get('output_hidden_states'): |
|
|
segment_keys.append('hidden_states') |
|
|
out['hidden_states'] = full_hidden_states |
|
|
|
|
|
return out |
|
|
|
|
|
def manage_gradients(self, memory_state, seg_num): |
|
|
k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments') |
|
|
if seg_num == 0 \ |
|
|
or k2 in {-1, None} \ |
|
|
or seg_num + k2 > max_n_segments: |
|
|
return memory_state |
|
|
|
|
|
memory_state = memory_state.detach() |
|
|
return memory_state |
|
|
|
|
|
def gradient_checkpointing_enable(self, *args, **kwargs): |
|
|
self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) |
|
|
|
|
|
|
|
|
class RecurrentWrapperNoSegmentation(RecurrentWrapper): |
|
|
def forward(self, segments, labels, output_attentions=None, output_hidden_states=None): |
|
|
memory_state = None |
|
|
|
|
|
cell_outputs = [] |
|
|
for seg_num, segment in enumerate(segments): |
|
|
cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], |
|
|
attention_mask=segment['attention_mask'], |
|
|
memory_state=memory_state, output_hidden_states=True) |
|
|
cell_outputs.append(cell_out) |
|
|
memory_state = self.manage_gradients(memory_state, seg_num) |
|
|
|
|
|
out = self.process_outputs(cell_outputs, segments, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states) |
|
|
return out |
|
|
|
|
|
def generate(self, segments, **generate_kwargs): |
|
|
raise NotImplementedError("Generation not implemented for this wrapper.") |
|
|
|
|
|
def process_outputs(self, cell_outputs, segments, **kwargs): |
|
|
out = CausalLMOutputWithCrossAttentions() |
|
|
proxy_out = {} |
|
|
for seg_num, segment in enumerate(segments): |
|
|
cell_out = cell_outputs[seg_num] |
|
|
|
|
|
full_logits = cell_out.logits |
|
|
|
|
|
labels = segment.get('labels') |
|
|
if labels is not None: |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
shift_logits = full_logits[..., :-1, :].contiguous() |
|
|
flat_labels = shift_labels.view(-1) |
|
|
flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
labels_mask = segment.get('labels_mask') |
|
|
if labels_mask is not None: |
|
|
shift_mask = labels_mask[..., :-1].contiguous() |
|
|
|
|
|
flat_labels = flat_labels[shift_mask.view(-1)] |
|
|
flat_logits = flat_logits[shift_mask.view(-1)] |
|
|
|
|
|
if labels_mask.sum() == 0: |
|
|
loss_value = 0 |
|
|
else: |
|
|
loss_value = loss_fct(flat_logits, flat_labels) |
|
|
|
|
|
proxy_out[f'loss_{seg_num}'] = loss_value |
|
|
else: |
|
|
proxy_out[f'loss_{seg_num}'] = 0 |
|
|
|
|
|
segment_keys = ['loss'] |
|
|
if kwargs.get('output_attentions'): |
|
|
segment_keys.append('attentions') |
|
|
if kwargs.get('output_hidden_states'): |
|
|
segment_keys.append('hidden_states') |
|
|
|
|
|
for key, value in cell_out.items(): |
|
|
if any([sk in key for sk in segment_keys]): |
|
|
proxy_out[f'{key}_{seg_num}'] = value |
|
|
|
|
|
num_segments = len(segments) |
|
|
out['loss'] = sum([proxy_out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments |
|
|
out['logits'] = torch.cat([cell_out.logits for cell_out in cell_outputs], dim=1) |
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
def gradient_checkpointing_enable(self, *args, **kwargs): |
|
|
if hasattr(self.memory_cell.model, "gradient_checkpointing_enable"): |
|
|
return self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) |
|
|
|
|
|
|
|
|
class StopOnSpecialTokenCriteria(StoppingCriteria): |
|
|
def __init__(self, special_token_ids): |
|
|
self.special_token_ids = set(special_token_ids) |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
last_token = input_ids[0, -1].item() |
|
|
return last_token in self.special_token_ids |
|
|
|
|
|
|
|
|
class RecurrentWrapperNoSegmentationGenerate(RecurrentWrapperNoSegmentation): |
|
|
def forward(self, segments, labels, output_attentions=None, output_hidden_states=None): |
|
|
memory_state = None |
|
|
|
|
|
cell_outputs = [] |
|
|
for seg_num, segment in enumerate(segments): |
|
|
cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], |
|
|
attention_mask=segment['attention_mask'], |
|
|
memory_state=memory_state, output_hidden_states=True) |
|
|
cell_outputs.append(cell_out) |
|
|
self.manage_gradients(memory_state, seg_num) |
|
|
|
|
|
out = self.process_outputs(cell_outputs, segments, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states) |
|
|
return out |
|
|
|
|
|
def generate(self, segments, **kwargs): |
|
|
memory_state = None |
|
|
|
|
|
for seg_num, segment in enumerate(segments): |
|
|
cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'], |
|
|
attention_mask=segment['attention_mask'], |
|
|
memory_state=memory_state, output_hidden_states=True) |
|
|
|
|
|
generated_segments = [] |
|
|
for seg_num in range(len(segments), self.rmt_config.get("max_n_segments", 32)): |
|
|
output_ids, memory_state = self.generate_segment(memory_state=memory_state, **kwargs) |
|
|
generated_segments.append(output_ids) |
|
|
|
|
|
if self.all_done(generated_segments): |
|
|
break |
|
|
|
|
|
return generated_segments |
|
|
|
|
|
def generate_segment(self, memory_state, **kwargs): |
|
|
input_ids = self.get_bos_tensor(memory_state) |
|
|
attention_mask = torch.ones_like(input_ids).bool() |
|
|
|
|
|
generated = self.memory_cell.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
memory_state=memory_state, |
|
|
stopping_criteria=self.make_custom_stopping_criteria(), |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
fwd_inputs = torch.cat((input_ids, generated), dim=1)[:, :-1] |
|
|
_, memory_state = self.memory_cell(input_ids=fwd_inputs, memory_state=memory_state) |
|
|
|
|
|
return generated, memory_state |
|
|
|
|
|
def get_bos_tensor(self, memory_state): |
|
|
bos = self.rmt_config["bos_token_id"] |
|
|
bos_tensor = torch.tensor([bos] * memory_state.shape[0]).reshape(-1, 1) |
|
|
return bos_tensor.to(memory_state.device) |
|
|
|
|
|
def all_done(self, generated_segments): |
|
|
eos = self.rmt_config['eos_token_id'] |
|
|
bs = generated_segments[0].shape[0] |
|
|
have_eos = [any([eos in seg[i] for seg in generated_segments]) for i in range(bs)] |
|
|
all_done = all(have_eos) |
|
|
return all_done |
|
|
|
|
|
def make_custom_stopping_criteria(self): |
|
|
return [StopOnSpecialTokenCriteria([self.rmt_config['think_token_id'], self.rmt_config['answer_token_id']])] |
|
|
|