File size: 4,577 Bytes
348cbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Minimal GPT-2-ish decoder-only LM, written for clarity.
from dataclasses import dataclass
import math
import torch
import torch.nn as nn

@dataclass
class GPTConfig:
    vocab_size: int = 16000
    n_layer: int = 6
    n_head: int = 6
    n_embed: int = 384
    block_size: int = 256
    attn_pdrop: float = 0.0
    resid_pdrop: float = 0.0

class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        assert cfg.n_embed % cfg.n_head == 0
        self.n_head = cfg.n_head
        self.key = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
        self.query = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
        self.value = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
        self.proj = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
        self.attn_drop = nn.Dropout(cfg.attn_pdrop)
        self.resid_drop = nn.Dropout(cfg.resid_pdrop)
        self.register_buffer("mask",
            torch.tril(torch.ones(cfg.block_size, cfg.block_size)).view(1,1,cfg.block_size,cfg.block_size)
        )

    def forward(self, x):
        B,T,C = x.size()
        H = self.n_head
        k = self.key(x).view(B,T,H,C//H).transpose(1,2)
        q = self.query(x).view(B,T,H,C//H).transpose(1,2)
        v = self.value(x).view(B,T,H,C//H).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
        att = torch.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B,T,C)
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.n_embed)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.n_embed)
        self.mlp = nn.Sequential(
            nn.Linear(cfg.n_embed, 4*cfg.n_embed),
            nn.GELU(),
            nn.Linear(4*cfg.n_embed, cfg.n_embed),
            nn.Dropout(cfg.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TinyGPT2(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embed)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embed)
        self.drop = nn.Dropout(cfg.resid_pdrop)
        self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
        self.ln_f = nn.LayerNorm(cfg.n_embed)
        self.head = nn.Linear(cfg.n_embed, cfg.vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        if isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=64, top_k=50, top_p=0.95, temperature=1.0):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.cfg.block_size:]
            logits = self(idx_cond)[:, -1, :] / max(temperature, 1e-5)
            logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

    @staticmethod
    def _top_k_top_p_filtering(logits, top_k=0, top_p=1.0):
        if top_k and top_k > 0:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            idx = cumprobs > top_p
            idx[..., 1:] = idx[..., :-1].clone()
            idx[..., 0] = 0
            sorted_logits[idx] = -float("inf")
            logits.scatter_(1, sorted_indices, sorted_logits)
        return logits

    def forward(self, idx):
        B,T = idx.size()
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)