File size: 1,508 Bytes
59f1228 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | from transformers import PretrainedConfig
class TinyGPTConfig(PretrainedConfig):
model_type = "basemini"
def __init__(
self,
vocab_size=32768,
ctx_len=512,
n_layer=4,
n_head=4,
n_embd=384,
dropout=0.0,
attention_backend="torch",
torch_fallback=False,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
sep_token_id=None,
unk_token_id=None,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
sep_token_id=sep_token_id,
unk_token_id=unk_token_id,
**kwargs,
)
if attention_backend not in ("sage", "torch", "flash2", "flash3"):
raise ValueError("attention_backend must be sage, torch, flash2 or flash3")
self.vocab_size = int(vocab_size)
self.ctx_len = int(ctx_len)
self.max_position_embeddings = int(ctx_len)
self.n_layer = int(n_layer)
self.n_head = int(n_head)
self.n_embd = int(n_embd)
self.num_hidden_layers = int(n_layer)
self.num_attention_heads = int(n_head)
self.hidden_size = int(n_embd)
self.dropout = float(dropout)
self.attention_backend = str(attention_backend)
self.available_attention_backends = ["sage", "torch", "flash2", "flash3"]
self.torch_fallback = bool(torch_fallback)
|