shwethd commited on
Commit
227301c
·
verified ·
1 Parent(s): fe180fa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -27
app.py CHANGED
@@ -220,8 +220,8 @@ print(f"Model ready on {device}")
220
  enc = tiktoken.get_encoding('gpt2')
221
 
222
 
223
- def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
224
- """Generate text from prompt"""
225
  try:
226
  if not model_loaded:
227
  return "❌ Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
@@ -232,6 +232,8 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
232
 
233
  temperature = max(0.1, min(2.0, temperature)) # Clamp temperature
234
  top_k = max(1, min(100, int(top_k))) # Clamp top_k
 
 
235
  max_new_tokens = max(1, min(200, int(max_new_tokens))) # Clamp max tokens
236
 
237
  # Encode prompt
@@ -241,34 +243,73 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
241
 
242
  tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
243
 
244
- # Generate
245
  with torch.no_grad():
 
 
 
246
  for i in range(max_new_tokens):
247
  # Forward pass
248
  logits, _ = model(tokens)
249
- logits = logits[:, -1, :] / max(temperature, 0.1) # Avoid division by zero
250
 
251
- # Apply top-k filtering
252
- if top_k < logits.size(-1):
253
- topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
254
- # Create filtered logits
255
- filtered_logits = torch.full_like(logits, float('-inf'))
256
- filtered_logits.scatter_(-1, topk_indices, topk_logits)
257
- logits = filtered_logits
258
 
259
- # Sample from distribution
260
  probs = F.softmax(logits, dim=-1)
261
 
262
- # Avoid NaN
263
- if torch.isnan(probs).any():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  probs = torch.ones_like(probs) / probs.size(-1)
265
 
 
266
  next_token = torch.multinomial(probs, 1)
267
 
 
 
 
 
 
 
 
268
  # Append to sequence
269
  tokens = torch.cat([tokens, next_token], dim=1)
270
 
271
- # Stop if we hit max length
 
272
  if tokens.size(1) >= config.block_size:
273
  break
274
 
@@ -278,22 +319,102 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
278
  # Post-process to fix spacing issues (common with BPE tokenizers)
279
  import re
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # Fix 1: lowercase followed by uppercase (e.g., "perpetualWith" -> "perpetual With")
282
  generated_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', generated_text)
283
 
284
  # Fix 2: Common word boundaries that got merged (e.g., "perpetualwith" -> "perpetual with")
285
  # Add space before common words that might have been merged
286
- common_words = ['with', 'the', 'and', 'that', 'this', 'have', 'from', 'not', 'but', 'for', 'are', 'was', 'were', 'been', 'will', 'shall', 'would', 'could', 'should']
287
  for word in common_words:
288
  # Only add space if it's not already separated and follows a lowercase letter
289
  pattern = r'([a-z])(' + word + r'\b)'
290
  generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
291
 
 
 
 
 
 
 
 
 
292
  # Fix 3: Add space before character names (all caps words)
293
  generated_text = re.sub(r'([a-z])([A-Z]{2,})', r'\1 \2', generated_text)
294
 
295
- # Fix 4: Remove duplicate speaker names (e.g., "LEONTES:\n...\nLEONTES:" -> keep only first)
296
- # More aggressive: remove same speaker if it appears within 5 lines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  lines = generated_text.split('\n')
298
  cleaned_lines = []
299
  speaker_history = [] # Track recent speakers with their line numbers
@@ -306,9 +427,9 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
306
  if speaker_match:
307
  speaker = speaker_match.group(1).strip()
308
 
309
- # Check if this speaker appeared recently (within last 5 lines)
310
  recent_speaker = False
311
- for hist_speaker, hist_line_num in speaker_history[-5:]:
312
  if speaker == hist_speaker:
313
  recent_speaker = True
314
  break
@@ -329,11 +450,47 @@ def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=50):
329
 
330
  generated_text = '\n'.join(cleaned_lines)
331
 
332
- # Fix 5: Remove multiple empty lines between speaker and dialogue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  generated_text = re.sub(r'([A-Z][A-Z\s]+?):\s*\n\s*\n+', r'\1:\n', generated_text)
334
 
335
- # Fix 6: Remove any remaining consecutive duplicate speakers (final cleanup)
336
- # Pattern: Same speaker name appearing on consecutive lines (with optional whitespace)
337
  generated_text = re.sub(
338
  r'^([A-Z][A-Z\s]+?):\s*\n\s*\n*\1:\s*\n',
339
  r'\1:\n',
@@ -389,15 +546,33 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
389
  label="Temperature",
390
  minimum=0.1,
391
  maximum=2.0,
392
- value=0.8,
393
- step=0.1
 
394
  )
