Text Generation
Transformers
Safetensors
Portuguese
nanothink
NanoThink-5M / model.py
AxionLab-official's picture
Create model.py
93ca81b verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class NanoThinkConfig(PretrainedConfig):
model_type = "nanothink"
def __init__(
self,
vocab_size=1229,
dim=128,
n_layers=4,
n_heads=4,
max_len=256,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.max_len = max_len
class NanoThinkModel(PreTrainedModel):
config_class = NanoThinkConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
self.pos_emb = nn.Embedding(config.max_len, config.dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.dim,
nhead=config.n_heads,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=config.n_layers
)
self.ln = nn.LayerNorm(config.dim)
self.head = nn.Linear(config.dim, config.vocab_size)
self.post_init()
def forward(self, input_ids):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(pos)
mask = torch.triu(
torch.ones(T, T, device=input_ids.device),
diagonal=1
).bool()
x = self.transformer(x, mask=mask)
x = self.ln(x)
logits = self.head(x)
return logits