File size: 4,722 Bytes
c240855 18dbeb4 c240855 18dbeb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | ---
license: mit
datasets:
- karpathy/tiny_shakespeare
---
ShakeGPT
**ShakeGPT** is a lightweight, decoder-only Transformer language model trained on the Tiny Shakespeare dataset. It is designed to capture the stylistic patterns, vocabulary, and structure of Shakespearean English at a character level.
## Model Description
* **Architecture:** Transformer Decoder
* **Parameters:** ~0.6M
* **Training Data:** Tiny Shakespeare (1.6MB of raw text)
* **Tokenization:** Character-level
* **Context Window:** 128 characters
## Technical Specifications
| Feature | Value |
| :--- | :--- |
| `n_embd` (Embedding Dimension) | 128 |
| `n_layer` (Transformer Blocks) | 3 |
| `n_head` (Attention Heads) | 4 |
| `block_size` (Context Length) | 128 |
| `dropout` | 0.1 |
---
## Inference Script
This script initializes the **ShakeGPT** architecture and loads your saved weights to generate new text.
```python
import torch
import torch.nn as nn
from torch.nn import functional as F
import os
# ==========================================
# HYPERPARAMETERS (Matched to gpt.py)
# ==========================================
device = 'cpu'
n_embd = 128
n_head = 4
n_layer = 3
block_size = 128 # Fixed mismatch
dropout = 0.1
weights_path = 'gpt_weights_best.pth'
# Load vocab from same source
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
itos = { i:ch for i,ch in enumerate(chars) }
decode = lambda l: ''.join([itos[i] for i in l])
# ==========================================
# MODEL ARCHITECTURE (Must be identical)
# ==========================================
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k, q, v = self.key(x), self.query(x), self.value(x)
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return self.dropout(wei) @ v
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd), # Fixed mismatch (4x)
nn.GELU(), # Fixed mismatch (GELU)
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.sa = MultiHeadAttention(n_head, n_embd // n_head)
self.ffwd = FeedFoward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
return x + self.ffwd(self.ln2(x))
class GPTLanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = self.blocks(tok_emb + pos_emb)
logits = self.lm_head(self.ln_f(x))
return logits, None
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
probs = F.softmax(logits[:, -1, :], dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# ==========================================
# EXECUTION
# ==========================================
model = GPTLanguageModel().to(device)
if os.path.exists(weights_path):
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
print(f"Loaded weights from {weights_path}")
else:
print("Error: Train the model first.")
exit()
num_tokens = int(input("Tokens to generate: ") or 100)
with torch.no_grad():
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("\n--- GENERATED ---\n" + decode(model.generate(context, max_new_tokens=num_tokens)[0].tolist()))
``` |