from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoConfig import torch from Naive_gpt.model import GPTLanguageModel class GPTConfig(PretrainedConfig): model_type = "gpt_custom" def __init__(self, vocab_size=20000, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.block_size = 1024 self.n_embd = 768 self.n_head = 12 self.n_layer = 12 self.dropout = 0.2 class GPTModelHF(PreTrainedModel): config_class = GPTConfig def __init__(self, config): super().__init__(config) self.model = GPTLanguageModel(config.vocab_size, config) self.config = config def forward(self, x): return self.model(x) def generate(self, idx, max_new_tokens): return self.model.generate(idx, max_new_tokens) # Register the model (this will run when the file is imported) AutoConfig.register("gpt_custom", GPTConfig) AutoModel.register(GPTConfig, GPTModelHF)