395
  top_k = gr.Slider(
396
  label="Top-K",
397
  minimum=10,
398
  maximum=100,
399
  value=50,
400
- step=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  )
402
  generate_btn = gr.Button("Generate", variant="primary")
403
 
@@ -433,7 +608,7 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
433
 
434
  generate_btn.click(
435
  fn=generate_text,
436
- inputs=[prompt_input, max_tokens, temperature, top_k],
437
  outputs=output
438
  )
439
 
 
220
  enc = tiktoken.get_encoding('gpt2')
221
 
222
 
223
+ def generate_text(prompt, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1):
224
+ """Generate text from prompt with improved sampling"""
225
  try:
226
  if not model_loaded:
227
  return "❌ Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
 
232
 
233
  temperature = max(0.1, min(2.0, temperature)) # Clamp temperature
234
  top_k = max(1, min(100, int(top_k))) # Clamp top_k
235
+ top_p = max(0.1, min(1.0, float(top_p))) # Clamp top_p (nucleus sampling)
236
+ repetition_penalty = max(1.0, min(1.5, float(repetition_penalty))) # Clamp repetition penalty
237
  max_new_tokens = max(1, min(200, int(max_new_tokens))) # Clamp max tokens
238
 
239
  # Encode prompt
 
243
 
244
  tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
245
 
246
+ # Generate with improved sampling strategy
247
  with torch.no_grad():
248
+ # Track recent tokens for repetition penalty
249
+ recent_tokens = set()
250
+
251
  for i in range(max_new_tokens):
252
  # Forward pass
253
  logits, _ = model(tokens)
254
+ logits = logits[:, -1, :] / max(temperature, 0.1) # Apply temperature
255
 
256
+ # Apply repetition penalty to reduce loops
257
+ if repetition_penalty > 1.0 and len(recent_tokens) > 0:
258
+ for token_id in recent_tokens:
259
+ if logits[0, token_id] > 0:
260
+ logits[0, token_id] /= repetition_penalty
261
+ else:
262
+ logits[0, token_id] *= repetition_penalty
263
 
264
+ # Convert to probabilities
265
  probs = F.softmax(logits, dim=-1)
266
 
267
+ # Apply top-p (nucleus) sampling first - often better than just top-k
268
+ if top_p < 1.0:
269
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
270
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
271
+
272
+ # Remove tokens with cumulative probability above threshold
273
+ sorted_indices_to_remove = cumulative_probs > top_p
274
+ # Keep at least one token
275
+ sorted_indices_to_remove[..., 0] = False
276
+
277
+ # Create mask
278
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
279
+ probs[indices_to_remove] = 0
280
+
281
+ # Renormalize
282
+ probs = probs / probs.sum()
283
+
284
+ # Apply top-k filtering (after top-p for better quality)
285
+ if top_k < logits.size(-1):
286
+ topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
287
+ # Create filtered probabilities
288
+ filtered_probs = torch.zeros_like(probs)
289
+ filtered_probs.scatter_(-1, topk_indices, topk_probs)
290
+ # Renormalize
291
+ filtered_probs = filtered_probs / filtered_probs.sum()
292
+ probs = filtered_probs
293
+
294
+ # Avoid NaN or zero probabilities
295
+ if torch.isnan(probs).any() or (probs.sum() == 0):
296
  probs = torch.ones_like(probs) / probs.size(-1)
297
 
298
+ # Sample from distribution
299
  next_token = torch.multinomial(probs, 1)
300
 
301
+ # Update recent tokens for repetition penalty (keep last 20 tokens)
302
+ token_id = next_token.item()
303
+ recent_tokens.add(token_id)
304
+ if len(recent_tokens) > 20:
305
+ # Remove oldest tokens (simple approach: keep last 20)
306
+ recent_tokens = set(list(recent_tokens)[-20:])
307
+
308
  # Append to sequence
309
  tokens = torch.cat([tokens, next_token], dim=1)
310
 
311
+ # Early stopping: stop if we generate end-of-text token (if present)
312
+ # For GPT-2 tokenizer, we can check for certain patterns
313
  if tokens.size(1) >= config.block_size:
314
  break
315
 
 
319
  # Post-process to fix spacing issues (common with BPE tokenizers)
320
  import re
321
 
