File size: 32,084 Bytes
ef6446c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
#!/usr/bin/env python3
# Copyright (C) 2024 Louis Chua Bean Chong
#
# This file is part of OpenLLM.
#
# OpenLLM is dual-licensed:
# 1. For open source use: GNU General Public License v3.0
# 2. For commercial use: Commercial License (contact for details)
#
# See LICENSE and docs/LICENSES.md for full license information.

"""

OpenLLM Text Generation Script



This script implements standalone text generation for OpenLLM models

as specified in Step 5 of the training pipeline (Text Generation Quality assessment).



Features:

- Load trained OpenLLM models from checkpoint directories

- Generate text with configurable parameters (temperature, length, etc.)

- Support multiple model formats (auto-detection)

- Quality assessment and metrics

- Batch generation capabilities

- Output formatting and saving



Usage:

    # Basic text generation

    python core/src/generate_text.py \

        --model_dir models/small-extended-4k \

        --prompt "The history of artificial intelligence" \

        --max_length 256 \

        --temperature 0.7



    # Multiple prompts with custom settings

    python core/src/generate_text.py \

        --model_dir models/small-extended-4k \

        --prompts_file prompts.txt \

        --max_length 100 \

        --temperature 0.8 \

        --top_k 40 \

        --num_samples 3



    # Save results to file

    python core/src/generate_text.py \

        --model_dir models/small-extended-4k \

        --prompt "Once upon a time" \

        --output_file generated_samples.txt



Author: Louis Chua Bean Chong

License: GPLv3

"""

import argparse
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

import sentencepiece as spm
import torch

# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from model import create_model


