shwethd commited on
Commit
c1e7837
Β·
verified Β·
1 Parent(s): 0634381

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -27
app.py CHANGED
@@ -114,35 +114,97 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
114
  config = GPTConfig()
115
  model = GPT(config)
116
 
 
 
117
  # Try to load model from HuggingFace Model Hub first, then local file
118
  try:
119
  from huggingface_hub import hf_hub_download
120
  import os
121
 
122
  # Try to get model path from environment variable or use default
123
- repo_id = os.getenv('HF_MODEL_REPO', 'YOUR_USERNAME/gpt2-shakespeare-124m') # Update with your repo
124
 
125
  try:
126
- model_path = hf_hub_download(
127
- repo_id=repo_id,
128
- filename="model_checkpoint_final.pt",
129
- cache_dir=None
130
- )
131
- checkpoint = torch.load(model_path, map_location=device)
132
- model.load_state_dict(checkpoint['model_state_dict'])
133
- print(f"Model loaded from HuggingFace Hub: {repo_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
- print(f"Could not load from Hub ({e}), trying local file...")
136
- # Fallback to local file
137
- checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
138
- model.load_state_dict(checkpoint['model_state_dict'])
139
- print("Model loaded from local checkpoint")
 
 
 
 
 
 
 
 
 
140
  except FileNotFoundError:
141
- print("Warning: Model checkpoint not found. Using untrained model.")
142
- # Model will be randomly initialized - not ideal but won't crash
143
  except Exception as e:
144
- print(f"Error loading model: {e}")
145
- print("Using untrained model as fallback.")
 
 
 
 
146
 
147
  model.to(device)
148
  model.eval()
@@ -154,21 +216,47 @@ enc = tiktoken.get_encoding('gpt2')
154
  def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
155
  """Generate text from prompt"""
156
  try:
 
 
 
 
 
 
 
 
 
 
 
157
  # Encode prompt
158
  tokens = enc.encode(prompt)
 
 
 
159
  tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
160
 
161
  # Generate
162
  with torch.no_grad():
163
- for _ in range(max_new_tokens):
164
  # Forward pass
165
  logits, _ = model(tokens)
166
- logits = logits[:, -1, :] / temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # Top-k sampling
169
- topk_probs, topk_indices = torch.topk(F.softmax(logits, dim=-1), top_k, dim=-1)
170
- ix = torch.multinomial(topk_probs, 1)
171
- next_token = torch.gather(topk_indices, -1, ix)
172
 
173
  # Append to sequence
174
  tokens = torch.cat([tokens, next_token], dim=1)
@@ -181,14 +269,21 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
181
  generated_text = enc.decode(tokens[0].tolist())
182
  return generated_text
183
  except Exception as e:
184
- return f"Error: {str(e)}"
 
185
 
186
 
187
  # Create Gradio interface
188
  with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
189
- gr.Markdown("""
 
 
 
 
190
  # 🎭 GPT-2 124M Shakespeare Language Model
191
 
 
 
192
  This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
193
 
194
  **Training Results:**
@@ -197,6 +292,8 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
197
  - Training Steps: 1,637
198
 
199
  Enter a prompt below to generate Shakespeare-style text!
 
 
200
  """)
201
 
202
  with gr.Row():
@@ -238,7 +335,7 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
238
  )
239
 
240
  # Example prompts
241
- gr.Markdown("### Example Prompts:")
242
  examples = gr.Examples(
243
  examples=[
244
  ["First Citizen:"],
@@ -246,6 +343,15 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
246
  ["To be or not"],
247
  ["HAMLET:"],
248
  ["MACBETH:"],
 
 
 
 
 
 
 
 
 
249
  ],
250
  inputs=prompt_input
251
  )
 
114
  config = GPTConfig()
115
  model = GPT(config)
116
 
117
+ model_loaded = False
118
+
119
  # Try to load model from HuggingFace Model Hub first, then local file
