Lepish commited on
Commit
68603e7
·
verified ·
1 Parent(s): cef0483

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +137 -1
model.py CHANGED
@@ -1 +1,137 @@
1
- <TO_BE_INSERTED_MODEL_PY_CODE>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ class LayerNorm(nn.Module):
7
+ def __init__(self, ndim, bias):
8
+ super().__init__()
9
+ self.weight = nn.Parameter(torch.ones(ndim))
10
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
11
+
12
+ def forward(self, x):
13
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias, eps=1e-5)
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ assert config.n_embd % config.n_head == 0
19
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
20
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
21
+ self.attn_dropout = nn.Dropout(config.dropout)
22
+ self.resid_dropout = nn.Dropout(config.dropout)
23
+ self.n_head = config.n_head
24
+ self.n_embd = config.n_embd
25
+ self.dropout = config.dropout
26
+
27
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
28
+ .view(1, 1, config.block_size, config.block_size))
29
+
30
+ def forward(self, x):
31
+ B, T, C = x.size()
32
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
33
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
34
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
35
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
36
+
37
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
38
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
39
+ att = F.softmax(att, dim=-1)
40
+ att = self.attn_dropout(att)
41
+
42
+ y = att @ v
43
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
44
+ y = self.resid_dropout(self.c_proj(y))
45
+ return y
46
+
47
+ class MLP(nn.Module):
48
+ def __init__(self, config):
49
+ super().__init__()
50
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
51
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
52
+ self.dropout = nn.Dropout(config.dropout)
53
+
54
+ def forward(self, x):
55
+ x = self.c_fc(x)
56
+ x = F.gelu(x)
57
+ x = self.c_proj(x)
58
+ return self.dropout(x)
59
+
60
+ class Block(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.ln1 = LayerNorm(config.n_embd, bias=config.bias)
64
+ self.attn = CausalSelfAttention(config)
65
+ self.ln2 = LayerNorm(config.n_embd, bias=config.bias)
66
+ self.mlp = MLP(config)
67
+
68
+ def forward(self, x):
69
+ x = x + self.attn(self.ln1(x))
70
+ x = x + self.mlp(self.ln2(x))
71
+ return x
72
+
73
+ class GPTConfig:
74
+ def __init__(self, **kwargs):
75
+ self.vocab_size = kwargs.get("vocab_size", 50304)
76
+ self.block_size = kwargs.get("block_size", 1024)
77
+ self.n_layer = kwargs.get("n_layer", 12)
78
+ self.n_head = kwargs.get("n_head", 12)
79
+ self.n_embd = kwargs.get("n_embd", 768)
80
+ self.dropout = kwargs.get("dropout", 0.1)
81
+ self.bias = kwargs.get("bias", True)
82
+
83
+ class GPT(nn.Module):
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.config = config
87
+
88
+ self.transformer = nn.ModuleDict(dict(
89
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
90
+ wpe = nn.Embedding(config.block_size, config.n_embd),
91
+ drop = nn.Dropout(config.dropout),
92
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
93
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
94
+ ))
95
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
96
+
97
+ self.apply(self._init_weights)
98
+
99
+ def _init_weights(self, module):
100
+ if isinstance(module, nn.Linear):
101
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+ if module.bias is not None:
103
+ nn.init.zeros_(module.bias)
104
+ elif isinstance(module, nn.Embedding):
105
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
106
+
107
+ def forward(self, idx, targets=None):
108
+ B, T = idx.size()
109
+ assert T <= self.config.block_size, "Cannot forward, sequence too long"
110
+
111
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)
112
+
113
+ tok_emb = self.transformer.wte(idx)
114
+ pos_emb = self.transformer.wpe(pos)
115
+ x = self.transformer.drop(tok_emb + pos_emb)
116
+ for block in self.transformer.h:
117
+ x = block(x)
118
+ x = self.transformer.ln_f(x)
119
+
120
+ logits = self.lm_head(x)
121
+
122
+ if targets is not None:
123
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
124
+ return logits, loss
125
+ else:
126
+ return logits, None
127
+
128
+ @torch.no_grad()
129
+ def generate(self, idx, max_new_tokens):
130
+ for _ in range(max_new_tokens):
131
+ idx_cond = idx[:, -self.config.block_size:]
132
+ logits, _ = self(idx_cond)
133
+ logits = logits[:, -1, :]
134
+ probs = F.softmax(logits, dim=-1)
135
+ next_token = torch.multinomial(probs, num_samples=1)
136
+ idx = torch.cat((idx, next_token), dim=1)
137
+ return idx