File size: 11,578 Bytes
9dd056d
 
d7d2fb2
326b359
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838299a
 
 
 
 
9dd056d
 
 
 
 
 
 
838299a
 
 
 
 
 
 
 
 
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326b359
9dd056d
 
 
 
326b359
 
 
 
 
 
9dd056d
 
 
 
 
 
 
 
 
 
326b359
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
326b359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd056d
 
 
 
18a94f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd056d
7557c9f
 
18a94f8
7557c9f
326b359
2120bf6
326b359
18a94f8
 
 
 
 
 
 
 
 
 
9dd056d
7557c9f
 
 
 
 
 
 
 
 
 
326b359
 
2120bf6
 
 
326b359
2120bf6
 
326b359
 
 
2120bf6
 
 
 
 
9dd056d
326b359
18a94f8
9dd056d
d7d2fb2
18a94f8
 
 
 
 
 
 
 
 
 
 
 
9dd056d
 
 
 
 
 
 
326b359
 
 
 
 
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326b359
 
 
 
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326b359
 
 
 
9dd056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import torch
import tiktoken
from model import ismail, ModelArgs
from data import TurkishTokenizerWrapper, TURKISH_TOKENIZER_AVAILABLE


#####################################
# TEXT GENERATION FUNCTIONS
#####################################

def generate_text_simple(model, idx, max_new_tokens, context_size):
    """
    Generate text using simple greedy decoding (argmax).

    Args:
        model: The transformer model
        idx: Input token indices of shape (batch_size, seq_len)
        max_new_tokens: Number of new tokens to generate
        context_size: Maximum context size the model can handle

    Returns:
        Generated token indices of shape (batch_size, seq_len + max_new_tokens)
    """
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # Crop current context if it exceeds the supported context size
        # E.g., if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = idx[:, -context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)

        # Focus only on the last time step
        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]

        # Get the idx of the vocab entry with the highest logits value
        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)

        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)

    return idx


def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
    """
    Generate text using sampling with temperature and optional top-k filtering.

    Args:
        model: The transformer model
        idx: Input token indices of shape (batch_size, seq_len)
        max_new_tokens: Number of new tokens to generate
        context_size: Maximum context size the model can handle
        temperature: Sampling temperature (higher = more random, lower = more deterministic)
        top_k: If set, only sample from the top k most likely tokens

    Returns:
        Generated token indices of shape (batch_size, seq_len + max_new_tokens)
    """
    for _ in range(max_new_tokens):
        # Crop current context if it exceeds the supported context size
        idx_cond = idx[:, -context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)

        # Focus only on the last time step
        logits = logits[:, -1, :]

        # Clamp temperature to avoid division by very small numbers
        temperature = max(temperature, 1e-8)
        logits = logits / temperature

        # Optional: apply top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        # Apply softmax to get probabilities
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)

        # Handle edge cases: check for invalid probabilities
        if torch.isnan(probs).any() or torch.isinf(probs).any():
            # Fallback to uniform distribution over valid tokens
            probs = torch.ones_like(probs) / probs.size(-1)

        # Ensure probabilities sum to 1
        probs = probs / probs.sum(dim=-1, keepdim=True)

        # Sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)

        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)

    return idx


def text_to_token_ids(text, tokenizer):
    """
    Convert text to token IDs.

    Args:
        text: Input text string
        tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper)

    Returns:
        Tensor of token IDs with shape (1, seq_len)
    """
    # Turkish tokenizer doesn't support allowed_special parameter
    if isinstance(tokenizer, TurkishTokenizerWrapper):
        encoded = tokenizer.encode(text)
    else:
        encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})

    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor


def token_ids_to_text(token_ids, tokenizer):
    """
    Convert token IDs to text.

    Args:
        token_ids: Tensor of token IDs, can be 1D or 2D
        tokenizer: Tokenizer instance (tiktoken or TurkishTokenizerWrapper)

    Returns:
        Decoded text string
    """
    # Handle both 1D and 2D tensors
    if token_ids.dim() == 2:
        token_ids = token_ids.squeeze(0)

    # Convert to list and decode
    flat = token_ids.tolist()
    return tokenizer.decode(flat)


def get_tokenizer(use_turkish=False, tokenizer_name="gpt2"):
    """
    Get the appropriate tokenizer based on user preference.

    Args:
        use_turkish: Whether to use Turkish tokenizer
        tokenizer_name: Name of tiktoken tokenizer to use if not using Turkish

    Returns:
        Tokenizer instance (TurkishTokenizerWrapper or tiktoken tokenizer)
    """
    if use_turkish:
        if not TURKISH_TOKENIZER_AVAILABLE:
            raise ImportError(
                "Turkish tokenizer requested but not available. "
                "Install it with: pip install turkish-tokenizer"
            )
        tokenizer = TurkishTokenizerWrapper()
        print(f"🇹🇷 Using Turkish Tokenizer (vocab size: {tokenizer.n_vocab:,})")
        return tokenizer
    else:
        tokenizer = tiktoken.get_encoding(tokenizer_name)
        print(f"📚 Using tiktoken tokenizer: {tokenizer_name} (vocab size: {tokenizer.n_vocab:,})")
        return tokenizer


#####################################
# EXAMPLE USAGE
#####################################

