File size: 9,106 Bytes
4155d71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fcf4ac
4155d71
 
 
9fcf4ac
 
4155d71
 
 
9fcf4ac
4155d71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import tiktoken
import gradio as gr
import os

# Model definition (copied from your training script)
# hyperparameters
batch_size = 24  # how many independent sequences will we process in parallel?
block_size = 256  # what is the maximum context length for predictions?
max_iters = int(160000 * 64 / batch_size)  # how many batches to train on
eval_interval = 500  # how often to evaluate the model
learning_rate = 3e-4  # learning rate for optimizer
device = 'mps' if torch.backends.mps.is_available(
) else 'cuda' if torch.cuda.is_available() else 'cpu'  # use GPU if available
eval_iters = 200  # how many batches to use for evaluation
n_embd = 384  # embedding dimension
n_head = 6  # number of attention heads
n_layer = 6  # number of transformer blocks
dropout = 0.2  # dropout rate
sliding_window_len = 128

# Get vocab size from tiktoken
vocab_size = tiktoken.get_encoding("gpt2").n_vocab

# Encoder/decoder functions


def encode(string):
    return tiktoken.get_encoding("gpt2").encode(string)


def decode(index):
    return tiktoken.get_encoding("gpt2").decode(index)


class FlashAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.o_proj = nn.Linear(head_size, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # batch size, sequence length, embedding dimension (n_embd)
        B, T, C = x.shape
        k = self.key(x)    # (B, T, head_size)
        q = self.query(x)
        value = self.value(x)  # (B, T, head_size)
        output = F.scaled_dot_product_attention(
            q, k, value, attn_mask=None, dropout_p=dropout, is_causal=True)
        output = self.o_proj(output)
        output = self.dropout(output)
        return output


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList(FlashAttentionHead(head_size)
                                   for _ in range(num_heads))
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FFN(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class EfficientMoEFFN(nn.Module):
    def __init__(self, n_embd, num_experts=4, num_experts_per_token=2):
        super().__init__()
        self.num_experts_per_token = num_experts_per_token
        self.num_experts = num_experts
        self.experts = nn.ModuleList([FFN(n_embd) for _ in range(num_experts)])
        self.gate = nn.Linear(n_embd, num_experts)

    def forward(self, x):
        B, T, C = x.shape
        x_flat = x.view(B*T, C)  # Flatten tokens to (batch*tokens, d_model)

        # Gating
        gate_scores = self.gate(x_flat)   # (B*T, num_experts)
        topk_scores, topk_indices = torch.topk(
            gate_scores, self.num_experts_per_token, dim=-1
        )  # (B*T, k)
        topk_probs = F.softmax(topk_scores, dim=-1)  # (B*T, k), normalized

        # Output buffer
        out = torch.zeros_like(x_flat)

        # For each expert: route only the tokens assigned to it
        for expert_id, expert in enumerate(self.experts):
            # Find where this expert is selected
            mask = (topk_indices == expert_id)  # (B*T, k)
            if not mask.any():
                continue  # if it's not part of the top k selected experts for any token, skip it

            token_ids, which_slot = mask.nonzero(as_tuple=True)

            # Select actual tokens
            tokens_for_expert = x_flat[token_ids]

            # Apply expert FFN
            expert_out = expert(tokens_for_expert)  # (num_tokens, C)

            # Scale by probability
            probs = topk_probs[token_ids, which_slot].unsqueeze(-1)
            expert_out = expert_out * probs

            # Scatter-add back to output buffer
            out.index_add_(0, token_ids, expert_out)

        return out.view(B, T, C)


class Block(nn.Module):
    # block where you have mha and feedforward then layer normalization
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = EfficientMoEFFN(n_embd, num_experts=4)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embed_table = nn.Embedding(vocab_size, n_embd)
        self.position_embed_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embed_table(idx)  # (B, T, n_embd)
        position_emb = self.position_embed_table(
            torch.arange(T, device=idx.device))

        x = token_emb + position_emb  # (B, T, n_embd)
        x = self.blocks(x)  # (B, T, n_embd)
        x = self.ln_f(x)  # (B, T, n_embd)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply temperature scaling
            if temperature != 1.0:
                logits = logits / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


# Load the model
model = LanguageModel().to(device)
model_path = "./model_v6_flash_attn.pth"

# Check if model file exists
if os.path.exists(model_path):
    model.load_state_dict(torch.load(
        model_path, map_location=device, weights_only=False))
    model.eval()
    print("Model loaded successfully")
else:
    print("model file not found")

# Compile model for better performance
model = torch.compile(model)


def generate_text(prompt, max_tokens, temperature, top_k):
    if not os.path.exists(model_path):
        return "Model not found. Please train the model first."

    # Encode the prompt
    idx = torch.tensor(encode(prompt), dtype=torch.long,
                       device=device).unsqueeze(0)

    # Generate text
    with torch.no_grad():
        generated_idx = model.generate(
            idx, max_tokens, temperature=temperature, top_k=top_k)

    # Decode the generated text
    generated_text = decode(generated_idx[0].tolist())
    return generated_text[len(prompt):]  # Return only the generated part


# Create Gradio interface
interface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=5, label="Input Prompt",
                   placeholder="Enter your text prompt here..."),
        gr.Slider(1, 500, value=100, label="Max Tokens"),
        gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
        gr.Slider(1, 100, value=50, label="Top K")
    ],
    outputs=gr.Textbox(label="Generated Text", lines=10),
    title="Text Generation with Transformer Model",
    description="Generate text using a trained transformer model. Adjust the parameters to control the output."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()