| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
| class NanoThinkConfig(PretrainedConfig): |
| model_type = "nanothink" |
|
|
| def __init__( |
| self, |
| vocab_size=1229, |
| dim=128, |
| n_layers=4, |
| n_heads=4, |
| max_len=256, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.dim = dim |
| self.n_layers = n_layers |
| self.n_heads = n_heads |
| self.max_len = max_len |
|
|
|
|
| class NanoThinkModel(PreTrainedModel): |
| config_class = NanoThinkConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.token_emb = nn.Embedding(config.vocab_size, config.dim) |
| self.pos_emb = nn.Embedding(config.max_len, config.dim) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.dim, |
| nhead=config.n_heads, |
| batch_first=True |
| ) |
|
|
| self.transformer = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.n_layers |
| ) |
|
|
| self.ln = nn.LayerNorm(config.dim) |
| self.head = nn.Linear(config.dim, config.vocab_size) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids): |
| B, T = input_ids.shape |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0) |
|
|
| x = self.token_emb(input_ids) + self.pos_emb(pos) |
|
|
| mask = torch.triu( |
| torch.ones(T, T, device=input_ids.device), |
| diagonal=1 |
| ).bool() |
|
|
| x = self.transformer(x, mask=mask) |
| x = self.ln(x) |
|
|
| logits = self.head(x) |
|
|
| return logits |