def load_checkpoint(model, checkpoint_path):
    """
    Load a trained checkpoint into the model.

    Args:
        model: The model instance
        checkpoint_path: Path to the checkpoint file (.pt)

    Returns:
        The loaded checkpoint dictionary with metadata
    """
    print(f"\n📦 Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # Handle different checkpoint formats
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✅ Loaded model state from checkpoint")
        if 'step' in checkpoint:
            print(f"   Training step: {checkpoint['step']:,}")
        if 'loss' in checkpoint:
            print(f"   Loss: {checkpoint['loss']:.4f}")
    else:
        # Direct state dict
        model.load_state_dict(checkpoint)
        print(f"✅ Loaded model state (direct)")

    return checkpoint


if __name__ == "__main__":
    import json
    from pathlib import Path
    import sys

    # Configuration: Set to True to use Turkish tokenizer, False for tiktoken
    USE_TURKISH_TOKENIZER = True  # Change this to False for English text generation

    # ===== CHECKPOINT LOADING =====
    # Set this to the path of your trained checkpoint
    # Example: CHECKPOINT_PATH = "./checkpoints/step_55000_expert_2.pt"
    CHECKPOINT_PATH = None  # Set to None to use random initialization

    # You can also pass checkpoint path as command line argument
    if len(sys.argv) > 1:
        CHECKPOINT_PATH = sys.argv[1]
        print(f"🔧 Using checkpoint from command line: {CHECKPOINT_PATH}")

    # Example configuration - smaller model for testing
    config_path = Path("config.json")
    if config_path.exists():
        with open(config_path) as f:
            config = json.load(f)
        print(f"✅ Loaded config from {config_path}")
        args = ModelArgs(**config["model"])
    else:
        print("⚠️ config.json not found, using default ModelArgs")
        args = ModelArgs()

    # Initialize tokenizer
    tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
    # Auto-detect Turkish tokenizer from config
    use_turkish = (tokenizer_name.lower() == "turkish") or USE_TURKISH_TOKENIZER

    tokenizer = get_tokenizer(
        use_turkish=use_turkish,
        tokenizer_name="gpt2" if use_turkish else tokenizer_name
    )

    # Update vocab size if using Turkish tokenizer
    if use_turkish and isinstance(tokenizer, TurkishTokenizerWrapper):
        if args.vocab_size != tokenizer.n_vocab:
            print(f"⚠️  Config vocab_size ({args.vocab_size:,}) doesn't match tokenizer ({tokenizer.n_vocab:,})")
            args.vocab_size = tokenizer.n_vocab
            print(f"📊 Updated vocab_size to {args.vocab_size:,} for Turkish tokenizer")

    # Initialize model
    print("\n🚀 Initializing model...")
    torch.manual_seed(123)
    model = ismail(args)

    # Load checkpoint if specified
    if CHECKPOINT_PATH:
        checkpoint_file = Path(CHECKPOINT_PATH)
        if checkpoint_file.exists():
            load_checkpoint(model, checkpoint_file)
        else:
            print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
            print("   Using random initialization instead")
    else:
        print("ℹ️  No checkpoint specified, using random initialization")

    model.eval()

    # Example 1: Greedy generation (argmax)
    print(f"\n{'='*60}")
    print("EXAMPLE 1: GREEDY GENERATION (ARGMAX)")
    print(f"{'='*60}")

    # Use Turkish or English prompts based on tokenizer
    if USE_TURKISH_TOKENIZER:
        start_context = "Merhaba, ben"
    else:
        start_context = "Hello, I am"
    print(f"\nInput: '{start_context}'")

    token_ids = text_to_token_ids(start_context, tokenizer)
    print(f"Token IDs shape: {token_ids.shape}")

    generated_ids = generate_text_simple(
        model=model,
        idx=token_ids,
        max_new_tokens=20,
        context_size=args.max_seq_len
    )

    generated_text = token_ids_to_text(generated_ids, tokenizer)
    print(f"\nGenerated: '{generated_text}'")
    print(f"Total tokens: {generated_ids.shape[1]}")

    # Example 2: Sampling with temperature
    print(f"\n{'='*60}")
    print("EXAMPLE 2: SAMPLING WITH TEMPERATURE")
    print(f"{'='*60}")

    if USE_TURKISH_TOKENIZER:
        start_context = "Bir varmış bir yokmuş"
    else:
        start_context = "Once upon a time"
    print(f"\nInput: '{start_context}'")

    token_ids = text_to_token_ids(start_context, tokenizer)

    # Generate with different temperatures
    for temp in [0.5, 1.0, 1.5]:
        print(f"\n--- Temperature: {temp} ---")
        generated_ids = generate_text_with_sampling(
            model=model,
            idx=token_ids.clone(),
            max_new_tokens=20,
            context_size=args.max_seq_len,
            temperature=temp
        )
        generated_text = token_ids_to_text(generated_ids, tokenizer)
        print(f"Generated: '{generated_text}'")

    # Example 3: Top-k sampling
    print(f"\n{'='*60}")
    print("EXAMPLE 3: TOP-K SAMPLING")
    print(f"{'='*60}")

    if USE_TURKISH_TOKENIZER:
        start_context = "Yapay zekanın geleceği"
    else:
        start_context = "The future of AI is"
    print(f"\nInput: '{start_context}'")

    token_ids = text_to_token_ids(start_context, tokenizer)

    generated_ids = generate_text_with_sampling(
        model=model,
        idx=token_ids,
        max_new_tokens=30,
        context_size=args.max_seq_len,
        temperature=0.8,
        top_k=50
    )

    generated_text = token_ids_to_text(generated_ids, tokenizer)
    print(f"Generated: '{generated_text}'")

    print(f"\n{'='*60}")
    print("Generation examples completed!")
    print(f"{'='*60}\n")