File size: 43,785 Bytes
3216812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e7837
 
0634381
3216812
0634381
 
 
 
c1e7837
0634381
 
c1e7837
 
 
 
 
 
 
 
 
 
 
 
 
adc8386
 
 
c1e7837
 
 
 
 
 
 
 
 
 
adc8386
 
 
c1e7837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc8386
 
c1e7837
 
 
 
 
 
 
 
 
 
 
 
0634381
c1e7837
 
 
adc8386
 
c1e7837
 
 
 
 
 
 
 
 
 
3216812
c1e7837
0634381
c1e7837
 
 
 
 
 
3216812
 
 
 
 
 
 
 
227301c
 
3216812
c1e7837
 
 
 
 
 
 
 
 
227301c
 
c1e7837
 
3216812
 
c1e7837
 
 
3216812
 
227301c
3216812
227301c
 
 
c1e7837
3216812
 
227301c
c1e7837
227301c
 
 
 
 
 
 
c1e7837
227301c
c1e7837
 
227301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e7837
3216812
227301c
c1e7837
3216812
227301c
 
 
 
 
 
 
3216812
 
 
227301c
 
3216812
 
 
 
 
4e5f1e6
 
 
7360b49
227301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc9884
 
227301c
 
 
3bc9884
 
227301c
4e5f1e6
 
 
3bc9884
 
82f907e
 
1e393db
82f907e
 
 
 
 
 
 
1e393db
 
3bc9884
a845bcb
 
 
 
82f907e
c037b52
82f907e
3bc9884
82f907e
 
 
1e393db
3bc9884
 
 
 
1e393db
3bc9884
 
 
1e393db
3bc9884
 
 
1e393db
3bc9884
1e393db
3bc9884
 
 
 
 
 
 
 
 
c037b52
4e5f1e6
 
c037b52
4e5f1e6
 
 
 
 
c037b52
 
 
 
 
 
a845bcb
3bc9884
a845bcb
3bc9884
 
 
a845bcb
3bc9884
a845bcb
3bc9884
 
 
 
 
 
 
 
 
 
 
 
 
a845bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc9884
1e393db
82f907e
 
1e393db
82f907e
 
 
 
 
 
 
 
1e393db
 
 
3bc9884
1e393db
3bc9884
 
 
 
 
 
 
 
 
 
 
 
82f907e
 
 
 
 
 
 
3bc9884
1e393db
 
227301c
 
 
 
 
 
 
 
a845bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc9884
 
 
 
4e5f1e6
 
227301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7360b49
 
fe180fa
7360b49
fe180fa
7360b49
fe180fa
7360b49
 
 
 
fe180fa
227301c
fe180fa
227301c
fe180fa
 
 
 
 
 
 
 
 
 
 
 
 
 
7360b49
 
 
 
 
 
227301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7360b49
 
227301c
fe180fa
 
 
 
 
 
7360b49
3bc9884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3216812
 
c1e7837
 
3216812
 
 
 
c1e7837
 
 
 
 
3216812
 
c1e7837
 
3216812
 
 
 
 
 
 
 
c1e7837
 
3216812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227301c
 
 
3216812
 
 
 
 
 
227301c
 
 
 
 
 
 
c037b52
227301c
c037b52
227301c
 
 
 
 
 
 
 
3216812
 
 
 
 
 
 
7360b49
 
3216812
 
c037b52
1e393db
3216812
 
c037b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3216812
c037b52
3216812
 
 
 
227301c
3216812
 
 
 
 
 
 
 
 
 
adc8386
 