322
+ # Fix 0: Remove the prompt from the beginning if it appears as a speaker name
323
+ # This handles cases where user enters "Romeo and Juliet" and model treats it as speaker
324
+ prompt_lower = prompt.lower().strip()
325
+ generated_lower = generated_text.lower()
326
+
327
+ # If prompt appears at the very start and looks like it was treated as a speaker
328
+ if generated_lower.startswith(prompt_lower):
329
+ # Check if it's followed by a newline (speaker format) or dialogue
330
+ prompt_len = len(prompt)
331
+ if len(generated_text) > prompt_len:
332
+ next_chars = generated_text[prompt_len:prompt_len+5].strip()
333
+ # If prompt is followed by newline or colon-like pattern, it was treated as speaker
334
+ if not next_chars or ':' in next_chars or '\n' in generated_text[prompt_len:prompt_len+5]:
335
+ # Remove the prompt from output (it's the input, not part of generated story)
336
+ generated_text = generated_text[len(prompt):].strip()
337
+ # Remove leading newlines/colons
338
+ generated_text = re.sub(r'^[\s:]+', '', generated_text)
339
+
340
+ # Check if the first line after removal is orphaned dialogue (no speaker)
341
+ lines = generated_text.split('\n')
342
+ if lines and lines[0].strip():
343
+ first_line = lines[0].strip()
344
+ # If first line is not a speaker name and looks like dialogue, add a speaker
345
+ if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', first_line):
346
+ # Check if it's dialogue-like (starts with capital, has punctuation)
347
+ if re.match(r'^[A-Z]', first_line) and ('.' in first_line or ',' in first_line or '!' in first_line or '?' in first_line):
348
+ # Add a generic speaker name based on the prompt context
349
+ # For story prompts like "Romeo and Juliet", use a character from the prompt
350
+ prompt_words = [w.capitalize() for w in prompt_lower.split() if len(w) > 2]
351
+ if len(prompt_words) >= 2:
352
+ # Use first significant word as speaker (e.g., "Romeo" from "Romeo and Juliet")
353
+ speaker_name = prompt_words[0].upper()
354
+ else:
355
+ # Generic speaker
356
+ speaker_name = "NARRATOR"
357
+
358
+ # Add speaker before the dialogue
359
+ generated_text = f"{speaker_name}:\n{first_line}\n" + '\n'.join(lines[1:]) if len(lines) > 1 else f"{speaker_name}:\n{first_line}"
360
+
361
  # Fix 1: lowercase followed by uppercase (e.g., "perpetualWith" -> "perpetual With")
362
  generated_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', generated_text)
363
 
364
  # Fix 2: Common word boundaries that got merged (e.g., "perpetualwith" -> "perpetual with")
365
  # Add space before common words that might have been merged
366
+ common_words = ['with', 'the', 'and', 'that', 'this', 'have', 'from', 'not', 'but', 'for', 'are', 'was', 'were', 'been', 'will', 'shall', 'would', 'could', 'should', 'be', 'your', 'you', 'our', 'my', 'his', 'her', 'their', 'him', 'them']
367
  for word in common_words:
368
  # Only add space if it's not already separated and follows a lowercase letter
369
  pattern = r'([a-z])(' + word + r'\b)'
370
  generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
371
 
372
+ # Fix 2b: Fix contractions that got merged (e.g., "You'llbe" -> "You'll be")
373
+ # Add space after contractions before lowercase words
374
+ contractions = ["'ll", "'ve", "'re", "'d", "'t", "'s", "'m"]
375
+ for contraction in contractions:
376
+ # Pattern: contraction followed by lowercase letter (e.g., "You'llbe" -> "You'll be")
377
+ pattern = r"(" + re.escape(contraction) + r")([a-z])"
378
+ generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
379
+
380
  # Fix 3: Add space before character names (all caps words)
381
  generated_text = re.sub(r'([a-z])([A-Z]{2,})', r'\1 \2', generated_text)
382
 
383
+ # Fix 3b: Normalize speaker names (e.g., "Romeo and juliet" -> "ROMEO AND JULIET:")
384
+ # Handle mixed case speaker names that should be all caps
385
+ lines = generated_text.split('\n')
386
+ normalized_lines = []
387
+ for i, line in enumerate(lines):
388
+ line_stripped = line.strip()
389
+
390
+ # Check if line is a potential speaker name (title case or mixed case, 2+ words)
391
+ # Pattern: "Romeo and juliet", "Romeo And Juliet", etc.
392
+ speaker_pattern = r'^([A-Z][a-z]+(?:\s+[a-zA-Z]+)+)\s*:?\s*$'
393
+ match = re.match(speaker_pattern, line_stripped)
394
+
395
+ if match:
396
+ # Check if next line is dialogue (not another speaker)
397
+ is_speaker = False
398
+ if i + 1 < len(lines):
399
+ next_line = lines[i + 1].strip()
400
+ # If next line is not empty and not a speaker name, this is likely a speaker
401
+ if next_line and not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
402
+ is_speaker = True
403
+ elif i == 0: # First line is likely a speaker if it matches pattern
404
+ is_speaker = True
405
+
406
+ if is_speaker:
407
+ # Convert to all caps and ensure colon
408
+ speaker_name = match.group(1).upper()
409
+ normalized_lines.append(speaker_name + ':')
410
+ continue
411
+
412
+ normalized_lines.append(line)
413
+
414
+ generated_text = '\n'.join(normalized_lines)
415
+
416
+ # Fix 4: Remove duplicate speaker names (e.g., "EDWARD IV:\n...\nEDWARD IV:" -> keep only first)
417
+ # More aggressive: remove same speaker if it appears within 3 lines (tighter window)
418
  lines = generated_text.split('\n')
