import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class NanoThinkConfig(PretrainedConfig): model_type = "nanothink" def __init__( self, vocab_size=1229, dim=128, n_layers=4, n_heads=4, max_len=256, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.dim = dim self.n_layers = n_layers self.n_heads = n_heads self.max_len = max_len class NanoThinkModel(PreTrainedModel): config_class = NanoThinkConfig def __init__(self, config): super().__init__(config) self.token_emb = nn.Embedding(config.vocab_size, config.dim) self.pos_emb = nn.Embedding(config.max_len, config.dim) encoder_layer = nn.TransformerEncoderLayer( d_model=config.dim, nhead=config.n_heads, batch_first=True ) self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=config.n_layers ) self.ln = nn.LayerNorm(config.dim) self.head = nn.Linear(config.dim, config.vocab_size) self.post_init() def forward(self, input_ids): B, T = input_ids.shape pos = torch.arange(T, device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(pos) mask = torch.triu( torch.ones(T, T, device=input_ids.device), diagonal=1 ).bool() x = self.transformer(x, mask=mask) x = self.ln(x) logits = self.head(x) return logits