| from transformers import PretrainedConfig | |
| class TinyGPTConfig(PretrainedConfig): | |
| model_type = "basemini" | |
| def __init__( | |
| self, | |
| vocab_size=32768, | |
| ctx_len=512, | |
| n_layer=4, | |
| n_head=4, | |
| n_embd=384, | |
| dropout=0.0, | |
| attention_backend="torch", | |
| torch_fallback=False, | |
| pad_token_id=None, | |
| bos_token_id=None, | |
| eos_token_id=None, | |
| sep_token_id=None, | |
| unk_token_id=None, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| sep_token_id=sep_token_id, | |
| unk_token_id=unk_token_id, | |
| **kwargs, | |
| ) | |
| if attention_backend not in ("sage", "torch", "flash2", "flash3"): | |
| raise ValueError("attention_backend must be sage, torch, flash2 or flash3") | |
| self.vocab_size = int(vocab_size) | |
| self.ctx_len = int(ctx_len) | |
| self.max_position_embeddings = int(ctx_len) | |
| self.n_layer = int(n_layer) | |
| self.n_head = int(n_head) | |
| self.n_embd = int(n_embd) | |
| self.num_hidden_layers = int(n_layer) | |
| self.num_attention_heads = int(n_head) | |
| self.hidden_size = int(n_embd) | |
| self.dropout = float(dropout) | |
| self.attention_backend = str(attention_backend) | |
| self.available_attention_backends = ["sage", "torch", "flash2", "flash3"] | |
| self.torch_fallback = bool(torch_fallback) | |