419
  cleaned_lines = []
420
  speaker_history = [] # Track recent speakers with their line numbers
 
427
  if speaker_match:
428
  speaker = speaker_match.group(1).strip()
429
 
430
+ # Check if this speaker appeared recently (within last 3 lines - more aggressive)
431
  recent_speaker = False
432
+ for hist_speaker, hist_line_num in speaker_history[-3:]:
433
  if speaker == hist_speaker:
434
  recent_speaker = True
435
  break
 
450
 
451
  generated_text = '\n'.join(cleaned_lines)
452
 
453
+ # Fix 5: Remove speaker names with no dialogue (e.g., "KING:\nEDWARD IV:" -> "EDWARD IV:")
454
+ # A speaker name should be followed by actual dialogue, not immediately by another speaker
455
+ lines = generated_text.split('\n')
456
+ final_lines = []
457
+
458
+ for i, line in enumerate(lines):
459
+ line_stripped = line.strip()
460
+ speaker_match = re.match(r'^([A-Z][A-Z\s]+?):\s*$', line_stripped)
461
+
462
+ if speaker_match:
463
+ # Check if next non-empty line is another speaker (meaning this speaker has no dialogue)
464
+ has_dialogue = False
465
+ for j in range(i + 1, min(i + 3, len(lines))): # Check next 3 lines (more aggressive)
466
+ next_line = lines[j].strip()
467
+ if not next_line: # Skip empty lines
468
+ continue
469
+ # If next non-empty line is NOT a speaker, we have dialogue
470
+ if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
471
+ has_dialogue = True
472
+ break
473
+ # If next non-empty line IS a speaker, this speaker has no dialogue
474
+ else:
475
+ # This speaker has no dialogue - skip it
476
+ break
477
+
478
+ if not has_dialogue:
479
+ # This speaker has no dialogue, skip it
480
+ continue
481
+
482
+ final_lines.append(line)
483
+
484
+ generated_text = '\n'.join(final_lines)
485
+
486
+ # Fix 5b: Fix merged text issues (e.g., "You?A:" -> "You? A:")
487
+ # Add space after question/exclamation marks before capital letters
488
+ generated_text = re.sub(r'([?!])([A-Z])', r'\1 \2', generated_text)
489
+
490
+ # Fix 6: Remove multiple empty lines between speaker and dialogue
491
  generated_text = re.sub(r'([A-Z][A-Z\s]+?):\s*\n\s*\n+', r'\1:\n', generated_text)
492
 
493
+ # Fix 7: Remove any remaining consecutive duplicate speakers (final cleanup)
 
494
  generated_text = re.sub(
495
  r'^([A-Z][A-Z\s]+?):\s*\n\s*\n*\1:\s*\n',
496
  r'\1:\n',
 
546
  label="Temperature",
547
  minimum=0.1,
548
  maximum=2.0,
549
+ value=0.7,
550
+ step=0.1,
551
+ info="Lower = more focused, Higher = more creative (0.7 recommended for better coherence)"
552
  )
553
  top_k = gr.Slider(
554
  label="Top-K",
555
  minimum=10,
556
  maximum=100,
557
  value=50,
558
+ step=10,
559
+ info="Number of top tokens to consider"
560
+ )
561
+ top_p = gr.Slider(
562
+ label="Top-P (Nucleus)",
563
+ minimum=0.1,
564
+ maximum=1.0,
565
+ value=0.9,
566
+ step=0.05,
567
+ info="Nucleus sampling - higher = more diverse, lower = more focused (0.9 recommended)"
568
+ )
569
+ repetition_penalty = gr.Slider(
570
+ label="Repetition Penalty",
571
+ minimum=1.0,
572
+ maximum=1.5,
573
+ value=1.1,
574
+ step=0.05,
575
+ info="Penalize repeated tokens - higher = less repetition (1.1 recommended)"
576
  )
577
  generate_btn = gr.Button("Generate", variant="primary")
578
 
 
608
 
609
  generate_btn.click(
610
  fn=generate_text,
611
+ inputs=[prompt_input, max_tokens, temperature, top_k, top_p, repetition_penalty],
612
  outputs=output
613
  )
614