| import json |
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class GPTConfig: |
| block_size: int = 1024 |
| vocab_size: int = 50304 |
| n_layer: int = 12 |
| n_head: int = 12 |
| n_embd: int = 768 |
| bias: bool = True |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, ndim, bias): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(ndim)) |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| def forward(self, input): |
| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size)) |
| def forward(self, x): |
| B, T, C = x.size() |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) |
| att = F.softmax(att, dim=-1) |
| y = att @ v |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.c_proj(y) |
| return y |
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| def forward(self, x): |
| x = self.c_fc(x) |
| x = F.gelu(x) |
| x = self.c_proj(x) |
| return x |
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
| self.mlp = MLP(config) |
| def forward(self, x): |
| x = x + self.attn(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
| class AttoGPT(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.transformer = nn.ModuleDict(dict( |
| wte = nn.Embedding(config.vocab_size, config.n_embd), |
| wpe = nn.Embedding(config.block_size, config.n_embd), |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f = LayerNorm(config.n_embd, bias=config.bias), |
| )) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| def forward(self, idx): |
| device = idx.device |
| b, t = idx.size() |
| pos = torch.arange(0, t, dtype=torch.long, device=device) |
| tok_emb = self.transformer.wte(idx) |
| pos_emb = self.transformer.wpe(pos) |
| x = tok_emb + pos_emb |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| logits = self.lm_head(x) |
| return logits |
|
|
| def load_model(path): |
| with open(path, "r") as f: |
| data = json.load(f) |
| cfg = GPTConfig(**data["config"]) |
| model = AttoGPT(cfg) |
| sd = {k: torch.tensor(v) for k, v in data["weights"].items()} |
| model.load_state_dict(sd) |
| itos = {int(k): v for k, v in data["vocab"].items()} |
| stoi = {v: k for k, v in itos.items()} |
| return model, itos, stoi, cfg |
|
|
| @torch.no_grad() |
| def generate(model, stoi, itos, block_size, prompt=" ", length=100, temperature=0.8): |
| model.eval() |
| tokens = [stoi.get(c, 0) for c in prompt] |
| if not tokens: tokens = [0] |
| idx = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) |
| for _ in range(length): |
| idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] |
| logits = model(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, next_token), dim=1) |
| return "".join(itos.get(t.item(), "?") for t in idx[0][len(tokens):]) |
|
|
| if __name__ == "__main__": |
| models_dir = "models" |
| model_files = sorted([f for f in os.listdir(models_dir) if f.endswith(".json")]) |
| for filename in model_files: |
| path = os.path.join(models_dir, filename) |
| model, itos, stoi, cfg = load_model(path) |
| print(f"\n{'='*60}\n {filename}\n{'='*60}") |
| for prompt in [" the ", " to be", " Ham"]: |
| text = generate(model, stoi, itos, cfg.block_size, prompt.strip(), length=80) |
| print(f' prompt="{prompt.strip()}":\n {text}\n') |
|
|