File size: 3,929 Bytes
18ac206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from huggingface_hub import hf_hub_download
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import urllib.request
import os
from transformers import AutoTokenizer, logging
import pandas as pd
from tqdm import tqdm
from safetensors.torch import load_file


class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(
            emb_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.GELU(),
            nn.Linear(4 * emb_dim, emb_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        attn_out, _ = self.attn(
            self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
        )
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))
        return x


class MiniTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        context_length,
        num_heads,
        num_layers,
        dropout=0.1,
    ):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb = nn.Embedding(context_length, emb_dim)
        self.blocks = nn.Sequential(
            *[
                TransformerBlock(emb_dim, num_heads, context_length, dropout)
                for _ in range(num_layers)
            ]
        )
        self.ln_f = nn.LayerNorm(emb_dim)
        self.head = nn.Linear(emb_dim, vocab_size, bias=False)
        self.context_length = context_length

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T, device=x.device)
        x = self.emb(x) + self.pos_emb(pos)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

    @torch.no_grad()
    def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None):

        for _ in range(max_new_tokens):
            # truncate context if needed
            x_cond = x[:, -self.context_length :]

            # get predictions
            logits = self(x_cond)  # (B, T_cond, vocab_size)
            logits = logits[:, -1, :] / temperature  # only last position

            # optionally restrict to top-k

            probs = F.softmax(logits, dim=-1)

            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # next_token = torch.argmax(probs, dim = 1).unsqueeze(-1)
            # append to sequence
            x = torch.cat([x, next_token], dim=1)

        return x


CONTEXT_LENGTH = 256
EMBEDDING_DIMENSION = 512
HEAD_NUMBER = 8
N_LAYER = 6
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

# Download the model file
model_path = hf_hub_download(
    repo_id="pierjoe/MiniTransformer",
    filename="checkpoints/mini_transformer_v4/model_50.safetensors",
)

# Load with your custom class
model = MiniTransformer(
    vocab_size=tokenizer.vocab_size,
    emb_dim=EMBEDDING_DIMENSION,
    context_length=CONTEXT_LENGTH,
    num_heads=HEAD_NUMBER,
    num_layers=N_LAYER,
).to(device)
state_dict = load_file(model_path)
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)

model.eval()
max_tokens = 100
prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user "
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids = model.generate(
    input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generated_text