sreedhayan commited on
Commit
18dbeb4
·
verified ·
1 Parent(s): dc3929d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +151 -0
README.md CHANGED
@@ -1,3 +1,154 @@
1
  ---
2
  license: mit
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - karpathy/tiny_shakespeare
5
  ---
6
+ ShakeGPT
7
+
8
+ **ShakeGPT** is a lightweight, decoder-only Transformer language model trained on the Tiny Shakespeare dataset. It is designed to capture the stylistic patterns, vocabulary, and structure of Shakespearean English at a character level.
9
+
10
+ ## Model Description
11
+ * **Architecture:** Transformer Decoder
12
+ * **Parameters:** ~0.6M
13
+ * **Training Data:** Tiny Shakespeare (1.6MB of raw text)
14
+ * **Tokenization:** Character-level
15
+ * **Context Window:** 128 characters
16
+
17
+ ## Technical Specifications
18
+ | Feature | Value |
19
+ | :--- | :--- |
20
+ | `n_embd` (Embedding Dimension) | 128 |
21
+ | `n_layer` (Transformer Blocks) | 3 |
22
+ | `n_head` (Attention Heads) | 4 |
23
+ | `block_size` (Context Length) | 128 |
24
+ | `dropout` | 0.1 |
25
+
26
+ ---
27
+
28
+ ## Inference Script
29
+
30
+ This script initializes the **ShakeGPT** architecture and loads your saved weights to generate new text.
31
+
32
+ ```python
33
+ import torch
34
+ import torch.nn as nn
35
+ from torch.nn import functional as F
36
+ import os
37
+
38
+ # ==========================================
39
+ # HYPERPARAMETERS (Matched to gpt.py)
40
+ # ==========================================
41
+ device = 'cpu'
42
+ n_embd = 128
43
+ n_head = 4
44
+ n_layer = 3
45
+ block_size = 128 # Fixed mismatch
46
+ dropout = 0.1
47
+ weights_path = 'gpt_weights_best.pth'
48
+
49
+ # Load vocab from same source
50
+ with open('input.txt', 'r', encoding='utf-8') as f:
51
+ text = f.read()
52
+
53
+ chars = sorted(list(set(text)))
54
+ vocab_size = len(chars)
55
+ itos = { i:ch for i,ch in enumerate(chars) }
56
+ decode = lambda l: ''.join([itos[i] for i in l])
57
+
58
+ # ==========================================
59
+ # MODEL ARCHITECTURE (Must be identical)
60
+ # ==========================================
61
+
62
+ class Head(nn.Module):
63
+ def __init__(self, head_size):
64
+ super().__init__()
65
+ self.key = nn.Linear(n_embd, head_size, bias=False)
66
+ self.query = nn.Linear(n_embd, head_size, bias=False)
67
+ self.value = nn.Linear(n_embd, head_size, bias=False)
68
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
69
+ self.dropout = nn.Dropout(dropout)
70
+
71
+ def forward(self, x):
72
+ B,T,C = x.shape
73
+ k, q, v = self.key(x), self.query(x), self.value(x)
74
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
75
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
76
+ wei = F.softmax(wei, dim=-1)
77
+ return self.dropout(wei) @ v
78
+
79
+ class MultiHeadAttention(nn.Module):
80
+ def __init__(self, num_heads, head_size):
81
+ super().__init__()
82
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
83
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x):
87
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
88
+ return self.dropout(self.proj(out))
89
+
90
+ class FeedFoward(nn.Module):
91
+ def __init__(self, n_embd):
92
+ super().__init__()
93
+ self.net = nn.Sequential(
94
+ nn.Linear(n_embd, 4 * n_embd), # Fixed mismatch (4x)
95
+ nn.GELU(), # Fixed mismatch (GELU)
96
+ nn.Linear(4 * n_embd, n_embd),
97
+ nn.Dropout(dropout),
98
+ )
99
+ def forward(self, x): return self.net(x)
100
+
101
+ class Block(nn.Module):
102
+ def __init__(self, n_embd, n_head):
103
+ super().__init__()
104
+ self.sa = MultiHeadAttention(n_head, n_embd // n_head)
105
+ self.ffwd = FeedFoward(n_embd)
106
+ self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
107
+ def forward(self, x):
108
+ x = x + self.sa(self.ln1(x))
109
+ return x + self.ffwd(self.ln2(x))
110
+
111
+ class GPTLanguageModel(nn.Module):
112
+ def __init__(self):
113
+ super().__init__()
114
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
115
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
116
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
117
+ self.ln_f = nn.LayerNorm(n_embd)
118
+ self.lm_head = nn.Linear(n_embd, vocab_size)
119
+
120
+ def forward(self, idx, targets=None):
121
+ B, T = idx.shape
122
+ tok_emb = self.token_embedding_table(idx)
123
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device))
124
+ x = self.blocks(tok_emb + pos_emb)
125
+ logits = self.lm_head(self.ln_f(x))
126
+ return logits, None
127
+
128
+ def generate(self, idx, max_new_tokens):
129
+ for _ in range(max_new_tokens):
130
+ idx_cond = idx[:, -block_size:]
131
+ logits, _ = self(idx_cond)
132
+ probs = F.softmax(logits[:, -1, :], dim=-1)
133
+ idx_next = torch.multinomial(probs, num_samples=1)
134
+ idx = torch.cat((idx, idx_next), dim=1)
135
+ return idx
136
+
137
+ # ==========================================
138
+ # EXECUTION
139
+ # ==========================================
140
+ model = GPTLanguageModel().to(device)
141
+
142
+ if os.path.exists(weights_path):
143
+ model.load_state_dict(torch.load(weights_path, map_location=device))
144
+ model.eval()
145
+ print(f"Loaded weights from {weights_path}")
146
+ else:
147
+ print("Error: Train the model first.")
148
+ exit()
149
+
150
+ num_tokens = int(input("Tokens to generate: ") or 100)
151
+ with torch.no_grad():
152
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
153
+ print("\n--- GENERATED ---\n" + decode(model.generate(context, max_new_tokens=num_tokens)[0].tolist()))
154
+ ```