wizardoftrap commited on
Commit
f7e18f7
·
verified ·
1 Parent(s): dd24a17

Upload sp_lm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sp_lm.py +146 -0
sp_lm.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+ import numpy as np
7
+ from tqdm.auto import tqdm
8
+ from contextlib import nullcontext
9
+ import os
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, ndim, bias):
13
+ super().__init__()
14
+ self.weight = nn.Parameter(torch.ones(ndim))
15
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
16
+ def forward(self, x):
17
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
18
+
19
+ class CausalSelfAttention(nn.Module):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ assert config.n_embd % config.n_head == 0
23
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
24
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
25
+ self.attn_dropout = nn.Dropout(config.dropout)
26
+ self.resid_dropout = nn.Dropout(config.dropout)
27
+ self.n_head = config.n_head
28
+ self.n_embd = config.n_embd
29
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
30
+ if not self.flash:
31
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
32
+ .view(1, 1, config.block_size, config.block_size))
33
+
34
+ def forward(self, x):
35
+ B, T, C = x.size()
36
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
37
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
38
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
39
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
40
+
41
+ if self.flash:
42
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
43
+ else:
44
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
45
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
46
+ att = F.softmax(att, dim=-1)
47
+ att = self.attn_dropout(att)
48
+ y = att @ v
49
+
50
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
51
+ y = self.resid_dropout(self.c_proj(y))
52
+ return y
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
58
+ self.gelu = nn.GELU()
59
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
60
+ self.dropout = nn.Dropout(config.dropout)
61
+ def forward(self, x):
62
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
63
+
64
+ class Block(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
68
+ self.attn = CausalSelfAttention(config)
69
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
70
+ self.mlp = MLP(config)
71
+ def forward(self, x):
72
+ x = x + self.attn(self.ln1(x))
73
+ x = x + self.mlp(self.ln2(x))
74
+ return x
75
+
76
+ @dataclass
77
+ class GPTConfig:
78
+ block_size: int
79
+ vocab_size: int
80
+ n_layer: int
81
+ n_head: int
82
+ n_embd: int
83
+ dropout: float = 0.0
84
+ bias: bool = True
85
+
86
+ class GPT(nn.Module):
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.config = config
90
+ self.transformer = nn.ModuleDict(dict(
91
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
92
+ wpe=nn.Embedding(config.block_size, config.n_embd),
93
+ drop=nn.Dropout(config.dropout),
94
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
95
+ ln_f=LayerNorm(config.n_embd, config.bias),
96
+ ))
97
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
98
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
99
+
100
+ self.apply(self._init_weights)
101
+ for pn, p in self.named_parameters():
102
+ if pn.endswith('c_proj.weight'):
103
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
104
+
105
+ def _init_weights(self, module):
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ elif isinstance(module, nn.Embedding):
111
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
112
+
113
+ def forward(self, idx, targets=None):
114
+ device = idx.device
115
+ b, t = idx.size()
116
+ assert t <= self.config.block_size
117
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
118
+
119
+ tok_emb = self.transformer.wte(idx)
120
+ pos_emb = self.transformer.wpe(pos)
121
+ x = self.transformer.drop(tok_emb + pos_emb)
122
+ for block in self.transformer.h:
123
+ x = block(x)
124
+ x = self.transformer.ln_f(x)
125
+
126
+ if targets is not None:
127
+ logits = self.lm_head(x)
128
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
129
+ return logits, loss
130
+ else:
131
+ logits = self.lm_head(x[:, [-1], :])
132
+ return logits, None
133
+
134
+ @torch.no_grad()
135
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
136
+ for _ in range(max_new_tokens):
137
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
138
+ logits, _ = self(idx_cond)
139
+ logits = logits[:, -1, :] / temperature
140
+ if top_k is not None:
141
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
142
+ logits[logits < v[:, [-1]]] = -float('Inf')
143
+ probs = F.softmax(logits, dim=-1)
144
+ idx_next = torch.multinomial(probs, num_samples=1)
145
+ idx = torch.cat((idx, idx_next), dim=1)
146
+ return idx