File size: 4,483 Bytes
ff5d275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7517b8
ff5d275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7517b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer


def load_tinylm(model_dir, device="cpu"):
    # Load config
    with open(f"{model_dir}/config.json") as f:
        config = json.load(f)

    VOCAB_SIZE  = config["vocab_size"]
    EMBED_RANK  = config["embed_rank"]
    D_MODEL     = config["d_model"]
    N_HEADS     = config["n_heads"]
    FFN_DIM     = config["ffn_dim"]
    N_LAYERS    = config["n_layers"]
    MAX_SEQ_LEN = config["max_seq_len"]
    DROPOUT     = config["dropout"]

    class FactoredEmbedding(nn.Module):
        def __init__(self, vocab_size, rank, d_model):
            super().__init__()
            self.in_proj  = nn.Embedding(vocab_size, rank)
            self.out_proj = nn.Linear(rank, d_model, bias=False)

        def forward(self, x):
            return self.out_proj(self.in_proj(x))

    class TransformerBlock(nn.Module):
        def __init__(self):
            super().__init__()
            self.ln1  = nn.LayerNorm(D_MODEL)
            self.attn = nn.MultiheadAttention(D_MODEL, N_HEADS, dropout=DROPOUT, batch_first=True)
            self.ln2  = nn.LayerNorm(D_MODEL)
            self.ffn  = nn.Sequential(
                nn.Linear(D_MODEL, FFN_DIM),
                nn.GELU(),
                nn.Linear(FFN_DIM, D_MODEL),
                nn.Dropout(DROPOUT),
            )

        def forward(self, x, attn_mask=None, key_padding_mask=None):
            x_norm = self.ln1(x)
            attn_out, _ = self.attn(x_norm, x_norm, x_norm,
                                    attn_mask=attn_mask,
                                    key_padding_mask=key_padding_mask,
                                    is_causal=True)
            x = x + attn_out
            x = x + self.ffn(self.ln2(x))
            return x

    class TinyLM(nn.Module):
        def __init__(self):
            super().__init__()
            self.tok_emb  = FactoredEmbedding(VOCAB_SIZE, EMBED_RANK, D_MODEL)
            self.pos_emb  = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
            self.drop     = nn.Dropout(DROPOUT)
            self.blocks   = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
            self.ln_final = nn.LayerNorm(D_MODEL)
            self.head_down  = nn.Linear(D_MODEL, EMBED_RANK, bias=False)
            self.head_vocab = nn.Linear(EMBED_RANK, VOCAB_SIZE, bias=False)
            self.head_vocab.weight = nn.Parameter(self.tok_emb.in_proj.weight)

        def forward(self, idx):
            B, T = idx.shape
            if T > MAX_SEQ_LEN:
                idx = idx[:, :MAX_SEQ_LEN]
            T = idx.shape[1]
            positions = torch.arange(T, device=idx.device).unsqueeze(0)
            x = self.drop(self.tok_emb(idx) + self.pos_emb(positions))
            mask = nn.Transformer.generate_square_subsequent_mask(T, device=idx.device)
            for block in self.blocks:
                x = block(x, attn_mask=mask)
            x = self.ln_final(x)
            x = self.head_down(x)
            return self.head_vocab(x)

    # Build and load weights
    model = TinyLM().to(device)
    state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer, config


def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"):
    MAX_SEQ_LEN = model.pos_emb.num_embeddings
    model.eval()
    ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        for _ in range(max_new_tokens):
            idx_cond = ids[:, -MAX_SEQ_LEN:]
            logits   = model(idx_cond)
            logits   = logits[:, -1, :] / temperature
            if top_k is not None:
                values, _ = torch.topk(logits, top_k)
                logits[logits < values[:, -1:]] = -float("inf")
            probs   = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            if next_id.item() == tokenizer.eos_token_id:
                break
            ids = torch.cat([ids, next_id], dim=1)

    return tokenizer.decode(ids[0], skip_special_tokens=True)


if __name__ == "__main__":
    model, tokenizer, config = load_tinylm("./tinylm")
    print("Model loaded!")
    print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.")