3216812
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
"""
HuggingFace Spaces App for GPT-2 124M Shakespeare Model
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import math
from dataclasses import dataclass


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


# Load model
print("Loading model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = GPTConfig()
model = GPT(config)

model_loaded = False

# Try to load model from HuggingFace Model Hub first, then local file
try:
    from huggingface_hub import hf_hub_download
    import os
    
    # Try to get model path from environment variable or use default
    repo_id = os.getenv('HF_MODEL_REPO', 'shwethd/gpt2-shakespeare-124m')
    
    try:
        print(f"Attempting to load from HuggingFace Hub: {repo_id}")
        
        # Try SafeTensors first (more secure, no pickle issues)
        try:
            from safetensors.torch import load_file
            try:
                model_path = hf_hub_download(
                    repo_id=repo_id,
                    filename="model.safetensors",
                    cache_dir=None
                )
                state_dict = load_file(model_path, device=device)
                model.load_state_dict(state_dict)
                # Restore weight sharing (broken during SafeTensors conversion)
                # lm_head.weight and transformer.wte.weight should share memory
                model.transformer.wte.weight = model.lm_head.weight
                model_loaded = True
                print(f"βœ… Model loaded successfully from SafeTensors: {repo_id}")
            except Exception as e:
                print(f"SafeTensors not found ({e}), trying .pt file...")
                # Fallback to .pt file
                model_path = hf_hub_download(
                    repo_id=repo_id,
                    filename="model_checkpoint_final.pt",
                    cache_dir=None
                )
                # PyTorch 2.6+ requires weights_only=False for custom classes
                # This is safe since we trust our own trained model
                checkpoint = torch.load(model_path, map_location=device, weights_only=False)
                
                # Handle different checkpoint formats
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['state_dict'])
                else:
                    # If checkpoint is the state dict itself
                    model.load_state_dict(checkpoint)
                
                model_loaded = True
                print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
        except ImportError:
            # safetensors not installed, use .pt file
            model_path = hf_hub_download(
                repo_id=repo_id,
                filename="model_checkpoint_final.pt",
                cache_dir=None
            )
            # PyTorch 2.6+ requires weights_only=False for custom classes
            checkpoint = torch.load(model_path, map_location=device, weights_only=False)
            
            # Handle different checkpoint formats
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            elif 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                # If checkpoint is the state dict itself
                model.load_state_dict(checkpoint)
            
            model_loaded = True
            print(f"βœ… Model loaded successfully from HuggingFace Hub: {repo_id}")
    except Exception as e:
        print(f"⚠️ Could not load from Hub ({e}), trying local file...")
        try:
            # Fallback to local file
            # PyTorch 2.6+ requires weights_only=False for custom classes
            checkpoint = torch.load('model_checkpoint_final.pt', map_location=device, weights_only=False)
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            elif 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint)
            model_loaded = True
            print("βœ… Model loaded from local checkpoint")
        except Exception as e2:
            print(f"❌ Could not load from local file either: {e2}")
except FileNotFoundError:
    print("❌ Warning: Model checkpoint not found. Using untrained model.")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("⚠️ Using untrained model as fallback - output will be random!")

if not model_loaded:
    print("⚠️ WARNING: Model is using random weights! Generation will be nonsensical.")
    print("Please ensure model_checkpoint_final.pt is uploaded to HuggingFace Model Hub.")

model.to(device)
model.eval()
print(f"Model ready on {device}")

enc = tiktoken.get_encoding('gpt2')


def generate_text(prompt, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1):
    """Generate text from prompt with improved sampling"""
    try:
        if not model_loaded:
            return "❌ Error: Model not loaded correctly. Please check that model_checkpoint_final.pt is uploaded to HuggingFace Model Hub (shwethd/gpt2-shakespeare-124m)."
        
        # Validate inputs
        if not prompt or len(prompt.strip()) == 0:
            return "Please enter a prompt."
        
        temperature = max(0.1, min(2.0, temperature))  # Clamp temperature
        top_k = max(1, min(100, int(top_k)))  # Clamp top_k
        top_p = max(0.1, min(1.0, float(top_p)))  # Clamp top_p (nucleus sampling)
        repetition_penalty = max(1.0, min(1.5, float(repetition_penalty)))  # Clamp repetition penalty
        max_new_tokens = max(1, min(200, int(max_new_tokens)))  # Clamp max tokens
        
        # Encode prompt
        tokens = enc.encode(prompt)
        if len(tokens) == 0:
            return "Error: Could not encode prompt."
        
        tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
        
        # Generate with improved sampling strategy
        with torch.no_grad():
            # Track recent tokens for repetition penalty
            recent_tokens = set()
            
            for i in range(max_new_tokens):
                # Forward pass
                logits, _ = model(tokens)
                logits = logits[:, -1, :] / max(temperature, 0.1)  # Apply temperature
                
                # Apply repetition penalty to reduce loops
                if repetition_penalty > 1.0 and len(recent_tokens) > 0:
                    for token_id in recent_tokens:
                        if logits[0, token_id] > 0:
                            logits[0, token_id] /= repetition_penalty
                        else:
                            logits[0, token_id] *= repetition_penalty
                
                # Convert to probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Apply top-p (nucleus) sampling first - often better than just top-k
                if top_p < 1.0:
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                    
                    # Remove tokens with cumulative probability above threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Keep at least one token
                    sorted_indices_to_remove[..., 0] = False
                    
                    # Create mask
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    probs[indices_to_remove] = 0
                    
                    # Renormalize
                    probs = probs / probs.sum()
                
                # Apply top-k filtering (after top-p for better quality)
                if top_k < logits.size(-1):
                    topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
                    # Create filtered probabilities
                    filtered_probs = torch.zeros_like(probs)
                    filtered_probs.scatter_(-1, topk_indices, topk_probs)
                    # Renormalize
                    filtered_probs = filtered_probs / filtered_probs.sum()
                    probs = filtered_probs
                
                # Avoid NaN or zero probabilities
                if torch.isnan(probs).any() or (probs.sum() == 0):
                    probs = torch.ones_like(probs) / probs.size(-1)
                
                # Sample from distribution
                next_token = torch.multinomial(probs, 1)
                
                # Update recent tokens for repetition penalty (keep last 20 tokens)
                token_id = next_token.item()
                recent_tokens.add(token_id)
                if len(recent_tokens) > 20:
                    # Remove oldest tokens (simple approach: keep last 20)
                    recent_tokens = set(list(recent_tokens)[-20:])
                
                # Append to sequence
                tokens = torch.cat([tokens, next_token], dim=1)
                
                # Early stopping: stop if we generate end-of-text token (if present)
                # For GPT-2 tokenizer, we can check for certain patterns
                if tokens.size(1) >= config.block_size:
                    break
        
        # Decode
        generated_text = enc.decode(tokens[0].tolist())
        
        # Post-process to fix spacing issues (common with BPE tokenizers)
        import re
        
        # Fix 0: Remove the prompt from the beginning if it appears as a speaker name
        # This handles cases where user enters "Romeo and Juliet" and model treats it as speaker
        prompt_lower = prompt.lower().strip()
        generated_lower = generated_text.lower()
        
        # If prompt appears at the very start and looks like it was treated as a speaker
        if generated_lower.startswith(prompt_lower):
            # Check if it's followed by a newline (speaker format) or dialogue
            prompt_len = len(prompt)
            if len(generated_text) > prompt_len:
                next_chars = generated_text[prompt_len:prompt_len+5].strip()
                # If prompt is followed by newline or colon-like pattern, it was treated as speaker
                if not next_chars or ':' in next_chars or '\n' in generated_text[prompt_len:prompt_len+5]:
                    # Remove the prompt from output (it's the input, not part of generated story)
                    generated_text = generated_text[len(prompt):].strip()
                    # Remove leading newlines/colons
                    generated_text = re.sub(r'^[\s:]+', '', generated_text)
                    
                    # Check if the first line after removal is orphaned dialogue (no speaker)
                    lines = generated_text.split('\n')
                    if lines and lines[0].strip():
                        first_line = lines[0].strip()
                        # If first line is not a speaker name and looks like dialogue, just remove it
                        # Don't add NARRATOR - let the model's natural flow continue
                        if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', first_line):
                            # Check if it's dialogue-like (starts with capital, has punctuation)
                            if re.match(r'^[A-Z]', first_line) and ('.' in first_line or ',' in first_line or '!' in first_line or '?' in first_line):
                                # Just remove the orphaned first line, don't add a speaker
                                generated_text = '\n'.join(lines[1:]) if len(lines) > 1 else ''
        
        # Fix 1: lowercase followed by uppercase (e.g., "perpetualWith" -> "perpetual With")
        generated_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', generated_text)
        
        # Fix 1b: Fix spacing issues like "furt her" -> "further", "T his" -> "This", "y our" -> "your", "th at" -> "that"
        # Remove spaces in the middle of common words - MORE AGGRESSIVE matching
        common_words_fix = [
            'further', 'this', 'that', 'there', 'where', 'here', 'their', 'your', 'our', 
            'man', 'men', 'woman', 'women', 'padua', 'content', 'gentle', 'gently',
            'house', 'neck', 'car', 'made', 'lost', 'rough', 'see', 'might', 'any', 'one',
            'well', 'newly', 'too', 'him', 'her', 'them', 'they', 'the', 'and', 'but',
            'for', 'not', 'are', 'was', 'were', 'been', 'have', 'has', 'had', 'will',
            'shall', 'would', 'could', 'should', 'be', 'is', 'it', 'he', 'she', 'we',
            'you', 'me', 'my', 'his', 'hers', 'its', 'our', 'ours', 'yours', 'theirs',
            'into', 'onto', 'upon', 'within', 'without', 'through', 'though', 'although',
            'about', 'above', 'below', 'beside', 'between', 'among', 'during', 'before',
            'after', 'while', 'until', 'since', 'because', 'together', 'honour', 'honor',
            'already', 'perfect', 'soul', 'way', 'wounds', 'tears', 'raise', 'call',
            'citizens', 'senator', 'liked', 'cold', 'incold', 'incwold', 'son', 'count',
            'happen', 'happ', 'what', 'common', 'complain', 'upon', 'she', 'honour', 'honor',
            'youth', 'ports', 'impans', 'swear', 'gods', 'please', 'standing', 'tybalt',
            'sworn', 'where', 'would', 'give', 'seize', 'before', 'repair', 'lest', 'speak',
            'woman', 'gentleman', 'deed', 'better', 'virtuous', 'done', 'broke', 'art'
        ]
        for word in common_words_fix:
            word_lower = word.lower()
            # Try all possible split positions
            for i in range(1, len(word_lower)):
                first_part = word_lower[:i]
                second_part = word_lower[i:]
                
                # Pattern 1: lowercase split (e.g., "furt her" -> "further", "th at" -> "that")
                # Use word boundaries but also allow punctuation/whitespace around
                pattern1 = r'\b' + re.escape(first_part) + r'\s+' + re.escape(second_part) + r'\b'
                generated_text = re.sub(pattern1, word, generated_text, flags=re.IGNORECASE)
                
                # Pattern 2: Capital first letter (e.g., "Th at" -> "That")
                pattern2 = r'\b' + re.escape(first_part.capitalize()) + r'\s+' + re.escape(second_part) + r'\b'
                generated_text = re.sub(pattern2, word.capitalize(), generated_text)
                
                # Pattern 3: All caps (e.g., "TH AT" -> "THAT")
                pattern3 = r'\b' + re.escape(first_part.upper()) + r'\s+' + re.escape(second_part.upper()) + r'\b'
                generated_text = re.sub(pattern3, word.upper(), generated_text)
                
                # Pattern 4: Mixed case - first letter capitalized (e.g., "Th at" -> "That")
                if len(first_part) > 0:
                    pattern4 = r'\b' + re.escape(first_part[0].upper() + first_part[1:]) + r'\s+' + re.escape(second_part) + r'\b'
                    generated_text = re.sub(pattern4, word.capitalize(), generated_text, flags=re.IGNORECASE)
                
                # Pattern 5: Handle multiple splits in one word (e.g., "c o u n t" -> "count")
                # This is a special case for words that got split multiple times
                if len(word_lower) > 4:  # Only for longer words
                    # Try to find pattern like "c o u n t" or "y o u r"
                    # This is more complex, so we'll handle it separately
                    pass
        
        # Fix 2: Common word boundaries that got merged (e.g., "perpetualwith" -> "perpetual with")
        # Add space before common words that might have been merged
        common_words = ['with', 'the', 'and', 'that', 'this', 'have', 'from', 'not', 'but', 'for', 'are', 'was', 'were', 'been', 'will', 'shall', 'would', 'could', 'should', 'be', 'your', 'you', 'our', 'my', 'his', 'her', 'their', 'him', 'them', 'to', 'of', 'in', 'on', 'at', 'as', 'is', 'it', 'he', 'she', 'we', 'they', 'an', 'a']
        for word in common_words:
            # Only add space if it's not already separated and follows a lowercase letter
            pattern = r'([a-z])(' + word + r'\b)'
            generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
        
        # Fix 2c: Fix double words (e.g., "but but" -> "but")
        generated_text = re.sub(r'\b(\w+)\s+\1\b', r'\1', generated_text, flags=re.IGNORECASE)
        
        # Fix 2d: Fix spacing after commas (e.g., "What,bear" -> "What, bear")
        generated_text = re.sub(r',([a-zA-Z])', r', \1', generated_text)
        
        # Fix 1c: Fix multiple splits in one word (e.g., "c o u n t" -> "count", "y o u r" -> "your", "y our" -> "your", "T h is" -> "This")
        # Handle cases where a word got split into multiple parts
        multi_split_words = ['count', 'your', 'son', 'our', 'the', 'and', 'but', 'for', 'not', 'are', 'was', 'were', 'been', 'have', 'has', 'had', 'will', 'shall', 'would', 'could', 'should', 'be', 'is', 'it', 'he', 'she', 'we', 'they', 'you', 'me', 'my', 'his', 'her', 'them', 'him', 'this', 'that', 'there', 'where', 'here', 'their', 'what', 'common', 'complain', 'upon', 'honour', 'honor', 'youth', 'ports', 'impans', 'woman', 'gentleman', 'deed', 'better', 'virtuous', 'done', 'broke', 'art']
        for word in multi_split_words:
            word_lower = word.lower()
            # Create pattern for word split into individual letters with spaces
            # e.g., "c o u n t" or "y o u r" or "T h is" or "Wh at" or "y our"
            if len(word_lower) > 2:
                # Pattern 1: letter space letter space ... (all letters of the word split individually)
                letters = list(word_lower)
                pattern_parts = [re.escape(letter) + r'\s+' for letter in letters[:-1]]
                pattern_parts.append(re.escape(letters[-1]))
                pattern = r'\b' + ''.join(pattern_parts) + r'\b'
                generated_text = re.sub(pattern, word, generated_text, flags=re.IGNORECASE)
                # Also handle with some capitalization (e.g., "T h is" -> "This", "Wh at" -> "What")
                pattern_cap = r'\b' + re.escape(letters[0].upper()) + r'\s+' + ''.join([re.escape(letter) + r'\s+' for letter in letters[1:-1]]) + re.escape(letters[-1]) + r'\b'
                generated_text = re.sub(pattern_cap, word.capitalize(), generated_text)
                # Handle mixed case like "Wh at" -> "What"
                if len(letters) > 2:
                    # Pattern for "Wh at" style (first two letters capitalized, rest lowercase)
                    pattern_mixed = r'\b' + re.escape(letters[0].upper()) + re.escape(letters[1]) + r'\s+' + ''.join([re.escape(letter) + r'\s+' for letter in letters[2:-1]]) + re.escape(letters[-1]) + r'\b'
                    generated_text = re.sub(pattern_mixed, word.capitalize(), generated_text, flags=re.IGNORECASE)
                
                # Pattern 2: Handle two-part splits (e.g., "y our" -> "your", "h onour" -> "honour")
                # Try all possible two-part splits
                for split_pos in range(1, len(word_lower)):
                    first_part = word_lower[:split_pos]
                    second_part = word_lower[split_pos:]
                    # Pattern: "y our" -> "your"
                    pattern_2part = r'\b' + re.escape(first_part) + r'\s+' + re.escape(second_part) + r'\b'
                    generated_text = re.sub(pattern_2part, word, generated_text, flags=re.IGNORECASE)
                    # Capitalized version: "Y our" -> "Your"
                    pattern_2part_cap = r'\b' + re.escape(first_part.capitalize()) + r'\s+' + re.escape(second_part) + r'\b'
                    generated_text = re.sub(pattern_2part_cap, word.capitalize(), generated_text)
                    # All caps: "Y OUR" -> "YOUR"
                    pattern_2part_allcap = r'\b' + re.escape(first_part.upper()) + r'\s+' + re.escape(second_part.upper()) + r'\b'
                    generated_text = re.sub(pattern_2part_allcap, word.upper(), generated_text)
        
        # Fix 2e: Fix merged words that should be separate (e.g., "himt" -> "him to", "incwold" -> "in cold")
        # Common patterns where words got merged incorrectly
        merged_fixes = [
            # Pronoun + "t" (likely "to" got merged)
            (r'\bhimt\s+', 'him to '),  # "himt me" -> "him to me"
            (r'\bhert\s+', 'her to '),  # "hert him" -> "her to him"
            (r'\bthemt\s+', 'them to '),  # "themt us" -> "them to us"
            (r'\byout\s+', 'you to '),  # "yout me" -> "you to me"
            (r'\bhimt([,.;:!?])', r'him to\1'),  # "himt," -> "him to,"
            (r'\bhert([,.;:!?])', r'her to\1'),
            (r'\bthemt([,.;:!?])', r'them to\1'),
            (r'\byout([,.;:!?])', r'you to\1'),
            # Other merged patterns
            (r'\bincwold\b', 'in cold'),  # "incwold" -> "in cold"
            (r'\bincold\b', 'in cold'),  # "incold" -> "in cold"
            (r'\blikeled\b', 'liked'),  # "likeled" -> "liked"
            (r'\bh\s+on\s+our\b', 'honour'),  # "h on our" -> "honour"
            (r'\bh\s+on\s+or\b', 'honor'),  # "h on or" -> "honor"
            (r'\bHapp\s+up\s+on\'t\b', "Happen upon't"),  # "Happ up on't" -> "Happen upon't"
            (r'\bhapp\s+up\s+on\'t\b', "happen upon't"),
            # Fix "comm on" -> "common" (if not already fixed)
            (r'\bcomm\s+on\b', 'common'),
            (r'\bComm\s+on\b', 'Common'),
            # Fix "compl a in" -> "complain" (multiple splits)
            (r'\bcompl\s+a\s+in\b', 'complain'),
            (r'\bCompl\s+a\s+in\b', 'Complain'),
            # Fix "As s he" -> "As she"
            (r'\bAs\s+s\s+he\b', 'As she'),
            (r'\bas\s+s\s+he\b', 'as she'),
        ]
        for pattern, replacement in merged_fixes:
            generated_text = re.sub(pattern, replacement, generated_text, flags=re.IGNORECASE)
        
        # Fix 2f: Fix "content on" - this is likely two separate words, but ensure proper spacing
        generated_text = re.sub(r'\bcontenton\b', 'content on', generated_text, flags=re.IGNORECASE)
        
        # Fix 2g: Fix "toget her" -> "together"
        generated_text = re.sub(r'\btoget\s+her\b', 'together', generated_text, flags=re.IGNORECASE)
        
        # Fix 2b: Fix contractions that got merged (e.g., "You'llbe" -> "You'll be")
        # Add space after contractions before lowercase words
        contractions = ["'ll", "'ve", "'re", "'d", "'t", "'s", "'m"]
        for contraction in contractions:
            # Pattern: contraction followed by lowercase letter (e.g., "You'llbe" -> "You'll be")
            pattern = r"(" + re.escape(contraction) + r")([a-z])"
            generated_text = re.sub(pattern, r'\1 \2', generated_text, flags=re.IGNORECASE)
        
        # Fix 3: Fix split speaker names (e.g., "ALL ANC A:" -> "ALLANCA:", "GENTLEM AN:" -> "GENTLEMAN:")
        # Pattern: All caps words separated by spaces ending with colon (likely split speaker name)
        # First, try to merge split speaker names: "ALL ANC A:" -> "ALLANCA:", "GENTLEM AN:" -> "GENTLEMAN:"
        # But be careful - some speaker names might legitimately have spaces (e.g., "FIRST CITIZEN:")
        lines = generated_text.split('\n')
        fixed_lines = []
        for line in lines:
            line_stripped = line.strip()
            # Check if line looks like a split speaker name (all caps, has spaces, ends with colon)
            # Pattern 1: Multiple all-caps words with spaces: "ALL ANC A:" or "GENTLEM AN:"
            if re.match(r'^([A-Z]+\s+[A-Z]+\s*[A-Z]*):\s*$', line_stripped):
                # Check if it's a known multi-word speaker name (keep those)
                known_multi_word_speakers = ['FIRST CITIZEN', 'SECOND CITIZEN', 'THIRD CITIZEN', 
                                            'FIRST GENTLEMAN', 'SECOND GENTLEMAN', 'THIRD GENTLEMAN',
                                            'FIRST SERVANT', 'SECOND SERVANT', 'LADY MACBETH',
                                            'KING HENRY', 'PRINCE HAMLET', 'DUKE VINCENTIO']
                is_known = False
                for known in known_multi_word_speakers:
                    if known in line_stripped.upper():
                        is_known = True
                        break
                
                if not is_known:
                    # Try to merge: "ALL ANC A:" -> "ALLANCA:", "GENTLEM AN:" -> "GENTLEMAN:"
                    # Remove spaces between all-caps words before colon
                    merged = re.sub(r'([A-Z]+)\s+([A-Z]+)\s*([A-Z]*):', r'\1\2\3:', line_stripped)
                    # Only use merged if it makes sense (not too long, looks like a word)
                    if len(merged) < 30:  # Reasonable speaker name length
                        fixed_lines.append(merged)
                    else:
                        fixed_lines.append(line)
                else:
                    # Keep known multi-word speaker names as is
                    fixed_lines.append(line)
            else:
                fixed_lines.append(line)
        generated_text = '\n'.join(fixed_lines)
        
        # Fix 3b: Add space before character names (all caps words) and fix missing punctuation
        # First, fix cases like "Barn MENENIUS:" -> "Barn. MENENIUS:" or "Barn, MENENIUS:"
        # Pattern: lowercase word followed immediately by all-caps speaker name
        generated_text = re.sub(r'([a-z]+)([A-Z]{2,}):', r'\1. \2:', generated_text)
        # Then add space before character names
        generated_text = re.sub(r'([a-z])([A-Z]{2,})', r'\1 \2', generated_text)
        
        # Fix 3b: Normalize speaker names (e.g., "Romeo and juliet" -> "ROMEO AND JULIET:")
        # Handle mixed case speaker names that should be all caps
        lines = generated_text.split('\n')
        normalized_lines = []
        for i, line in enumerate(lines):
            line_stripped = line.strip()
            
            # Check if line is a potential speaker name (title case or mixed case, 2+ words)
            # Pattern: "Romeo and juliet", "Romeo And Juliet", etc.
            speaker_pattern = r'^([A-Z][a-z]+(?:\s+[a-zA-Z]+)+)\s*:?\s*$'
            match = re.match(speaker_pattern, line_stripped)
            
            if match:
                # Check if next line is dialogue (not another speaker)
                is_speaker = False
                if i + 1 < len(lines):
                    next_line = lines[i + 1].strip()
                    # If next line is not empty and not a speaker name, this is likely a speaker
                    if next_line and not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
                        is_speaker = True
                elif i == 0:  # First line is likely a speaker if it matches pattern
                    is_speaker = True
                
                if is_speaker:
                    # Convert to all caps and ensure colon
                    speaker_name = match.group(1).upper()
                    normalized_lines.append(speaker_name + ':')
                    continue
            
            normalized_lines.append(line)
        
        generated_text = '\n'.join(normalized_lines)
        
        # Fix 4: Remove duplicate speaker names (e.g., "EDWARD IV:\n...\nEDWARD IV:" -> keep only first)
        # More aggressive: remove same speaker if it appears within 3 lines (tighter window)
        lines = generated_text.split('\n')
        cleaned_lines = []
        speaker_history = []  # Track recent speakers with their line numbers
        
        for i, line in enumerate(lines):
            line_stripped = line.strip()
            # Check if this line is a speaker name
            speaker_match = re.match(r'^([A-Z][A-Z\s]+?):\s*$', line_stripped)
            
            if speaker_match:
                speaker = speaker_match.group(1).strip()
                
                # Check if this speaker appeared recently (within last 3 lines - more aggressive)
                recent_speaker = False
                for hist_speaker, hist_line_num in speaker_history[-3:]:
                    if speaker == hist_speaker:
                        recent_speaker = True
                        break
                
                if recent_speaker:
                    # Skip this duplicate speaker
                    continue
                
                # Add to history
                speaker_history.append((speaker, i))
                # Keep only last 10 speakers in history
                if len(speaker_history) > 10:
                    speaker_history.pop(0)
                
                cleaned_lines.append(line)
            else:
                cleaned_lines.append(line)
        
        generated_text = '\n'.join(cleaned_lines)
        
        # Fix 5: Remove speaker names with no dialogue (e.g., "KING:\nEDWARD IV:" -> "EDWARD IV:")
        # A speaker name should be followed by actual dialogue, not immediately by another speaker
        lines = generated_text.split('\n')
        final_lines = []
        
        for i, line in enumerate(lines):
            line_stripped = line.strip()
            speaker_match = re.match(r'^([A-Z][A-Z\s]+?):\s*$', line_stripped)
            
            if speaker_match:
                # Check if next non-empty line is another speaker (meaning this speaker has no dialogue)
                has_dialogue = False
                for j in range(i + 1, min(i + 3, len(lines))):  # Check next 3 lines (more aggressive)
                    next_line = lines[j].strip()
                    if not next_line:  # Skip empty lines
                        continue
                    # If next non-empty line is NOT a speaker, we have dialogue
                    if not re.match(r'^([A-Z][A-Z\s]+?):\s*$', next_line):
                        has_dialogue = True
                        break
                    # If next non-empty line IS a speaker, this speaker has no dialogue
                    else:
                        # This speaker has no dialogue - skip it
                        break
                
                if not has_dialogue:
                    # This speaker has no dialogue, skip it
                    continue
            
            final_lines.append(line)
        
        generated_text = '\n'.join(final_lines)
        
        # Fix 5b: Fix merged text issues (e.g., "You?A:" -> "You? A:")
        # Add space after question/exclamation marks before capital letters
        generated_text = re.sub(r'([?!])([A-Z])', r'\1 \2', generated_text)
        
        # Fix 6: Remove multiple empty lines between speaker and dialogue
        generated_text = re.sub(r'([A-Z][A-Z\s]+?):\s*\n\s*\n+', r'\1:\n', generated_text)
        
        # Fix 7: Remove any remaining consecutive duplicate speakers (final cleanup)
        generated_text = re.sub(
            r'^([A-Z][A-Z\s]+?):\s*\n\s*\n*\1:\s*\n',
            r'\1:\n',
            generated_text,
            flags=re.MULTILINE
        )
        
        # Fix 8: Handle incomplete termination - remove incomplete words/sentences at the end
        # This happens when the model hits the token limit mid-generation
        if generated_text.strip():
            # Remove incomplete word at the end (word that doesn't end with punctuation or space)
            # Pattern: ends with a word that has no trailing punctuation/space
            # But keep if it ends with proper punctuation (. ! ? , ; :)
            lines = generated_text.split('\n')
            if lines:
                last_line = lines[-1].strip()
                
                # If last line doesn't end with punctuation and is not a speaker name
                if last_line and not re.match(r'^([A-Z][A-Z\s]+?):\s*$', last_line):
                    # Check if it ends with incomplete word (no punctuation, not a complete sentence)
                    # Remove if it ends with a word that looks incomplete
                    # Pattern: ends with word that has no punctuation
                    if not re.search(r'[.!?,;:]$', last_line):
                        # Check if the last "word" is very short (likely incomplete)
                        # Or if it's a single character/letter (likely cut off)
                        words = last_line.split()
                        if words:
                            last_word = words[-1]
                            # If last word is very short (1-2 chars) and not punctuation, likely incomplete
                            if len(last_word) <= 2 and last_word.isalpha():
                                # Remove the incomplete last word
                                lines[-1] = ' '.join(words[:-1]) if len(words) > 1 else ''
                            # If last word doesn't end with punctuation and line is short, might be incomplete
                            elif len(last_line) < 20 and not last_word.endswith(('.', '!', '?', ',', ';', ':')):
                                # Check if removing last word makes sense
                                # Only remove if it's clearly incomplete (very short word)
                                if len(last_word) < 4:
                                    lines[-1] = ' '.join(words[:-1]) if len(words) > 1 else ''
                    
                    # If after processing, last line is empty or just whitespace, remove it
                    if not lines[-1].strip():
                        lines = lines[:-1]
                
                # Reconstruct text
                generated_text = '\n'.join(lines)
                
                # Final check: if text doesn't end with punctuation and is not a speaker, 
                # try to find the last complete sentence
                if generated_text.strip():
                    # Find the last complete sentence (ends with . ! ?)
                    # Split by sentences
                    sentences = re.split(r'([.!?]+)', generated_text)
                    if len(sentences) > 1:
                        # Reconstruct, keeping only complete sentences
                        complete_text = ''
                        for i in range(0, len(sentences) - 1, 2):
                            if i + 1 < len(sentences):
                                complete_text += sentences[i] + sentences[i + 1]
                        # If we have complete sentences, use them; otherwise keep original
                        if complete_text.strip():
                            # But check if we removed too much (more than 50% of text)
                            if len(complete_text.strip()) > len(generated_text.strip()) * 0.3:
                                generated_text = complete_text.strip()
        
        return generated_text
    except Exception as e:
        import traceback
        return f"❌ Error during generation: {str(e)}\n\nPlease check:\n1. Model is uploaded to HuggingFace Model Hub\n2. Repository name is correct: shwethd/gpt2-shakespeare-124m\n3. File name is exactly: model_checkpoint_final.pt"


# Create Gradio interface
with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
    # Status indicator
    status_color = "🟒" if model_loaded else "πŸ”΄"
    status_text = "Model loaded successfully!" if model_loaded else "⚠️ Model not loaded - check HuggingFace Model Hub!"
    
    gr.Markdown(f"""
    # 🎭 GPT-2 124M Shakespeare Language Model
    
    {status_color} **Status:** {status_text}
    
    This is a 124M parameter decoder-only transformer model trained on Shakespeare's complete works.
    
    **Training Results:**
    - Final Loss: 0.095127 (Target: < 0.099999) βœ…
    - Model Parameters: 124.44M
    - Training Steps: 1,637
    
    Enter a prompt below to generate Shakespeare-style text!
    
    {"⚠️ **Note:** If you see garbled/random text, the model may not have loaded correctly. Check the logs and ensure the model is uploaded to HuggingFace Model Hub: `shwethd/gpt2-shakespeare-124m`" if not model_loaded else ""}
    """)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your prompt here (e.g., 'First Citizen:', 'ROMEO:', 'To be or not')",
                value="First Citizen:",
                lines=3
            )
            max_tokens = gr.Slider(
                label="Max Tokens",
                minimum=50,
                maximum=200,
                value=100,
                step=10
            )
            temperature = gr.Slider(
                label="Temperature",
                minimum=0.1,
                maximum=2.0,
                value=0.7,
                step=0.1,
                info="Lower = more focused, Higher = more creative (0.7 recommended for better coherence)"
            )
            top_k = gr.Slider(
                label="Top-K",
                minimum=10,
                maximum=100,
                value=50,
                step=10,
                info="Number of top tokens to consider"
            )
            top_p = gr.Slider(
                label="Top-P (Nucleus)",
                minimum=0.1,
                maximum=1.0,
                value=0.85,
                step=0.05,
                info="Nucleus sampling - 0.85-0.9 recommended. Lower (0.3) = too restrictive, Higher (0.95+) = too random"
            )
            repetition_penalty = gr.Slider(
                label="Repetition Penalty",
                minimum=1.0,
                maximum=1.5,
                value=1.1,
                step=0.05,
                info="Penalize repeated tokens - higher = less repetition (1.1 recommended)"
            )
            generate_btn = gr.Button("Generate", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(
                label="Generated Text",
                lines=10,
                interactive=True,  # Make it interactive so users can select and copy
                show_copy_button=True  # Add copy button
            )
    
    # Example prompts with suggested parameters
    gr.Markdown("### Example Prompts (Click to try - includes optimal settings)")
    examples = gr.Examples(
        examples=[
            # Format: [prompt, max_tokens, temperature, top_k, top_p, repetition_penalty]
            ["First Citizen:", 100, 0.7, 50, 0.85, 1.1],
            ["ROMEO:", 100, 0.65, 45, 0.88, 1.15],  # Romantic - slightly lower temp
            ["To be or not", 80, 0.6, 40, 0.85, 1.2],  # Quote - more focused
            ["HAMLET:", 100, 0.7, 50, 0.85, 1.1],
            ["MACBETH:", 100, 0.7, 50, 0.85, 1.1],
            ["JULIET:", 100, 0.65, 45, 0.88, 1.15],  # Romantic
            ["KING:", 100, 0.7, 50, 0.85, 1.1],
            ["LADY MACBETH:", 100, 0.7, 50, 0.85, 1.1],
            ["OTHELLO:", 100, 0.7, 50, 0.85, 1.1],
            ["What light through yonder", 100, 0.65, 45, 0.88, 1.15],  # Romantic quote
            ["All the world's a stage", 100, 0.7, 50, 0.85, 1.1],  # Metaphorical
            ["Double, double toil and trouble", 80, 0.7, 50, 0.85, 1.15],  # Witches chant
            ["Friends, Romans, countrymen", 100, 0.7, 50, 0.85, 1.1],  # Speech
            ["A rose by any other name", 100, 0.65, 45, 0.88, 1.15],  # Romantic quote
        ],
        inputs=[prompt_input, max_tokens, temperature, top_k, top_p, repetition_penalty]
    )
    
    generate_btn.click(
        fn=generate_text,
        inputs=[prompt_input, max_tokens, temperature, top_k, top_p, repetition_penalty],
        outputs=output
    )
    
    gr.Markdown("""
    ---
    **Note:** The model was trained on Shakespeare text and generates text in that style.
    Generated text may not always be coherent but should follow Shakespearean patterns.
    """)

if __name__ == "__main__":
    # Don't use share=True on HuggingFace Spaces
    demo.launch()