File size: 6,619 Bytes
c5f49b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional

# ================= DEVICE =================
device = "cpu"
torch.set_float32_matmul_precision("high")

# ================= MODEL CONFIG =================
@dataclass(frozen=True)
class GPTConfig:
    n_embd: int = 192
    n_head: int = 6
    n_layer: int = 6
    block_size: int = 256
    dropout: float = 0.1

    def validate(self) -> None:
        if self.n_embd <= 0 or self.n_head <= 0 or self.n_layer <= 0:
            raise ValueError("Invalid config: n_embd/n_head/n_layer must be > 0")
        if self.block_size <= 8:
            raise ValueError("Invalid config: block_size must be > 8")
        if self.n_embd % self.n_head != 0:
            raise ValueError("Invalid config: n_embd must be divisible by n_head")
        if not (0.0 <= float(self.dropout) <= 0.5):
            raise ValueError("Invalid config: dropout must be in [0, 0.5]")

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


DEFAULT_CONFIG = GPTConfig()
DEFAULT_CONFIG.validate()

# Back-compat exports (older scripts import these symbols).
n_embd = DEFAULT_CONFIG.n_embd
n_head = DEFAULT_CONFIG.n_head
n_layer = DEFAULT_CONFIG.n_layer
block_size = DEFAULT_CONFIG.block_size
dropout = DEFAULT_CONFIG.dropout


# ================= RMSNorm =================
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return self.weight * x * torch.rsqrt(
            x.pow(2).mean(-1, keepdim=True) + 1e-6
        )


# ================= SELF ATTENTION =================
class SelfAttention(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.qkv = nn.Linear(cfg.n_embd, 3 * cfg.n_embd, bias=False)
        self.proj = nn.Linear(cfg.n_embd, cfg.n_embd)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x):
        bsz, tsz, channels = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(bsz, tsz, self.cfg.n_head, channels // self.cfg.n_head).transpose(1, 2)
        k = k.view(bsz, tsz, self.cfg.n_head, channels // self.cfg.n_head).transpose(1, 2)
        v = v.view(bsz, tsz, self.cfg.n_head, channels // self.cfg.n_head).transpose(1, 2)

        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            is_causal=True,
            dropout_p=self.cfg.dropout if self.training else 0.0,
        )
        out = out.transpose(1, 2).contiguous().view(bsz, tsz, channels)
        return self.dropout(self.proj(out))


# ================= FEED FORWARD =================
class FeedForward(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.net = nn.Sequential(
            nn.Linear(cfg.n_embd, 4 * cfg.n_embd),
            nn.GELU(),
            nn.Linear(4 * cfg.n_embd, cfg.n_embd),
            nn.Dropout(cfg.dropout),
        )

    def forward(self, x):
        return self.net(x)


# ================= TRANSFORMER BLOCK =================
class Block(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.ln1 = RMSNorm(cfg.n_embd)
        self.ln2 = RMSNorm(cfg.n_embd)
        self.attn = SelfAttention(cfg)
        self.ff = FeedForward(cfg)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


# ================= GPT MODEL =================
class GPT(nn.Module):
    def __init__(self, vocab_size: int, cfg: Optional[GPTConfig] = None):
        super().__init__()
        cfg = cfg or DEFAULT_CONFIG
        cfg.validate()
        self.cfg = cfg

        self.token_emb = nn.Embedding(vocab_size, cfg.n_embd)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.Sequential(*[Block(cfg) for _ in range(cfg.n_layer)])
        self.ln_f = RMSNorm(cfg.n_embd)
        self.head = nn.Linear(cfg.n_embd, vocab_size)

    def forward(self, idx, targets=None):
        bsz, tsz = idx.shape
        if tsz > self.cfg.block_size:
            raise ValueError(
                f"Sequence length {tsz} exceeds block_size {self.cfg.block_size}."
            )

        pos = torch.arange(0, tsz, device=idx.device)
        x = self.token_emb(idx) + self.pos_emb(pos)[None, :, :]
        x = self.drop(x)
        x = self.blocks(x)
        logits = self.head(self.ln_f(x))

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )
        return logits, loss


# ================= SIMPLE BPE TOKENIZER =================
class SimpleBPETokenizer:
    def __init__(self):
        self.vocab = {}    # {int: bytes}
        self.merges = {}   # {(int, int): int}

    @lru_cache(maxsize=32768)
    def _encode_cached(self, text: str):
        tokens = list(text.encode("utf-8", errors="ignore"))

        while len(tokens) >= 2:
            best_i = None
            best_rank = None
            for i in range(len(tokens) - 1):
                rank = self.merges.get((tokens[i], tokens[i + 1]))
                if rank is None:
                    continue
                if best_rank is None or rank < best_rank:
                    best_rank = rank
                    best_i = i

            if best_i is None:
                break

            merged = self.merges[(tokens[best_i], tokens[best_i + 1])]
            tokens = tokens[:best_i] + [merged] + tokens[best_i + 2 :]

        return tuple(tokens)

    def encode(self, text: str):
        return list(self._encode_cached(text))

    def decode(self, tokens):
        byte_stream = b"".join(self.vocab.get(t, b"") for t in tokens)
        return byte_stream.decode("utf-8", errors="ignore")


def config_from_dict(d: Optional[Dict[str, Any]]) -> GPTConfig:
    if not d:
        return DEFAULT_CONFIG
    cfg = GPTConfig(
        n_embd=int(d.get("n_embd", DEFAULT_CONFIG.n_embd)),
        n_head=int(d.get("n_head", DEFAULT_CONFIG.n_head)),
        n_layer=int(d.get("n_layer", DEFAULT_CONFIG.n_layer)),
        block_size=int(d.get("block_size", DEFAULT_CONFIG.block_size)),
        dropout=float(d.get("dropout", DEFAULT_CONFIG.dropout)),
    )
    cfg.validate()
    return cfg