class TextGenerator:
    """

    Comprehensive text generation engine for OpenLLM models.



    This class handles loading trained models and generating high-quality text

    with configurable sampling parameters and quality assessment.

    """

    def __init__(self, model_dir: str, device: str = "auto"):
        """

        Initialize the text generator.



        Args:

            model_dir: Directory containing trained model checkpoints

            device: Device to use ("auto", "cpu", "cuda")



        Implementation Details:

            - Auto-detects best available device if device="auto"

            - Loads model architecture based on checkpoint configuration

            - Sets up tokenizer for text processing

            - Validates model and tokenizer compatibility

        """
        self.model_dir = Path(model_dir)

        # Determine device to use
        # Auto-detection prioritizes CUDA if available for better performance
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        print("πŸš€ OpenLLM Text Generator")
        print(f"πŸ“‚ Model directory: {model_dir}")
        print(f"πŸ–₯️  Device: {self.device}")

        # Load model and tokenizer
        # This handles the complete setup process
        self._load_model()
        self._load_tokenizer()

        # Validate setup
        # Ensure model and tokenizer are compatible
        self._validate_setup()

        print("βœ… Text generator initialized successfully!")

    def _load_model(self):
        """

        Load the trained model from checkpoint.



        Implementation Details:

            - Searches for best_model.pt or latest checkpoint

            - Auto-detects model size from configuration

            - Handles different checkpoint formats gracefully

            - Sets model to evaluation mode for inference

        """
        # Find the best model checkpoint
        # Priority: best_model.pt > latest checkpoint by step number
        best_model_path = self.model_dir / "best_model.pt"

        if best_model_path.exists():
            checkpoint_path = best_model_path
            print(f"πŸ“₯ Loading best model: {checkpoint_path}")
        else:
            # Look for step-based checkpoints
            checkpoints = list(self.model_dir.glob("checkpoint_step_*.pt"))
            if not checkpoints:
                raise FileNotFoundError(f"No model checkpoints found in {self.model_dir}")

            # Get the latest checkpoint by step number
            latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split("_")[-1]))
            checkpoint_path = latest_checkpoint
            print(f"πŸ“₯ Loading latest checkpoint: {checkpoint_path}")

        # Load checkpoint data
        # This contains model weights, configuration, and training metadata
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            print("βœ… Checkpoint loaded successfully")
        except Exception as e:
            raise RuntimeError(f"Failed to load checkpoint: {e}")

        # Extract model configuration
        # This tells us what architecture to create
        if "config" in checkpoint:
            config_dict = checkpoint["config"]
        else:
            # Fallback: try to infer from model state dict
            print("⚠️  No config found in checkpoint, inferring from model structure...")
            config_dict = self._infer_config_from_state_dict(
                checkpoint.get("model_state_dict", checkpoint)
            )

        # Determine model size category
        # This maps checkpoint config to our predefined model sizes
        n_layer = config_dict.get("n_layer", 12)
        n_embd = config_dict.get("n_embd", 768)

        if n_layer <= 6:
            model_size = "small"
        elif n_layer <= 12:
            model_size = "medium"
        else:
            model_size = "large"

        print(f"🎯 Detected model size: {model_size}")
        print(f"πŸ“Š Architecture: {n_layer} layers, {n_embd} embedding dim")

        # Create model architecture
        # This recreates the exact same model used during training
        try:
            self.model = create_model(model_size)
            print(f"πŸ—οΈ  Model architecture created: {self.model.get_num_params():,} parameters")
        except Exception as e:
            raise RuntimeError(f"Failed to create model architecture: {e}")

        # Load trained weights
        # This restores the model to its trained state
        try:
            if "model_state_dict" in checkpoint:
                self.model.load_state_dict(checkpoint["model_state_dict"])
            else:
                # Fallback for different checkpoint formats
                self.model.load_state_dict(checkpoint)

            print("βœ… Model weights loaded successfully")
        except Exception as e:
            raise RuntimeError(f"Failed to load model weights: {e}")

        # Move model to device and set to evaluation mode
        # Evaluation mode disables dropout and other training-specific behaviors
        self.model = self.model.to(self.device)
        self.model.eval()

        # Store model configuration for later use
        # This is useful for generation parameters and limits
        self.config = self.model.config

        # Extract training metadata if available
        # This provides context about model quality and training progress
        self.training_info = {
            "step": checkpoint.get("step", "Unknown"),
            "best_loss": checkpoint.get("best_loss", "Unknown"),
            "model_size": model_size,
        }

        print(
            f"πŸ“ˆ Training info: step {self.training_info['step']}, "
            f"best loss {self.training_info['best_loss']}"
        )

    def _infer_config_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """

        Infer model configuration from state dict when config is missing.



        Args:

            state_dict: Model parameter dictionary



        Returns:

            Inferred configuration dictionary



        Implementation Details:

            - Analyzes parameter shapes to determine architecture

            - Makes reasonable assumptions about standard GPT architecture

            - Provides fallback values for missing parameters

        """
        # Extract key dimensions from parameter shapes
        # This reverse-engineers the model architecture

        # Embedding layer tells us vocab size and embedding dimension
        if "transformer.wte.weight" in state_dict:
            vocab_size, n_embd = state_dict["transformer.wte.weight"].shape
        else:
            # Fallback defaults
            vocab_size, n_embd = 32000, 512

        # Count transformer blocks to get number of layers
        # Look for attention weight patterns
        n_layer = 0
        for key in state_dict.keys():
            if "attn.c_attn.weight" in key:
                # Extract layer number from key like 'transformer.h.0.attn.c_attn.weight'
                layer_num = int(key.split(".")[2])
                n_layer = max(n_layer, layer_num + 1)

        # Infer number of attention heads from attention weights
        # The c_attn weight combines query, key, value projections
        if "transformer.h.0.attn.c_attn.weight" in state_dict:
            _ = state_dict["transformer.h.0.attn.c_attn.weight"].shape
            # Shape is [n_embd, 3 * n_embd] for combined Q,K,V
            # So n_head = n_embd / head_dim, assuming head_dim = 64
            n_head = n_embd // 64  # Standard head dimension
        else:
            n_head = 8  # Fallback

        # Construct configuration dictionary
        # Use reasonable defaults for missing values
        config = {
            "vocab_size": vocab_size,
            "n_layer": n_layer,
            "n_head": n_head,
            "n_embd": n_embd,
            "block_size": 1024,  # Standard context length
            "dropout": 0.1,  # Standard dropout rate
            "bias": True,  # Most models use bias
            "model_name": f"gpt-inferred-{n_layer}L",
        }

        print(f"πŸ” Inferred config: {config}")
        return config

    def _load_tokenizer(self):
        """

        Load the SentencePiece tokenizer.



        Implementation Details:

            - Searches multiple possible tokenizer locations

            - Validates tokenizer vocabulary size against model

            - Sets up special tokens if available

        """
        # Try multiple possible tokenizer locations
        # Different training setups may store tokenizer in different places
        possible_paths = [
            self.model_dir / "tokenizer.model",
            self.model_dir.parent / "tokenizer" / "tokenizer.model",
            Path("data/tokenizer/tokenizer.model"),
            self.model_dir / ".." / "tokenizer" / "tokenizer.model",
        ]

        tokenizer_path = None
        for path in possible_paths:
            if path.exists():
                tokenizer_path = path
                break

        if tokenizer_path is None:
            raise FileNotFoundError(f"Tokenizer not found in any of: {possible_paths}")

        print(f"πŸ“ Loading tokenizer from: {tokenizer_path}")

        # Load SentencePiece tokenizer
        # This handles all text-to-token and token-to-text conversion
        try:
            self.tokenizer = spm.SentencePieceProcessor()
            self.tokenizer.load(str(tokenizer_path))
            print(f"βœ… Tokenizer loaded: {self.tokenizer.vocab_size()} vocabulary")
        except Exception as e:
            raise RuntimeError(f"Failed to load tokenizer: {e}")

    def _validate_setup(self):
        """

        Validate that model and tokenizer are compatible.



        Implementation Details:

            - Checks vocabulary size consistency

            - Tests basic tokenization and model forward pass

            - Warns about potential compatibility issues

        """
        # Check vocabulary size consistency
        # Model and tokenizer should have matching vocabulary
        model_vocab_size = self.config.vocab_size
        tokenizer_vocab_size = self.tokenizer.vocab_size()

        if model_vocab_size != tokenizer_vocab_size:
            print("⚠️  Warning: Vocabulary size mismatch!")
            print(f"   Model expects: {model_vocab_size}")
            print(f"   Tokenizer has: {tokenizer_vocab_size}")
            print("   This may cause generation issues.")

        # Test basic functionality
        # Quick validation that everything works together
        try:
            # Test tokenization
            test_text = "Hello world"
            tokens = self.tokenizer.encode(test_text)
            _ = self.tokenizer.decode(tokens)

            # Test model forward pass
            input_ids = torch.tensor([tokens[:5]], dtype=torch.long, device=self.device)
            with torch.no_grad():
                _ = self.model(input_ids)

            print("βœ… Validation passed: tokenization and model forward pass work")

        except Exception as e:
            print(f"⚠️  Validation warning: {e}")
            print("   Generation may still work, but there might be issues.")

    def generate(

        self,

        prompt: str,

        max_length: int = 100,

        temperature: float = 0.7,

        top_k: Optional[int] = 40,

        top_p: Optional[float] = 0.9,

        num_return_sequences: int = 1,

        do_sample: bool = True,

        repetition_penalty: float = 1.0,

    ) -> List[str]:
        """

        Generate text from a prompt using the loaded model.



        Args:

            prompt: Input text to continue

            max_length: Maximum number of tokens to generate

            temperature: Sampling temperature (0.1-2.0, higher = more random)

            top_k: Limit to top-k most likely tokens (None = no limit)

            top_p: Nucleus sampling threshold (None = no nucleus sampling)

            num_return_sequences: Number of sequences to generate

            do_sample: Whether to use sampling (False = greedy)

            repetition_penalty: Penalty for repeating tokens (1.0 = no penalty)



        Returns:

            List of generated text strings



        Implementation Details:

            - Uses autoregressive generation (one token at a time)

            - Supports multiple sampling strategies (greedy, top-k, nucleus)

            - Handles context length limits gracefully

            - Applies repetition penalty to improve quality

            - Returns only the generated portion (excludes input prompt)

        """
        print(f"🎯 Generating text for: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")
        print(
            f"βš™οΈ  Parameters: max_length={max_length}, temperature={temperature}, "
            f"top_k={top_k}, top_p={top_p}"
        )

        # Tokenize input prompt
        # Convert text to token IDs for model processing
        try:
            input_tokens = self.tokenizer.encode(prompt)
            if len(input_tokens) == 0:
                raise ValueError("Empty tokenization result")
        except Exception as e:
            raise RuntimeError(f"Failed to tokenize prompt: {e}")

        # Check prompt length against model context
        # Ensure we don't exceed model's maximum sequence length
        max_context = self.config.block_size
        if len(input_tokens) >= max_context:
            print(
                f"⚠️  Warning: Prompt length ({len(input_tokens)}) approaches "
                f"context limit ({max_context})"
            )
            # Truncate prompt if necessary
            input_tokens = input_tokens[-(max_context - max_length) :]
            print(f"   Truncated prompt to {len(input_tokens)} tokens")

        # Generate multiple sequences
        # Each sequence is generated independently
        generated_texts = []

        for seq_idx in range(num_return_sequences):
            if num_return_sequences > 1:
                print(f"πŸ”„ Generating sequence {seq_idx + 1}/{num_return_sequences}")

            try:
                generated_text = self._generate_single_sequence(
                    input_tokens=input_tokens,
                    max_length=max_length,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    do_sample=do_sample,
                    repetition_penalty=repetition_penalty,
                )
                generated_texts.append(generated_text)

            except Exception as e:
                print(f"⚠️  Generation failed for sequence {seq_idx + 1}: {e}")
                generated_texts.append(f"Generation error: {e}")

        return generated_texts

    def _generate_single_sequence(

        self,

        input_tokens: List[int],

        max_length: int,

        temperature: float,

        top_k: Optional[int],

        top_p: Optional[float],

        do_sample: bool,

        repetition_penalty: float,

    ) -> str:
        """

        Generate a single text sequence using autoregressive sampling.



        Args:

            input_tokens: Tokenized input prompt

            max_length: Maximum tokens to generate

            temperature: Sampling temperature

            top_k: Top-k sampling limit

            top_p: Nucleus sampling threshold

            do_sample: Whether to use sampling vs greedy

            repetition_penalty: Repetition penalty factor



        Returns:

            Generated text string (excluding input prompt)



        Implementation Details:

            - Implements autoregressive generation loop

            - Applies all specified sampling strategies

            - Handles special tokens (EOS, padding)

            - Tracks token frequencies for repetition penalty

        """
        # Initialize generation state
        # Keep track of all generated tokens and their frequencies
        generated_tokens = input_tokens.copy()
        token_frequencies = {}  # For repetition penalty

        # Count initial token frequencies
        # This helps apply repetition penalty from the start
        for token in input_tokens:
            token_frequencies[token] = token_frequencies.get(token, 0) + 1

        # Set model to evaluation mode and disable gradients
        # This ensures consistent inference behavior and saves memory
        self.model.eval()

        with torch.no_grad():
            # Main generation loop
            # Generate one token at a time until stopping condition
            for step in range(max_length):
                # Check context length limits
                # Prevent exceeding model's maximum sequence length
                if len(generated_tokens) >= self.config.block_size:
                    print(f"⚠️  Reached maximum context length ({self.config.block_size})")
                    break

                # Prepare model input
                # Use all generated tokens as context for next prediction
                input_ids = torch.tensor([generated_tokens], dtype=torch.long, device=self.device)

                try:
                    # Forward pass through model
                    # Get logits (raw predictions) for all vocabulary tokens
                    outputs = self.model(input_ids)

                    # Handle different model output formats
                    # Some models return tuples, others return tensors directly
                    if isinstance(outputs, tuple):
                        logits = outputs[0]  # First element is usually logits
                    else:
                        logits = outputs

                    # Get predictions for next token (last position in sequence)
                    next_token_logits = logits[0, -1, :].float()

                except Exception as e:
                    raise RuntimeError(f"Model forward pass failed at step {step}: {e}")

                # Apply repetition penalty
                # Reduce probability of recently used tokens
                if repetition_penalty != 1.0:
                    for token, freq in token_frequencies.items():
                        if token < len(next_token_logits):
                            penalty = repetition_penalty**freq
                            if next_token_logits[token] > 0:
                                next_token_logits[token] /= penalty
                            else:
                                next_token_logits[token] *= penalty

                # Apply sampling strategy to select next token
                # This determines the randomness and quality of generation
                if do_sample:
                    next_token = self._sample_next_token(
                        next_token_logits, temperature, top_k, top_p
                    )
                else:
                    # Greedy decoding: always pick most likely token
                    next_token = torch.argmax(next_token_logits).item()

                # Add generated token to sequence
                generated_tokens.append(next_token)

                # Update token frequency for repetition penalty
                token_frequencies[next_token] = token_frequencies.get(next_token, 0) + 1

                # Check for end-of-sequence token
                # Some models/tokenizers have special EOS tokens
                if hasattr(self.tokenizer, "eos_id") and next_token == self.tokenizer.eos_id():
                    print(f"πŸ”š Reached end-of-sequence token at step {step}")
                    break

                # Optional: Check for other stopping conditions
                # Could add custom stop words or patterns here

        # Decode generated tokens to text
        # Convert token IDs back to readable text, excluding input prompt
        try:
            # Extract only newly generated tokens (exclude input prompt)
            new_tokens = generated_tokens[len(input_tokens) :]

            if len(new_tokens) == 0:
                return "⚠️  No tokens generated"

            # Decode to text using tokenizer
            generated_text = self.tokenizer.decode(new_tokens)

            print(f"βœ… Generated {len(new_tokens)} tokens")
            return generated_text

        except Exception as e:
            raise RuntimeError(f"Failed to decode generated tokens: {e}")

    def _sample_next_token(

        self, logits: torch.Tensor, temperature: float, top_k: Optional[int], top_p: Optional[float]

    ) -> int:
        """

        Sample next token using specified sampling strategy.



        Args:

            logits: Raw model predictions for next token

            temperature: Sampling temperature

            top_k: Top-k sampling limit

            top_p: Nucleus sampling threshold



        Returns:

            Selected token ID



        Implementation Details:

            - Applies temperature scaling for randomness control

            - Implements top-k sampling to limit choices

            - Implements nucleus (top-p) sampling for quality

            - Uses multinomial sampling for final selection

        """
        # Apply temperature scaling
        # Higher temperature = more random, lower = more deterministic
        if temperature != 1.0:
            logits = logits / temperature

        # Apply top-k filtering
        # Only consider the k most likely tokens
        if top_k is not None and top_k > 0:
            # Get indices of top-k tokens
            top_k_tokens = min(top_k, logits.size(-1))
            top_k_values, top_k_indices = torch.topk(logits, top_k_tokens)

            # Zero out non-top-k logits
            filtered_logits = torch.full_like(logits, float("-inf"))
            filtered_logits[top_k_indices] = top_k_values
            logits = filtered_logits

        # Apply nucleus (top-p) sampling
        # Dynamically adjust vocabulary based on cumulative probability
        if top_p is not None and top_p < 1.0:
            # Sort logits in descending order
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)

            # Calculate cumulative probabilities
            sorted_probs = torch.softmax(sorted_logits, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # Find cutoff point where cumulative probability exceeds top_p
            sorted_indices_to_remove = cumulative_probs > top_p

            # Keep at least the top token
            sorted_indices_to_remove[0] = False

            # Zero out tokens beyond nucleus
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = float("-inf")

        # Convert logits to probabilities and sample
        # Use multinomial sampling for final token selection
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()

        return next_token

    def generate_batch(self, prompts: List[str], **generation_kwargs) -> List[List[str]]:
        """

        Generate text for multiple prompts.



        Args:

            prompts: List of input prompts

            **generation_kwargs: Arguments passed to generate()



        Returns:

            List of lists, where each inner list contains generated texts for one prompt



        Implementation Details:

            - Processes prompts sequentially (could be parallelized)

            - Applies same generation parameters to all prompts

            - Handles errors gracefully for individual prompts

        """
        print(f"πŸ”„ Generating text for {len(prompts)} prompts...")

        all_results = []

        for i, prompt in enumerate(prompts):
            print(f"\n--- Prompt {i + 1}/{len(prompts)} ---")

            try:
                results = self.generate(prompt, **generation_kwargs)
                all_results.append(results)

            except Exception as e:
                print(f"❌ Failed to generate for prompt {i + 1}: {e}")
                all_results.append([f"Generation failed: {e}"])

        return all_results


def load_prompts_from_file(file_path: str) -> List[str]:
    """

    Load prompts from a text file.



    Args:

        file_path: Path to file containing prompts (one per line)



    Returns:

        List of prompt strings



    Implementation Details:

        - Reads file line by line

        - Strips whitespace and filters empty lines

        - Handles different text encodings gracefully

    """
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            prompts = [line.strip() for line in f if line.strip()]

        print(f"πŸ“„ Loaded {len(prompts)} prompts from {file_path}")
        return prompts

    except Exception as e:
        raise RuntimeError(f"Failed to load prompts from {file_path}: {e}")


def save_results_to_file(results: List[str], output_path: str, prompts: List[str] = None):
    """

    Save generation results to a text file.



    Args:

        results: Generated text results

        output_path: Path to output file

        prompts: Original prompts (optional, for context)



    Implementation Details:

        - Formats output with clear separators

        - Includes prompts and metadata when available

        - Handles file creation and error reporting

    """
    try:
        with open(output_path, "w", encoding="utf-8") as f:
            f.write("# OpenLLM Text Generation Results\n")
            f.write(f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"# Total samples: {len(results)}\n\n")

            for i, result in enumerate(results):
                f.write(f"--- Sample {i + 1} ---\n")

                if prompts and i < len(prompts):
                    f.write(f"Prompt: {prompts[i]}\n\n")

                if isinstance(result, list):
                    for j, text in enumerate(result):
                        f.write(f"Generated {j + 1}: {text}\n\n")
                else:
                    f.write(f"Generated: {result}\n\n")

                f.write("-" * 50 + "\n\n")

        print(f"πŸ’Ύ Results saved to: {output_path}")

    except Exception as e:
        raise RuntimeError(f"Failed to save results to {output_path}: {e}")


def main():
    """Main function for command-line text generation."""
    parser = argparse.ArgumentParser(
        description="OpenLLM Text Generation",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""

Examples:

  # Basic text generation

  python core/src/generate_text.py \\

    --model_dir ./openllm-trained \\

    --prompt "Hello, how are you?" \\

    --max_length 100



  # Advanced generation with parameters

  python core/src/generate_text.py \\

    --model_dir ./openllm-trained \\

    --prompt "The future of AI is" \\

    --max_length 200 \\

    --temperature 0.8 \\

    --top_k 50 \\

    --top_p 0.9

        """,
    )

    parser.add_argument(
        "--model_dir",
        required=True,
        help="Directory containing trained model checkpoints",
    )

    parser.add_argument(
        "--prompt",
        required=True,
        help="Input text prompt for generation",
    )

    parser.add_argument(
        "--max_length",
        type=int,
        default=100,
        help="Maximum number of tokens to generate (default: 100)",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="Sampling temperature (default: 0.7)",
    )

    parser.add_argument(
        "--top_k",
        type=int,
        default=40,
        help="Top-k sampling parameter (default: 40)",
    )

    parser.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus sampling parameter (default: 0.9)",
    )

    parser.add_argument(
        "--device",
        default="auto",
        choices=["auto", "cpu", "cuda"],
        help="Device to use for generation (default: auto)",
    )

    args = parser.parse_args()

    print("πŸš€ OpenLLM Text Generation")
    print("=" * 50)

    try:
        # Initialize text generator
        generator = TextGenerator(args.model_dir, args.device)

        # Generate text
        print(f"πŸ“ Prompt: {args.prompt}")
        print(f"βš™οΈ  Parameters: max_length={args.max_length}, temperature={args.temperature}")

        generated_text = generator.generate(
            prompt=args.prompt,
            max_length=args.max_length,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
        )

        print("\n🎯 Generated text:")
        print(f"{generated_text}")

    except Exception as e:
        print(f"\n❌ Error: {e}")
        import traceback

        traceback.print_exc()
        return False

    return True


def load_tokenizer(tokenizer_path: str):
    """

    Load tokenizer for testing purposes.



    This function is used by tests to load tokenizers without initializing the full generator.



    Args:

        tokenizer_path: Path to tokenizer model file



    Returns:

        SentencePieceProcessor: Loaded tokenizer

    """
    import sentencepiece as spm

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load(tokenizer_path)
    return tokenizer


if __name__ == "__main__":
    success = main()
    exit(0 if success else 1)