ubermenchh's picture
Update model.py
12234eb verified
import torch
from transformers import AutoModelForCausalLM, PreTrainedModel
from .config import EngramConfig
from .engram import EngramLayer
class GPT2WithEngram(PreTrainedModel):
config_class = EngramConfig
def __init__(self, config, vocab_map=None, nanogpt_config=None):
super().__init__(config)
self.config = config
if nanogpt_config is None:
from transformers import GPT2Config
nanogpt_config = GPT2Config(
vocab_size=config.vocab_size,
n_embd=config.d_model,
n_layer=6,
n_head=6,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1
)
self.gpt2 = AutoModelForCausalLM.from_config(nanogpt_config)
if vocab_map is None:
vocab_map = torch.zeros(config.vocab_size, dtype=torch.long)
self.engram = EngramLayer(config, vocab_map)
self.target_layer_idx = config.injection_layer
def forward(self, input_ids, labels=None, attention_mask=None):
self.current_input_ids = input_ids
def engram_hook(module, args, output):
hidden_state = output[0]
engram_out = self.engram(self.current_input_ids, hidden_state)
# print(f"engram_out: {engram_out.abs().mean()}")
new_hidden = hidden_state + engram_out
return (new_hidden,) + output[1:]
layer_module = self.gpt2.transformer.h[self.target_layer_idx]
hook_handle = layer_module.register_forward_hook(engram_hook)
try:
outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
finally:
hook_handle.remove()
self.current_input_ids = None
return outputs
AutoModelForCausalLM.register(EngramConfig, GPT2WithEngram)