import transformers import model class AbcTransformerConfig(transformers.PretrainedConfig): model_type = 'abc-transformer' def __init__( self, vocab_size=113, n_embd=384, block_size=128, n_heads=6, n_layers=6, dropout=0.2, device=None, **kwargs ): self.vocab_size = vocab_size self.n_embd = n_embd self.block_size = block_size self.n_heads = n_heads self.n_layers = n_layers self.dropout = dropout self.device = device super().__init__(**kwargs) class AbcTransformer(transformers.PreTrainedModel): config_class = AbcTransformerConfig def __init__(self, config): super().__init__(config) self.model = model.AbcTransformer( vocab_size=config.vocab_size, n_embd=config.n_embd, block_size=config.block_size, n_heads=config.n_heads, n_layers=config.n_layers, dropout=config.dropout, device=config.device, ) def forward(self, tensor, labels): return self.model(tensor, labels) transformers.AutoConfig.register('abc-transformer', AbcTransformerConfig) AbcTransformer.register_for_auto_class("AutoModelForCausalLM") transformers.AutoModel.register(AbcTransformerConfig, AbcTransformer)