|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, PreTrainedModel |
|
|
|
|
|
from .config import EngramConfig |
|
|
from .engram import EngramLayer |
|
|
|
|
|
|
|
|
class GPT2WithEngram(PreTrainedModel): |
|
|
config_class = EngramConfig |
|
|
|
|
|
def __init__(self, config, vocab_map=None, nanogpt_config=None): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
if nanogpt_config is None: |
|
|
from transformers import GPT2Config |
|
|
nanogpt_config = GPT2Config( |
|
|
vocab_size=config.vocab_size, |
|
|
n_embd=config.d_model, |
|
|
n_layer=6, |
|
|
n_head=6, |
|
|
resid_pdrop=0.1, |
|
|
embd_pdrop=0.1, |
|
|
attn_pdrop=0.1 |
|
|
) |
|
|
self.gpt2 = AutoModelForCausalLM.from_config(nanogpt_config) |
|
|
|
|
|
if vocab_map is None: |
|
|
vocab_map = torch.zeros(config.vocab_size, dtype=torch.long) |
|
|
|
|
|
self.engram = EngramLayer(config, vocab_map) |
|
|
self.target_layer_idx = config.injection_layer |
|
|
|
|
|
def forward(self, input_ids, labels=None, attention_mask=None): |
|
|
self.current_input_ids = input_ids |
|
|
|
|
|
def engram_hook(module, args, output): |
|
|
hidden_state = output[0] |
|
|
engram_out = self.engram(self.current_input_ids, hidden_state) |
|
|
|
|
|
new_hidden = hidden_state + engram_out |
|
|
return (new_hidden,) + output[1:] |
|
|
|
|
|
layer_module = self.gpt2.transformer.h[self.target_layer_idx] |
|
|
hook_handle = layer_module.register_forward_hook(engram_hook) |
|
|
|
|
|
try: |
|
|
outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
|
finally: |
|
|
hook_handle.remove() |
|
|
self.current_input_ids = None |
|
|
|
|
|
return outputs |
|
|
|
|
|
AutoModelForCausalLM.register(EngramConfig, GPT2WithEngram) |