| from transformers import PreTrainedModel | |
| from .configuration_ss4m import SimpleStories4MConfig | |
| from .nano_gpt_model import NanoGPT | |
| class SimpleStories4MModel(PreTrainedModel): | |
| config_class = SimpleStories4MConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| hyperparameters = { | |
| "vocab_size": config.vocab_size, | |
| "block_size": config.block_size, | |
| "n_embed": config.n_embed, | |
| "n_heads": config.n_heads, | |
| "n_layers": config.n_layers, | |
| "dropout": config.dropout, | |
| } | |
| self.model = NanoGPT(hyperparameters) | |
| def forward(self, tensor, targets=None): | |
| return self.model(tensor, targets) |