Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math | |
| import tiktoken | |
| import gradio as gr | |
| import os | |
| # Model definition (copied from your training script) | |
| # hyperparameters | |
| batch_size = 24 # how many independent sequences will we process in parallel? | |
| block_size = 256 # what is the maximum context length for predictions? | |
| max_iters = int(160000 * 64 / batch_size) # how many batches to train on | |
| eval_interval = 500 # how often to evaluate the model | |
| learning_rate = 3e-4 # learning rate for optimizer | |
| device = 'mps' if torch.backends.mps.is_available( | |
| ) else 'cuda' if torch.cuda.is_available() else 'cpu' # use GPU if available | |
| eval_iters = 200 # how many batches to use for evaluation | |
| n_embd = 384 # embedding dimension | |
| n_head = 6 # number of attention heads | |
| n_layer = 6 # number of transformer blocks | |
| dropout = 0.2 # dropout rate | |
| sliding_window_len = 128 | |
| # Get vocab size from tiktoken | |
| vocab_size = tiktoken.get_encoding("gpt2").n_vocab | |
| # Encoder/decoder functions | |
| def encode(string): | |
| return tiktoken.get_encoding("gpt2").encode(string) | |
| def decode(index): | |
| return tiktoken.get_encoding("gpt2").decode(index) | |
| class FlashAttentionHead(nn.Module): | |
| def __init__(self, head_size): | |
| super().__init__() | |
| self.key = nn.Linear(n_embd, head_size, bias=False) | |
| self.query = nn.Linear(n_embd, head_size, bias=False) | |
| self.value = nn.Linear(n_embd, head_size, bias=False) | |
| self.o_proj = nn.Linear(head_size, head_size, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| # batch size, sequence length, embedding dimension (n_embd) | |
| B, T, C = x.shape | |
| k = self.key(x) # (B, T, head_size) | |
| q = self.query(x) | |
| value = self.value(x) # (B, T, head_size) | |
| output = F.scaled_dot_product_attention( | |
| q, k, value, attn_mask=None, dropout_p=dropout, is_causal=True) | |
| output = self.o_proj(output) | |
| output = self.dropout(output) | |
| return output | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, num_heads, head_size): | |
| super().__init__() | |
| self.heads = nn.ModuleList(FlashAttentionHead(head_size) | |
| for _ in range(num_heads)) | |
| self.proj = nn.Linear(head_size * num_heads, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| out = torch.cat([h(x) for h in self.heads], dim=-1) | |
| out = self.dropout(self.proj(out)) | |
| return out | |
| class FFN(nn.Module): | |
| def __init__(self, n_embd): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(n_embd, 4 * n_embd), | |
| nn.ReLU(), | |
| nn.Linear(4 * n_embd, n_embd), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class EfficientMoEFFN(nn.Module): | |
| def __init__(self, n_embd, num_experts=4, num_experts_per_token=2): | |
| super().__init__() | |
| self.num_experts_per_token = num_experts_per_token | |
| self.num_experts = num_experts | |
| self.experts = nn.ModuleList([FFN(n_embd) for _ in range(num_experts)]) | |
| self.gate = nn.Linear(n_embd, num_experts) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| x_flat = x.view(B*T, C) # Flatten tokens to (batch*tokens, d_model) | |
| # Gating | |
| gate_scores = self.gate(x_flat) # (B*T, num_experts) | |
| topk_scores, topk_indices = torch.topk( | |
| gate_scores, self.num_experts_per_token, dim=-1 | |
| ) # (B*T, k) | |
| topk_probs = F.softmax(topk_scores, dim=-1) # (B*T, k), normalized | |
| # Output buffer | |
| out = torch.zeros_like(x_flat) | |
| # For each expert: route only the tokens assigned to it | |
| for expert_id, expert in enumerate(self.experts): | |
| # Find where this expert is selected | |
| mask = (topk_indices == expert_id) # (B*T, k) | |
| if not mask.any(): | |
| continue # if it's not part of the top k selected experts for any token, skip it | |
| token_ids, which_slot = mask.nonzero(as_tuple=True) | |
| # Select actual tokens | |
| tokens_for_expert = x_flat[token_ids] | |
| # Apply expert FFN | |
| expert_out = expert(tokens_for_expert) # (num_tokens, C) | |
| # Scale by probability | |
| probs = topk_probs[token_ids, which_slot].unsqueeze(-1) | |
| expert_out = expert_out * probs | |
| # Scatter-add back to output buffer | |
| out.index_add_(0, token_ids, expert_out) | |
| return out.view(B, T, C) | |
| class Block(nn.Module): | |
| # block where you have mha and feedforward then layer normalization | |
| def __init__(self, n_embd, n_head): | |
| super().__init__() | |
| head_size = n_embd // n_head | |
| self.sa = MultiHeadAttention(n_head, head_size) | |
| self.ffwd = EfficientMoEFFN(n_embd, num_experts=4) | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| def forward(self, x): | |
| x = x + self.sa(self.ln1(x)) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |
| class LanguageModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_embed_table = nn.Embedding(vocab_size, n_embd) | |
| self.position_embed_table = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.Sequential( | |
| *[Block(n_embd, n_head) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) # final layer norm | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| token_emb = self.token_embed_table(idx) # (B, T, n_embd) | |
| position_emb = self.position_embed_table( | |
| torch.arange(T, device=idx.device)) | |
| x = token_emb + position_emb # (B, T, n_embd) | |
| x = self.blocks(x) # (B, T, n_embd) | |
| x = self.ln_f(x) # (B, T, n_embd) | |
| logits = self.lm_head(x) # (B, T, vocab_size) | |
| if targets is None: | |
| loss = None | |
| else: | |
| B, T, C = logits.shape | |
| logits = logits.view(B*T, C) | |
| targets = targets.view(B*T) | |
| loss = F.cross_entropy(logits, targets) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): | |
| # idx is (B, T) array of indices in the current context | |
| for _ in range(max_new_tokens): | |
| # crop idx to the last block_size tokens | |
| idx_cond = idx[:, -block_size:] | |
| # get the predictions | |
| logits, loss = self(idx_cond) | |
| # focus only on the last time step | |
| logits = logits[:, -1, :] # becomes (B, C) | |
| # apply temperature scaling | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # optionally crop the logits to only the top k options | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| # apply softmax to get probabilities | |
| probs = F.softmax(logits, dim=-1) # (B, C) | |
| # sample from the distribution | |
| idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) | |
| # append sampled index to the running sequence | |
| idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) | |
| return idx | |
| # Load the model | |
| model = LanguageModel().to(device) | |
| model_path = "./model_v6_flash_attn.pth" | |
| # Check if model file exists | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load( | |
| model_path, map_location=device, weights_only=False)) | |
| model.eval() | |
| print("Model loaded successfully") | |
| else: | |
| print("model file not found") | |
| # Compile model for better performance | |
| model = torch.compile(model) | |
| def generate_text(prompt, max_tokens, temperature, top_k): | |
| if not os.path.exists(model_path): | |
| return "Model not found. Please train the model first." | |
| # Encode the prompt | |
| idx = torch.tensor(encode(prompt), dtype=torch.long, | |
| device=device).unsqueeze(0) | |
| # Generate text | |
| with torch.no_grad(): | |
| generated_idx = model.generate( | |
| idx, max_tokens, temperature=temperature, top_k=top_k) | |
| # Decode the generated text | |
| generated_text = decode(generated_idx[0].tolist()) | |
| return generated_text[len(prompt):] # Return only the generated part | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(lines=5, label="Input Prompt", | |
| placeholder="Enter your text prompt here..."), | |
| gr.Slider(1, 500, value=100, label="Max Tokens"), | |
| gr.Slider(0.1, 2.0, value=1.0, label="Temperature"), | |
| gr.Slider(1, 100, value=50, label="Top K") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", lines=10), | |
| title="Text Generation with Transformer Model", | |
| description="Generate text using a trained transformer model. Adjust the parameters to control the output." | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |