Base-mini / configuration_tiny_gpt.py
QuantaSparkLabs's picture
Push to HF
59f1228 verified
Raw
History Blame Contribute Delete
1.51 kB
from transformers import PretrainedConfig
class TinyGPTConfig(PretrainedConfig):
model_type = "basemini"
def __init__(
self,
vocab_size=32768,
ctx_len=512,
n_layer=4,
n_head=4,
n_embd=384,
dropout=0.0,
attention_backend="torch",
torch_fallback=False,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
sep_token_id=None,
unk_token_id=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
sep_token_id=sep_token_id,
unk_token_id=unk_token_id,
**kwargs,
)
if attention_backend not in ("sage", "torch", "flash2", "flash3"):
raise ValueError("attention_backend must be sage, torch, flash2 or flash3")
self.vocab_size = int(vocab_size)
self.ctx_len = int(ctx_len)
self.max_position_embeddings = int(ctx_len)
self.n_layer = int(n_layer)
self.n_head = int(n_head)
self.n_embd = int(n_embd)
self.num_hidden_layers = int(n_layer)
self.num_attention_heads = int(n_head)
self.hidden_size = int(n_embd)
self.dropout = float(dropout)
self.attention_backend = str(attention_backend)
self.available_attention_backends = ["sage", "torch", "flash2", "flash3"]
self.torch_fallback = bool(torch_fallback)