File size: 1,354 Bytes
cbda9b7 413ea27 cbda9b7 413ea27 cbda9b7 413ea27 cbda9b7 a2fbb2f cbda9b7 a2fbb2f cbda9b7 a2fbb2f cbda9b7 a2fbb2f cbda9b7 a2fbb2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import torch
from transformers import PretrainedConfig, AutoConfig
class MiniTransformerConfig(PretrainedConfig):
model_type = "minitransformer"
def __init__(
self,
bsz: int = 1,
dim: int = 768,
num_heads: int = 24,
num_layers: int = 27,
seq_len: int = 8192,
window_size: int = 8192,
vocab_size: int = 200064,
mlp_scale: int = 4,
bias: bool = False,
dropout: float = 0.0,
softcap: float = 50.0,
theta: float = 10_000.0,
use_alibi: bool = False,
torch_dtype: torch.dtype = torch.bfloat16,
device: torch.device = None,
**kwargs,
):
super().__init__(**kwargs)
self.bsz = bsz
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.seq_len = seq_len
self.window_size = window_size
self.vocab_size = vocab_size
self.hidden_size = dim
self.mlp_scale = mlp_scale
self.intermediate_size = self.dim * self.mlp_scale
self.bias = bias
self.dropout = dropout
self.softcap = softcap
self.theta = theta
self.use_alibi = use_alibi
self.torch_dtype = torch_dtype
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
|