shwethd commited on
Commit
3216812
·
verified ·
1 Parent(s): 95db43d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +246 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Spaces App for GPT-2 124M Shakespeare Model
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import tiktoken
8
+ import gradio as gr
9
+ import math
10
+ from dataclasses import dataclass
11
+
12
+
13
+ class CausalSelfAttention(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ assert config.n_embd % config.n_head == 0
17
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
18
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
19
+ self.c_proj.NANOGPT_SCALE_INIT = 1
20
+ self.n_head = config.n_head
21
+ self.n_embd = config.n_embd
22
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
23
+
24
+ def forward(self, x):
25
+ B, T, C = x.size()
26
+ qkv = self.c_attn(x)
27
+ q, k, v = qkv.split(self.n_embd, dim=2)
28
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
29
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
30
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
31
+
32
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
33
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
34
+ att = F.softmax(att, dim=-1)
35
+ y = att @ v
36
+
37
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
38
+ y = self.c_proj(y)
39
+ return y
40
+
41
+
42
+ class MLP(nn.Module):
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
46
+ self.gelu = nn.GELU(approximate='tanh')
47
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
48
+ self.c_proj.NANOGPT_SCALE_INIT = 1
49
+
50
+ def forward(self, x):
51
+ x = self.c_fc(x)
52
+ x = self.gelu(x)
53
+ x = self.c_proj(x)
54
+ return x
55
+
56
+
57
+ class Block(nn.Module):
58
+ def __init__(self, config):
59
+ super().__init__()
60
+ self.ln_1 = nn.LayerNorm(config.n_embd)
61
+ self.attn = CausalSelfAttention(config)
62
+ self.ln_2 = nn.LayerNorm(config.n_embd)
63
+ self.mlp = MLP(config)
64
+
65
+ def forward(self, x):
66
+ x = x + self.attn(self.ln_1(x))
67
+ x = x + self.mlp(self.ln_2(x))
68
+ return x
69
+
70
+
71
+ @dataclass
72
+ class GPTConfig:
73
+ block_size: int = 1024
74
+ vocab_size: int = 50257
75
+ n_layer: int = 12
76
+ n_head: int = 12
77
+ n_embd: int = 768
78
+
79
+
80
+ class GPT(nn.Module):
81
+ def __init__(self, config):
82
+ super().__init__()
83
+ self.config = config
84
+
85
+ self.transformer = nn.ModuleDict(dict(
86
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
87
+ wpe=nn.Embedding(config.block_size, config.n_embd),
88
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
89
+ ln_f=nn.LayerNorm(config.n_embd),
90
+ ))
91
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
92
+ self.transformer.wte.weight = self.lm_head.weight
93
+
94
+ def forward(self, idx, targets=None):
95
+ B, T = idx.size()
96
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
97
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
98
+ pos_emb = self.transformer.wpe(pos)
99
+ tok_emb = self.transformer.wte(idx)
100
+ x = tok_emb + pos_emb
101
+ for block in self.transformer.h:
102
+ x = block(x)
103
+ x = self.transformer.ln_f(x)
104
+ logits = self.lm_head(x)
105
+ loss = None
106
+ if targets is not None:
107
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
108
+ return logits, loss
109
+
110
+
111
+ # Load model
112
+ print("Loading model...")
113
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
114
+ config = GPTConfig()
115
+ model = GPT(config)
116
+
117
+ # Try to load model (works both locally and on HuggingFace)
118
+ try:
119
+ checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
120
+ model.load_state_dict(checkpoint['model_state_dict'])
121
+ print("Model loaded from checkpoint")
122
+ except FileNotFoundError:
123
+ print("Warning: Model checkpoint not found. Using untrained model.")
124
+ # Model will be randomly initialized - not ideal but won't crash
125
+
126
+ model.to(device)
127
+ model.eval()
128
+ print(f"Model ready on {device}")
129
+
130
+ enc = tiktoken.get_encoding('gpt2')
131
+
132
+
133
+ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
134
+ """Generate text from prompt"""
135
+ try:
136
+ # Encode prompt
137
+ tokens = enc.encode(prompt)
138
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
139
+
140
+ # Generate
141
+ with torch.no_grad():
142
+ for _ in range(max_new_tokens):
143
+ # Forward pass
144
+ logits, _ = model(tokens)
145
+ logits = logits[:, -1, :] / temperature
146
+
147
+ # Top-k sampling
148
+ topk_probs, topk_indices = torch.topk(F.softmax(logits, dim=-1), top_k, dim=-1)
149
+ ix = torch.multinomial(topk_probs, 1)
150
+ next_token = torch.gather(topk_indices, -1, ix)
151
+
152
+ # Append to sequence
153
+ tokens = torch.cat([tokens, next_token], dim=1)
154
+
155
+ # Stop if we hit max length
156
+ if tokens.size(1) >= config.block_size:
157
+ break
158
+
159
+ # Decode
160
+ generated_text = enc.decode(tokens[0].tolist())
161
+ return generated_text
162
+ except Exception as e:
163
+ return f"Error: {str(e)}"
164
+
165
+
166
+ # Create Gradio interface
167
+ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
168
+ gr.Markdown("""
169
+ # 🎭 GPT-2 124M Shakespeare Language Model
170
+
171
+ This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
172
+
173
+ **Training Results:**
174
+ - Final Loss: 0.095127 (Target: < 0.099999) ✅
175
+ - Model Parameters: 124.44M
176
+ - Training Steps: 1,637
177
+
178
+ Enter a prompt below to generate Shakespeare-style text!
179
+ """)
180
+
181
+ with gr.Row():
182
+ with gr.Column():
183
+ prompt_input = gr.Textbox(
184
+ label="Prompt",
185
+ placeholder="Enter your prompt here (e.g., 'First Citizen:', 'ROMEO:', 'To be or not')",
186
+ value="First Citizen:",
187
+ lines=3
188
+ )
189
+ max_tokens = gr.Slider(
190
+ label="Max Tokens",
191
+ minimum=50,
192
+ maximum=200,
193
+ value=100,
194
+ step=10
195
+ )
196
+ temperature = gr.Slider(
197
+ label="Temperature",
198
+ minimum=0.1,
199
+ maximum=2.0,
200
+ value=0.8,
201
+ step=0.1
202
+ )
203
+ top_k = gr.Slider(
204
+ label="Top-K",
205
+ minimum=10,
206
+ maximum=100,
207
+ value=50,
208
+ step=10
209
+ )
210
+ generate_btn = gr.Button("Generate", variant="primary")
211
+
212
+ with gr.Column():
213
+ output = gr.Textbox(
214
+ label="Generated Text",
215
+ lines=10,
216
+ interactive=False
217
+ )
218
+
219
+ # Example prompts
220
+ gr.Markdown("### Example Prompts:")
221
+ examples = gr.Examples(
222
+ examples=[
223
+ ["First Citizen:"],
224
+ ["ROMEO:"],
225
+ ["To be or not"],
226
+ ["HAMLET:"],
227
+ ["MACBETH:"],
228
+ ],
229
+ inputs=prompt_input
230
+ )
231
+
232
+ generate_btn.click(
233
+ fn=generate_text,
234
+ inputs=[prompt_input, max_tokens, temperature, top_k],
235
+ outputs=output
236
+ )
237
+
238
+ gr.Markdown("""
239
+ ---
240
+ **Note:** The model was trained on Shakespeare text and generates text in that style.
241
+ Generated text may not always be coherent but should follow Shakespearean patterns.
242
+ """)
243
+
244
+ if __name__ == "__main__":
245
+ demo.launch(share=True)
246
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ tiktoken>=0.5.0
3
+ gradio>=5.4.1
4
+