|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoConfig |
|
|
import torch |
|
|
from Naive_gpt.model import GPTLanguageModel |
|
|
|
|
|
class GPTConfig(PretrainedConfig): |
|
|
model_type = "gpt_custom" |
|
|
|
|
|
def __init__(self, vocab_size=20000, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.block_size = 1024 |
|
|
self.n_embd = 768 |
|
|
self.n_head = 12 |
|
|
self.n_layer = 12 |
|
|
self.dropout = 0.2 |
|
|
|
|
|
class GPTModelHF(PreTrainedModel): |
|
|
config_class = GPTConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = GPTLanguageModel(config.vocab_size, config) |
|
|
self.config = config |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
def generate(self, idx, max_new_tokens): |
|
|
return self.model.generate(idx, max_new_tokens) |
|
|
|
|
|
|
|
|
AutoConfig.register("gpt_custom", GPTConfig) |
|
|
AutoModel.register(GPTConfig, GPTModelHF) |