120
  try:
121
  from huggingface_hub import hf_hub_download
122
  import os
123
 
124
  # Try to get model path from environment variable or use default
125
+ repo_id = os.getenv('HF_MODEL_REPO', 'shwethd/gpt2-shakespeare-124m')
126
 
127
  try:
128
+ print(f"Attempting to load from HuggingFace Hub: {repo_id}")
129
+
130
+ # Try SafeTensors first (more secure, no pickle issues)
131
+ try:
132
+ from safetensors.torch import load_file
133
+ try:
134
+ model_path = hf_hub_download(
135
+ repo_id=repo_id,
136
+ filename="model.safetensors",
137
+ cache_dir=None
138
+ )
139
+ state_dict = load_file(model_path, device=device)
140
+ model.load_state_dict(state_dict)
141
+ model_loaded = True
142
+ print(f"βœ… Model loaded successfully from SafeTensors: {repo_id}")
143
+ except Exception as e:
144
+ print(f"SafeTensors not found ({e}), trying .pt file...")
145
+ # Fallback to .pt file
146
+ model_path = hf_hub_download(
147
+ repo_id=repo_id,
148
+ filename="model_checkpoint_final.pt",
149
+ cache_dir=None
150
+ )
151
+ checkpoint = torch.load(model_path, map_location=device)
152
+
153
+ # Handle different checkpoint formats
154
+ if 'model_state_dict' in checkpoint:
155
+ model.load_state_dict(checkpoint['model_state_dict'])
156
+ elif 'state_dict' in checkpoint:
157
+ model.load_state_dict(checkpoint['state_dict'])
158
+ else:
159
+ # If checkpoint is the state dict itself
160
+ model.load_state_dict(checkpoint)
161
+
162
+ model_loaded = True
163
+ print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
164
+ except ImportError:
165
+ # safetensors not installed, use .pt file
166
+ model_path = hf_hub_download(
167
+ repo_id=repo_id,
168
+ filename="model_checkpoint_final.pt",
169
+ cache_dir=None
170
+ )
171
+ checkpoint = torch.load(model_path, map_location=device)
172
+
173
+ # Handle different checkpoint formats
174
+ if 'model_state_dict' in checkpoint:
175
+ model.load_state_dict(checkpoint['model_state_dict'])
176
+ elif 'state_dict' in checkpoint:
177
+ model.load_state_dict(checkpoint['state_dict'])
178
+ else:
179
+ # If checkpoint is the state dict itself
180
+ model.load_state_dict(checkpoint)
181
+
182
+ model_loaded = True
183
+ print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
184
  except Exception as e:
185
+ print(f"⚠️ Could not load from Hub ({e}), trying local file...")
186
+ try:
187
+ # Fallback to local file
188
+ checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
189
+ if 'model_state_dict' in checkpoint:
190
+ model.load_state_dict(checkpoint['model_state_dict'])
191
+ elif 'state_dict' in checkpoint:
192
+ model.load_state_dict(checkpoint['state_dict'])
193
+ else:
194
+ model.load_state_dict(checkpoint)
195
+ model_loaded = True
196
+ print("βœ… Model loaded from local checkpoint")
197
+ except Exception as e2:
198
+ print(f"❌ Could not load from local file either: {e2}")
199
  except FileNotFoundError:
200
+ print("❌ Warning: Model checkpoint not found. Using untrained model.")
 
201
  except Exception as e:
202
+ print(f"❌ Error loading model: {e}")
203
+ print("⚠️ Using untrained model as fallback - output will be random!")
204
+
205
+ if not model_loaded:
206
+ print("⚠️ WARNING: Model is using random weights! Generation will be nonsensical.")
207
+ print("Please ensure model_checkpoint_final.pt is uploaded to HuggingFace Model Hub.")
208
 
209
  model.to(device)
210
  model.eval()
 
216
  def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
217
  """Generate text from prompt"""
