baseline-wikitext-prism / modeling_baseline.py
prism-lab's picture
Initial anonymous commit
5b04dc9 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_baseline import BaselineConfig
try:
from x_transformers import TransformerWrapper, Encoder
except ImportError:
raise ImportError("pip install x-transformers")
class BaselineModel(PreTrainedModel):
config_class = BaselineConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = TransformerWrapper(
num_tokens=config.vocab_size,
max_seq_len=config.seq_len,
use_abs_pos_emb=False,
tie_embedding=True,
attn_layers=Encoder(
dim=config.d_model,
depth=config.depth,
heads=config.heads,
layer_dropout=config.dropout,
attn_dropout=config.dropout,
ff_dropout=config.dropout,
rotary_pos_emb=True,
attn_flash=True,
use_scalenorm=False
)
)
# TIE FIX
if hasattr(self.model.token_emb, 'emb'):
self.model.to_logits.weight = self.model.token_emb.emb.weight
else:
self.model.to_logits.weight = self.model.token_emb.weight
def forward(self, input_ids, labels=None, mask=None):
logits = self.model(input_ids, mask=mask)
if labels is not None:
return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits}
return logits