| from dataclasses import dataclass |
| from transformers import PretrainedConfig |
|
|
| @dataclass |
| class GPTConfig(PretrainedConfig): |
| """ |
| Configuration class for custom GPT model. |
| """ |
| model_type = "custom_gpt" |
| block_size: int = 768 |
| vocab_size: int = 50257 |
| n_layer: int = 8 |
| n_head: int = 8 |
| n_embd: int = 768 |
| dropout: float = 0.1 |
|
|
| @classmethod |
| def from_pretrained(cls, *args, **kwargs): |
| """ |
| Override the from_pretrained method to handle custom configuration loading. |
| """ |
| return super().from_pretrained(*args, **kwargs) |