File size: 4,087 Bytes
a69eabc
 
 
d1ee6e0
 
a69eabc
b074d6c
 
 
 
 
 
 
d1ee6e0
 
 
1df5077
d1ee6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffde90b
d1ee6e0
 
9125647
d1ee6e0
 
 
 
 
 
 
 
ffde90b
 
9125647
ffde90b
 
 
 
 
9125647
ffde90b
 
 
 
9125647
 
 
 
 
 
 
ffde90b
 
 
 
 
 
 
 
 
d1ee6e0
 
 
 
 
68e2663
1df5077
 
d1ee6e0
 
 
 
 
 
 
 
9125647
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from torch.nn import functional as F
import json
import re

class BPETokenizer:
    def __init__(self, model_type="gpt2"):
        import tiktoken
        self.enc = tiktoken.get_encoding(model_type)
    def encode(self, text):
        return self.enc.encode(text, allowed_special={'<|endoftext|>'})
    def decode(self, ids):
        return self.enc.decode(ids)

class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=768, n_layers=12, n_heads=12, ctx_len=1024, dropout=0.1):
        super().__init__()
        self.ctx_len = ctx_len
        self.n_heads = n_heads
        self.emb_dim = emb_dim
        self.n_layers = n_layers
        self.token_embedding_table    = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding_table = nn.Embedding(ctx_len, emb_dim)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=emb_dim,
                nhead=n_heads,
                dim_feedforward=emb_dim * 4,
                dropout=dropout,
                batch_first=True,
                norm_first=True,
                activation='gelu'
            ) for _ in range(n_layers)
        ])
        self.ln_f   = nn.LayerNorm(emb_dim)
        self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        device = idx.device
        B, T = idx.shape
        idx = torch.clamp(idx, 0, self.token_embedding_table.num_embeddings - 1)
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = self.drop(tok_emb + pos_emb)
        mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
        for block in self.blocks:
            x = block(x, src_mask=mask, is_causal=True)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            targets = torch.clamp(targets, 0, self.lm_head.out_features - 1)
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.ctx_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if repetition_penalty != 1.0:
                for b in range(logits.shape[0]):
                    for token_id in set(idx[b].tolist()):
                        if logits[b, token_id] < 0:
                            logits[b, token_id] *= repetition_penalty
                        else:
                            logits[b, token_id] /= repetition_penalty
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            if idx_next.item() == 50256:
                break
        return idx

    @classmethod
    def load(cls, path, device='cpu'):
        ckpt = torch.load(path, map_location=device, weights_only=False)
        state_dict = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt
        cfg = {'vocab_size': 50257, 'emb_dim': 1024, 'n_layers': 24, 'n_heads': 16, 'ctx_len': 1024}
        if isinstance(ckpt, dict) and 'config' in ckpt:
            cfg = ckpt['config']
        model = cls(cfg['vocab_size'], cfg['emb_dim'], cfg['n_layers'], cfg['n_heads'], cfg['ctx_len'])
        new_state_dict = {}
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict, strict=False)
        model.to(device)
        model.eval()
        return model