| from transformers import PretrainedConfig | |
| class BaselineConfig(PretrainedConfig): | |
| model_type = "baseline-xtransformers" | |
| def __init__( | |
| self, | |
| vocab_size=32768, | |
| d_model=512, | |
| seq_len=4096, | |
| depth=5, | |
| heads=8, | |
| dropout=0.0, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.seq_len = seq_len | |
| self.depth = depth | |
| self.heads = heads | |
| self.dropout = dropout | |