| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from .configuration_baseline import BaselineConfig |
| | try: |
| | from x_transformers import TransformerWrapper, Encoder |
| | except ImportError: |
| | raise ImportError("pip install x-transformers") |
| |
|
| | class BaselineModel(PreTrainedModel): |
| | config_class = BaselineConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| | self.model = TransformerWrapper( |
| | num_tokens=config.vocab_size, |
| | max_seq_len=config.seq_len, |
| | use_abs_pos_emb=False, |
| | tie_embedding=True, |
| | attn_layers=Encoder( |
| | dim=config.d_model, |
| | depth=config.depth, |
| | heads=config.heads, |
| | layer_dropout=config.dropout, |
| | attn_dropout=config.dropout, |
| | ff_dropout=config.dropout, |
| | rotary_pos_emb=True, |
| | attn_flash=True, |
| | use_scalenorm=False |
| | ) |
| | ) |
| | |
| | if hasattr(self.model.token_emb, 'emb'): |
| | self.model.to_logits.weight = self.model.token_emb.emb.weight |
| | else: |
| | self.model.to_logits.weight = self.model.token_emb.weight |
| |
|
| | def forward(self, input_ids, labels=None, mask=None): |
| | logits = self.model(input_ids, mask=mask) |
| | if labels is not None: |
| | return {"loss": nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1)), "logits": logits} |
| | return logits |
| |
|