import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_baseline import BaselineConfig try: from x_transformers import TransformerWrapper, Encoder except ImportError: raise ImportError("pip install x-transformers") class BaselineModel(PreTrainedModel): config_class = BaselineConfig def __init__(self, config): super().__init__(config) self.config = config self.model = TransformerWrapper( num_tokens=config.vocab_size, max_seq_len=config.seq_len, use_abs_pos_emb=False, tie_embedding=True, attn_layers=Encoder( dim=config.d_model, depth=config.depth, heads=config.heads, layer_dropout=config.dropout, attn_dropout=config.dropout, ff_dropout=config.dropout, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False ) ) # TIE FIX if hasattr(self.model.token_emb, 'emb'): self.model.to_logits.weight = self.model.token_emb.emb.weight else: self.model.to_logits.weight = self.model.token_emb.weight def forward(self, input_ids, labels=None, mask=None): logits = self.model(input_ids, mask=mask) if labels is not None: return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits} return logits