import torch import torch.nn as nn from torch.nn import functional as F import math import tiktoken import gradio as gr import os # Model definition (copied from your training script) # hyperparameters batch_size = 24 # how many independent sequences will we process in parallel? block_size = 256 # what is the maximum context length for predictions? max_iters = int(160000 * 64 / batch_size) # how many batches to train on eval_interval = 500 # how often to evaluate the model learning_rate = 3e-4 # learning rate for optimizer device = 'mps' if torch.backends.mps.is_available( ) else 'cuda' if torch.cuda.is_available() else 'cpu' # use GPU if available eval_iters = 200 # how many batches to use for evaluation n_embd = 384 # embedding dimension n_head = 6 # number of attention heads n_layer = 6 # number of transformer blocks dropout = 0.2 # dropout rate sliding_window_len = 128 # Get vocab size from tiktoken vocab_size = tiktoken.get_encoding("gpt2").n_vocab # Encoder/decoder functions def encode(string): return tiktoken.get_encoding("gpt2").encode(string) def decode(index): return tiktoken.get_encoding("gpt2").decode(index) class FlashAttentionHead(nn.Module): def __init__(self, head_size): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.o_proj = nn.Linear(head_size, head_size, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): # batch size, sequence length, embedding dimension (n_embd) B, T, C = x.shape k = self.key(x) # (B, T, head_size) q = self.query(x) value = self.value(x) # (B, T, head_size) output = F.scaled_dot_product_attention( q, k, value, attn_mask=None, dropout_p=dropout, is_causal=True) output = self.o_proj(output) output = self.dropout(output) return output class MultiHeadAttention(nn.Module): def __init__(self, num_heads, head_size): super().__init__() self.heads = nn.ModuleList(FlashAttentionHead(head_size) for _ in range(num_heads)) self.proj = nn.Linear(head_size * num_heads, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class FFN(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class EfficientMoEFFN(nn.Module): def __init__(self, n_embd, num_experts=4, num_experts_per_token=2): super().__init__() self.num_experts_per_token = num_experts_per_token self.num_experts = num_experts self.experts = nn.ModuleList([FFN(n_embd) for _ in range(num_experts)]) self.gate = nn.Linear(n_embd, num_experts) def forward(self, x): B, T, C = x.shape x_flat = x.view(B*T, C) # Flatten tokens to (batch*tokens, d_model) # Gating gate_scores = self.gate(x_flat) # (B*T, num_experts) topk_scores, topk_indices = torch.topk( gate_scores, self.num_experts_per_token, dim=-1 ) # (B*T, k) topk_probs = F.softmax(topk_scores, dim=-1) # (B*T, k), normalized # Output buffer out = torch.zeros_like(x_flat) # For each expert: route only the tokens assigned to it for expert_id, expert in enumerate(self.experts): # Find where this expert is selected mask = (topk_indices == expert_id) # (B*T, k) if not mask.any(): continue # if it's not part of the top k selected experts for any token, skip it token_ids, which_slot = mask.nonzero(as_tuple=True) # Select actual tokens tokens_for_expert = x_flat[token_ids] # Apply expert FFN expert_out = expert(tokens_for_expert) # (num_tokens, C) # Scale by probability probs = topk_probs[token_ids, which_slot].unsqueeze(-1) expert_out = expert_out * probs # Scatter-add back to output buffer out.index_add_(0, token_ids, expert_out) return out.view(B, T, C) class Block(nn.Module): # block where you have mha and feedforward then layer normalization def __init__(self, n_embd, n_head): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size) self.ffwd = EfficientMoEFFN(n_embd, num_experts=4) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class LanguageModel(nn.Module): def __init__(self): super().__init__() self.token_embed_table = nn.Embedding(vocab_size, n_embd) self.position_embed_table = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential( *[Block(n_embd, n_head) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) # final layer norm self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx, targets=None): B, T = idx.shape token_emb = self.token_embed_table(idx) # (B, T, n_embd) position_emb = self.position_embed_table( torch.arange(T, device=idx.device)) x = token_emb + position_emb # (B, T, n_embd) x = self.blocks(x) # (B, T, n_embd) x = self.ln_f(x) # (B, T, n_embd) logits = self.lm_head(x) # (B, T, vocab_size) if targets is None: loss = None else: B, T, C = logits.shape logits = logits.view(B*T, C) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): # crop idx to the last block_size tokens idx_cond = idx[:, -block_size:] # get the predictions logits, loss = self(idx_cond) # focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply temperature scaling if temperature != 1.0: logits = logits / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # (B, C) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx # Load the model model = LanguageModel().to(device) model_path = "./model_v6_flash_attn.pth" # Check if model file exists if os.path.exists(model_path): model.load_state_dict(torch.load( model_path, map_location=device, weights_only=False)) model.eval() print("Model loaded successfully") else: print("model file not found") # Compile model for better performance model = torch.compile(model) def generate_text(prompt, max_tokens, temperature, top_k): if not os.path.exists(model_path): return "Model not found. Please train the model first." # Encode the prompt idx = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0) # Generate text with torch.no_grad(): generated_idx = model.generate( idx, max_tokens, temperature=temperature, top_k=top_k) # Decode the generated text generated_text = decode(generated_idx[0].tolist()) return generated_text[len(prompt):] # Return only the generated part # Create Gradio interface interface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(lines=5, label="Input Prompt", placeholder="Enter your text prompt here..."), gr.Slider(1, 500, value=100, label="Max Tokens"), gr.Slider(0.1, 2.0, value=1.0, label="Temperature"), gr.Slider(1, 100, value=50, label="Top K") ], outputs=gr.Textbox(label="Generated Text", lines=10), title="Text Generation with Transformer Model", description="Generate text using a trained transformer model. Adjust the parameters to control the output." ) # Launch the app if __name__ == "__main__": interface.launch()