218
  try:
219
+ if not model_loaded:
220
+ return "❌ Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
221
+
222
+ # Validate inputs
223
+ if not prompt or len(prompt.strip()) == 0:
224
+ return "Please enter a prompt."
225
+
226
+ temperature = max(0.1, min(2.0, temperature)) # Clamp temperature
227
+ top_k = max(1, min(100, int(top_k))) # Clamp top_k
228
+ max_new_tokens = max(1, min(200, int(max_new_tokens))) # Clamp max tokens
229
+
230
  # Encode prompt
231
  tokens = enc.encode(prompt)
232
+ if len(tokens) == 0:
233
+ return "Error: Could not encode prompt."
234
+
235
  tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
236
 
237
  # Generate
238
  with torch.no_grad():
239
+ for i in range(max_new_tokens):
240
  # Forward pass
241
  logits, _ = model(tokens)
242
+ logits = logits[:, -1, :] / max(temperature, 0.1) # Avoid division by zero
243
+
244
+ # Apply top-k filtering
245
+ if top_k < logits.size(-1):
246
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
247
+ # Create filtered logits
248
+ filtered_logits = torch.full_like(logits, float('-inf'))
249
+ filtered_logits.scatter_(-1, topk_indices, topk_logits)
250
+ logits = filtered_logits
251
+
252
+ # Sample from distribution
253
+ probs = F.softmax(logits, dim=-1)
254
+
255
+ # Avoid NaN
256
+ if torch.isnan(probs).any():
257
+ probs = torch.ones_like(probs) / probs.size(-1)
258
 
259
+ next_token = torch.multinomial(probs, 1)
 
 
 
260
 
261
  # Append to sequence
262
  tokens = torch.cat([tokens, next_token], dim=1)
 
269
  generated_text = enc.decode(tokens[0].tolist())
270
  return generated_text
271
  except Exception as e:
272
+ import traceback
273
+ return f"❌ Error during generation: {str(e)}\n\nPlease check:\n1. Model is uploaded to HuggingFace Model Hub\n2. Repository name is correct: shwethd/gpt2-shakespeare-124m\n3. File name is exactly: model_checkpoint_final.pt"
274
 
275
 
276
  # Create Gradio interface
277
  with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
278
+ # Status indicator
279
+ status_color = "🟒" if model_loaded else "πŸ”΄"
280
+ status_text = "Model loaded successfully!" if model_loaded else "⚠️ Model not loaded - check HuggingFace Model Hub!"
281
+
282
+ gr.Markdown(f"""
283
  # 🎭 GPT-2 124M Shakespeare Language Model
284
 
285
+ {status_color} **Status:** {status_text}
286
+
287
  This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
288
 
289
  **Training Results:**
 
292
  - Training Steps: 1,637
293
 
294
  Enter a prompt below to generate Shakespeare-style text!
295
+
296
+ {"⚠️ **Note:** If you see garbled/random text, the model may not have loaded correctly. Check the logs and ensure the model is uploaded to HuggingFace Model Hub: `shwethd/gpt2-shakespeare-124m`" if not model_loaded else ""}
297
  """)
298
 
299
  with gr.Row():
 
335
  )
336
 
337
  # Example prompts
338
+ gr.Markdown("### Example Prompts (Click to try):")
339
  examples = gr.Examples(
340
  examples=[
341
  ["First Citizen:"],
 
343
  ["To be or not"],
344
  ["HAMLET:"],
345
  ["MACBETH:"],
346
+ ["JULIET:"],
347
+ ["KING:"],
348
+ ["LADY MACBETH:"],
349
+ ["OTHELLO:"],
350
+ ["What light through yonder"],
351
+ ["All the world's a stage"],
352
+ ["Double, double toil and trouble"],
353
+ ["Friends, Romans, countrymen"],
354
+ ["A rose by any other name"],
355
  ],
356
  inputs=prompt_input
357
  )