rupakrpk93 commited on
Commit
43e9ef6
·
verified ·
1 Parent(s): e9ae34f

Committed model architecture file so that you can refer to this file for creating a model in code. For more reference please look at the inference documents for more details

Browse files
Files changed (1) hide show
  1. model_architecture.py +138 -0
model_architecture.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ class LayerNorm(nn.Module):
9
+ def __init__(self, ndim, bias):
10
+ super().__init__()
11
+ self.weight = nn.Parameter(torch.ones(ndim))
12
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
13
+ def forward(self, x):
14
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
15
+
16
+ class CausalSelfAttention(nn.Module):
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ assert config.n_embd % config.n_head == 0
20
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
21
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
22
+ self.attn_dropout = nn.Dropout(config.dropout)
23
+ self.resid_dropout = nn.Dropout(config.dropout)
24
+ self.n_head = config.n_head
25
+ self.n_embd = config.n_embd
26
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
27
+ if not self.flash:
28
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
29
+ .view(1, 1, config.block_size, config.block_size))
30
+
31
+ def forward(self, x):
32
+ B, T, C = x.size()
33
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
34
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
35
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
36
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
37
+ if self.flash:
38
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
39
+ else:
40
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
41
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
42
+ att = F.softmax(att, dim=-1)
43
+ att = self.attn_dropout(att)
44
+ y = att @ v
45
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
46
+ y = self.resid_dropout(self.c_proj(y))
47
+ return y
48
+
49
+ class MLP(nn.Module):
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
53
+ self.gelu = nn.GELU()
54
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
55
+ self.dropout = nn.Dropout(config.dropout)
56
+ def forward(self, x):
57
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
58
+
59
+ class Block(nn.Module):
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
63
+ self.attn = CausalSelfAttention(config)
64
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
65
+ self.mlp = MLP(config)
66
+ def forward(self, x):
67
+ x = x + self.attn(self.ln1(x))
68
+ x = x + self.mlp(self.ln2(x))
69
+ return x
70
+
71
+ @dataclass
72
+ class GPTConfig:
73
+ block_size: int = 256
74
+ vocab_size: int = 50000
75
+ n_layer: int = 24
76
+ n_head: int = 12
77
+ n_embd: int = 768
78
+ dropout: float = 0.1
79
+ bias: bool = True
80
+
81
+ class GPT(nn.Module):
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ self.config = config
85
+ self.transformer = nn.ModuleDict(dict(
86
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
87
+ wpe=nn.Embedding(config.block_size, config.n_embd),
88
+ drop=nn.Dropout(config.dropout),
89
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
90
+ ln_f=LayerNorm(config.n_embd, config.bias),
91
+ ))
92
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
93
+ self.transformer.wte.weight = self.lm_head.weight
94
+ self.apply(self._init_weights)
95
+ for pn, p in self.named_parameters():
96
+ if pn.endswith('c_proj.weight'):
97
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
98
+
99
+ def _init_weights(self, module):
100
+ if isinstance(module, nn.Linear):
101
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+ if module.bias is not None:
103
+ nn.init.zeros_(module.bias)
104
+ elif isinstance(module, nn.Embedding):
105
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
106
+
107
+ def forward(self, idx, targets=None):
108
+ device = idx.device
109
+ b, t = idx.size()
110
+ assert t <= self.config.block_size
111
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
112
+ tok_emb = self.transformer.wte(idx)
113
+ pos_emb = self.transformer.wpe(pos)
114
+ x = self.transformer.drop(tok_emb + pos_emb)
115
+ for block in self.transformer.h:
116
+ x = block(x)
117
+ x = self.transformer.ln_f(x)
118
+ if targets is not None:
119
+ logits = self.lm_head(x)
120
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
121
+ return logits, loss
122
+ else:
123
+ logits = self.lm_head(x[:, [-1], :])
124
+ return logits, None
125
+
126
+ @torch.no_grad()
127
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
128
+ for _ in range(max_new_tokens):
129
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
130
+ logits, _ = self(idx_cond)
131
+ logits = logits[:, -1, :] / temperature
132
+ if top_k is not None:
133
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
134
+ logits[logits < v[:, [-1]]] = -float('Inf')
135
+ probs = F.softmax(logits, dim=-1)
136
+ idx_next = torch.multinomial(probs, num_samples=1)
137
+ idx = torch.cat((idx, idx_next), dim=1)
138
+ return idx