|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_baseline import BaselineConfig |
|
|
try: |
|
|
from x_transformers import TransformerWrapper, Encoder |
|
|
except ImportError: |
|
|
raise ImportError("pip install x-transformers") |
|
|
|
|
|
class BaselineModel(PreTrainedModel): |
|
|
config_class = BaselineConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.model = TransformerWrapper( |
|
|
num_tokens=config.vocab_size, |
|
|
max_seq_len=config.seq_len, |
|
|
use_abs_pos_emb=False, |
|
|
tie_embedding=True, |
|
|
attn_layers=Encoder( |
|
|
dim=config.d_model, |
|
|
depth=config.depth, |
|
|
heads=config.heads, |
|
|
layer_dropout=config.dropout, |
|
|
attn_dropout=config.dropout, |
|
|
ff_dropout=config.dropout, |
|
|
rotary_pos_emb=True, |
|
|
attn_flash=True, |
|
|
use_scalenorm=False |
|
|
) |
|
|
) |
|
|
|
|
|
if hasattr(self.model.token_emb, 'emb'): |
|
|
self.model.to_logits.weight = self.model.token_emb.emb.weight |
|
|
else: |
|
|
self.model.to_logits.weight = self.model.token_emb.weight |
|
|
|
|
|
def forward(self, input_ids, labels=None, mask=None): |
|
|
logits = self.model(input_ids, mask=mask) |
|
|
if labels is not None: |
|
|
return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits} |
|
|
return logits |
|
|
|