AxionLab-official commited on
Commit
93ca81b
·
verified ·
1 Parent(s): 6723353

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +67 -0
model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+
6
+ class NanoThinkConfig(PretrainedConfig):
7
+ model_type = "nanothink"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=1229,
12
+ dim=128,
13
+ n_layers=4,
14
+ n_heads=4,
15
+ max_len=256,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.vocab_size = vocab_size
20
+ self.dim = dim
21
+ self.n_layers = n_layers
22
+ self.n_heads = n_heads
23
+ self.max_len = max_len
24
+
25
+
26
+ class NanoThinkModel(PreTrainedModel):
27
+ config_class = NanoThinkConfig
28
+
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+
32
+ self.token_emb = nn.Embedding(config.vocab_size, config.dim)
33
+ self.pos_emb = nn.Embedding(config.max_len, config.dim)
34
+
35
+ encoder_layer = nn.TransformerEncoderLayer(
36
+ d_model=config.dim,
37
+ nhead=config.n_heads,
38
+ batch_first=True
39
+ )
40
+
41
+ self.transformer = nn.TransformerEncoder(
42
+ encoder_layer,
43
+ num_layers=config.n_layers
44
+ )
45
+
46
+ self.ln = nn.LayerNorm(config.dim)
47
+ self.head = nn.Linear(config.dim, config.vocab_size)
48
+
49
+ self.post_init()
50
+
51
+ def forward(self, input_ids):
52
+ B, T = input_ids.shape
53
+ pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
54
+
55
+ x = self.token_emb(input_ids) + self.pos_emb(pos)
56
+
57
+ mask = torch.triu(
58
+ torch.ones(T, T, device=input_ids.device),
59
+ diagonal=1
60
+ ).bool()
61
+
62
+ x = self.transformer(x, mask=mask)
63
+ x = self.ln(x)
64
+
65
+ logits = self.head(x)
66
+
67
+ return logits