import torch from transformers import PretrainedConfig, AutoConfig class MiniTransformerConfig(PretrainedConfig): model_type = "minitransformer" def __init__( self, bsz: int = 1, dim: int = 768, num_heads: int = 24, num_layers: int = 27, seq_len: int = 8192, window_size: int = 8192, vocab_size: int = 200064, mlp_scale: int = 4, bias: bool = False, dropout: float = 0.0, softcap: float = 50.0, theta: float = 10_000.0, use_alibi: bool = False, torch_dtype: torch.dtype = torch.bfloat16, device: torch.device = None, **kwargs, ): super().__init__(**kwargs) self.bsz = bsz self.dim = dim self.num_heads = num_heads self.num_layers = num_layers self.seq_len = seq_len self.window_size = window_size self.vocab_size = vocab_size self.hidden_size = dim self.mlp_scale = mlp_scale self.intermediate_size = self.dim * self.mlp_scale self.bias = bias self.dropout = dropout self.softcap = softcap self.theta = theta self.use_alibi = use_alibi self.torch_dtype = torch_dtype self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string