File size: 36,443 Bytes
48b48f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# ==============================================================================
# Single-File Script ~221M Model - Resume Training for ~4 Hours
# ==============================================================================
# --- Necessary Imports ---
import torch
import torch.nn as nn
from dataclasses import dataclass, field
import math
import torch.nn.functional as F
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import os
from tqdm import tqdm
import traceback
# Corrected import: Added IterableDataset AND Dataset
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torch.optim as optim
# Use torch.amp imports (recommended over torch.cuda.amp)
from torch.amp import GradScaler, autocast
from datasets import load_dataset, IterableDataset as HFIterableDataset
import datetime
import random
import matplotlib.pyplot as plt
import glob
import time
import dataclasses # Make sure this is imported

# --- Model Configuration ---
@dataclass
class ModelArgs:
    # --- ~221M Config for 4GB VRAM ---
    hidden_size: int = 768; num_hidden_layers: int = 12; num_attention_heads: int = 12
    num_key_value_heads: int = 12; intermediate_size: int = 2048; vocab_size: int = 128000
    rms_norm_eps: float = 1e-5; rope_theta: float = 500000.0; max_position_embeddings: int = 4096
    head_dim: int = field(init=False)
    add_recency_bias: bool = False # Keep this option if desired

    def __post_init__(self):
        self.head_dim = self.hidden_size // self.num_attention_heads
        if self.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size % num_attention_heads != 0")
        if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("num_attention_heads % num_key_value_heads != 0")

# --- Model Components (RMSNorm, RoPE funcs, Attention, FeedForward, TransformerBlock, Llama) ---
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6): super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
    def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    def forward(self, x): original_dtype = x.dtype; output = self._norm(x.float()).to(original_dtype); return output * self.weight

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | torch.device, theta: float = 10000.0):
    if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE")
    theta_indices = torch.arange(0, head_dim, 2).float(); theta_freqs = 1.0 / (theta**(theta_indices / head_dim))
    target_device = torch.device(device) if isinstance(device, str) else device; theta_freqs = theta_freqs.to(target_device)
    positions = torch.arange(seq_len, device=target_device).float(); freqs = torch.outer(positions, theta_freqs).float(); return freqs, positions

def apply_rotary_embeddings(x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor):
    positions = positions.long(); max_pos = freqs_cis_full.shape[0]
    if torch.max(positions) >= max_pos: positions = torch.clamp(positions, max=max_pos - 1)
    freqs = freqs_cis_full[positions]; freqs = freqs.unsqueeze(0).unsqueeze(2)
    bsz, seq_len, n_part_heads, head_dim = x.shape; x1 = x[..., : head_dim // 2]; x2 = x[..., head_dim // 2 :]
    cos_freqs = torch.cos(freqs).type_as(x); sin_freqs = torch.sin(freqs).type_as(x)
    rotated_x1 = x1 * cos_freqs - x2 * sin_freqs; rotated_x2 = x1 * sin_freqs + x2 * cos_freqs
    rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1); return rotated_x.type_as(x)

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__(); self.args = args; self.num_heads = args.num_attention_heads; self.num_kv_heads = args.num_key_value_heads
        self.head_dim = args.head_dim; self.repeats = self.num_heads // self.num_kv_heads
        self.wq = nn.Linear(args.hidden_size, args.num_attention_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.num_attention_heads * args.head_dim, args.hidden_size, bias=False)
    def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        bsz, n_kv_heads, seqlen, head_dim = x.shape;
        if n_rep == 1: return x
        return (x[:, :, None, :, :].expand(bsz, n_kv_heads, n_rep, seqlen, head_dim).reshape(bsz, n_kv_heads * n_rep, seqlen, head_dim))
    def _create_recency_bias(self, seqlen, full_seqlen, device, dtype, bias_strength=0.1, decay_rate=0.9):
        bias = torch.zeros((1, 1, seqlen, full_seqlen), device=device, dtype=dtype); indices = torch.arange(full_seqlen, device=device)
        rel_pos = torch.arange(seqlen, device=device).unsqueeze(1) - indices.unsqueeze(0); mask = rel_pos >= 0
        decaying_bias = bias_strength * (decay_rate ** (-rel_pos[mask])); bias[:, :, mask] = decaying_bias.type_as(bias); return bias
    def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        bsz, seqlen, _ = x.shape; xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim); xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim); xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
        xq = apply_rotary_embeddings(xq, freqs_cis_full, positions); xk = apply_rotary_embeddings(xk, freqs_cis_full, positions)
        xk = xk.transpose(1, 2); xv = xv.transpose(1, 2)
        if cache is not None: cache_k, cache_v = cache; keys = torch.cat((cache_k.to(xk.device), xk), dim=2); values = torch.cat((cache_v.to(xv.device), xv), dim=2)
        else: keys = xk; values = xv
        updated_cache = (keys.detach(), values.detach()); keys_repeated = self._repeat_kv(keys, self.repeats); values_repeated = self._repeat_kv(values, self.repeats)
        xq = xq.transpose(1, 2); scores = torch.matmul(xq.float(), keys_repeated.transpose(-2, -1).float()) / math.sqrt(self.head_dim)
        if self.args.add_recency_bias:
             full_seqlen = keys_repeated.shape[-2]; recency_bias = self._create_recency_bias(seqlen, full_seqlen, device=scores.device, dtype=scores.dtype); scores = scores + recency_bias
        if mask is not None:
            full_seqlen = keys_repeated.shape[-2]; expected_mask_shape_end = (seqlen, full_seqlen)
            if mask.shape[-2:] != expected_mask_shape_end:
                 try: mask_slice = mask[:, :, -seqlen:, :full_seqlen]; scores = scores + mask_slice.float()
                 except Exception: pass
            else: scores = scores + mask.float()
        scores = nn.functional.softmax(scores, dim=-1).type_as(xq); output = torch.matmul(scores, values_repeated)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1); output = self.wo(output); return output, updated_cache

class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs): super().__init__(); self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs): super().__init__(); self.args = args; self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.attention = Attention(args); self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.feed_forward = FeedForward(args)
    def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        r, cache = self.attention(self.attention_norm(x), freqs_cis_full, positions, mask, cache); h = x + r; r = self.feed_forward(self.ffn_norm(h)); out = h + r; return out, cache

class Llama(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__(); self.args = args; self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size); self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.num_hidden_layers)])
        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.tok_embeddings.weight.requires_grad = True
        freqs_cis, _ = precompute_theta_pos_frequencies(args.head_dim, args.max_position_embeddings, device='cpu', theta=args.rope_theta)
        self.register_buffer("freqs_cis", freqs_cis, persistent=False)
    def forward(self, tokens: torch.Tensor, positions: torch.Tensor):
        bsz, seqlen = tokens.shape; h = self.tok_embeddings(tokens); freqs_cis_full = self.freqs_cis.to(h.device); mask = None
        if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device); mask = torch.triu(mask, diagonal=1).type_as(h)
        positions = positions.to(h.device)
        for layer in self.layers: h, _ = layer(h, freqs_cis_full, positions, mask, cache=None)
        h = self.norm(h); output = F.linear(h, self.tok_embeddings.weight); return output

# --- Generate function (Added Top-P Sampling) ---
@torch.no_grad()
def generate(model: Llama, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None):
    model.eval()
    try: model_device = next(model.parameters()).device; model_dtype = next(model.parameters()).dtype
    except StopIteration: model_device = torch.device("cpu"); model_dtype = torch.float32; print("Warning: Model has no parameters.")
    prompt_ids = tokenizer.encode(prompt, add_special_tokens=True); tokens = torch.tensor([prompt_ids], dtype=torch.long, device=model_device)
    cache = [(torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype),
              torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype))
             for _ in range(model.args.num_hidden_layers)]
    generated_token_ids = []; current_tokens = tokens; print(f"Generating {max_new_tokens} tokens from prompt: '{prompt}'"); print("Output: ", end='')
    full_freqs_cis = model.freqs_cis.to(model_device)
    for i in range(max_new_tokens):
        current_seq_len = current_tokens.shape[1]; start_pos = cache[0][0].shape[2]; positions = torch.arange(start_pos, start_pos + current_seq_len, device=model_device)
        current_mask = None;
        if i == 0 and current_seq_len > 1: current_mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=model_device); current_mask = torch.triu(current_mask, diagonal=1).type(model_dtype)
        h = model.tok_embeddings(current_tokens); updated_cache_list = []
        for layer_idx, layer in enumerate(model.layers): h, updated_layer_cache = layer(h, full_freqs_cis, positions, current_mask, cache[layer_idx]); updated_cache_list.append(updated_layer_cache)
        cache = updated_cache_list; h = model.norm(h); logits = F.linear(h, model.tok_embeddings.weight)
        next_token_logits = logits[:, -1, :]
        if temperature == 0: next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        else:
            next_token_logits = next_token_logits / temperature
            if top_k is not None and top_k > 0: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))); next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
            if top_p is not None and 0.0 < top_p < 1.0:
                probs_for_filter = F.softmax(next_token_logits, dim=-1); probs_sort, probs_idx = torch.sort(probs_for_filter, descending=True); probs_sum = torch.cumsum(probs_sort, dim=-1)
                mask_top_p = probs_sum > top_p; mask_top_p[..., 0] = False; mask_top_p[..., 1:] = mask_top_p[..., :-1].clone(); indices_to_remove = mask_top_p.scatter(1, probs_idx, mask_top_p); next_token_logits[indices_to_remove] = float('-inf')
            probs = F.softmax(next_token_logits, dim=-1); next_token_id = torch.multinomial(probs, num_samples=1)
        if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id: print("\n[EOS token reached]"); break
        next_token_id_item = next_token_id.item(); generated_token_ids.append(next_token_id_item); current_tokens = next_token_id.clone()
        print(tokenizer.decode([next_token_id_item]), end='', flush=True)
        if len(generated_token_ids) >= max_new_tokens: break
    print("\n--- Generation Complete ---"); final_token_ids = prompt_ids + generated_token_ids; full_generated_text = tokenizer.decode(final_token_ids, skip_special_tokens=False)
    print(f"\nFull generated text:\n{full_generated_text}"); return full_generated_text

# --- Dataset Class (Map Style for WikiText) ---
class SimpleLMDataset(Dataset):
    def __init__(self, token_ids: list[int], sequence_length: int):
        self.token_ids = token_ids; self.sequence_length = sequence_length
        self.num_sequences = max(0, len(token_ids) - sequence_length)
        if self.num_sequences == 0: raise ValueError(f"Dataset token count ({len(token_ids)}) not > sequence length ({sequence_length}).")
    def __len__(self): return self.num_sequences
    def __getitem__(self, idx):
        chunk = self.token_ids[idx : idx + self.sequence_length + 1]
        if len(chunk) < self.sequence_length + 1:
             last_valid_idx = len(self.token_ids) - self.sequence_length - 1
             chunk = self.token_ids[last_valid_idx : last_valid_idx + self.sequence_length + 1]
        input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
        return input_ids, target_ids

# --- Dataset Class (Iterable for SlimPajama - Kept for reference/fallback) ---
class TokenizedSequenceDataset(IterableDataset):
    def __init__(self, dataset_name, dataset_config, split, tokenizer, sequence_length, buffer_size=10000):
        try: self.dataset = load_dataset(dataset_name, dataset_config, split=split, streaming=True); print(f"Successfully loaded streaming dataset: {dataset_name} ({split})")
        except Exception as e: raise RuntimeError(f"Failed to load streaming dataset {dataset_name} ({split}): {e}") from e
        self.tokenizer = tokenizer; self.sequence_length = sequence_length; self.buffer_size = buffer_size; self.buffer = []
        try: self.iter_dataset = iter(self.dataset)
        except Exception as e: raise RuntimeError(f"Failed to create iterator for dataset {dataset_name} ({split}): {e}") from e
    def __iter__(self):
        while True:
            while len(self.buffer) < self.sequence_length + 1:
                try:
                    item = next(self.iter_dataset); text = item.get('text', '')
                    if text and text.strip(): token_ids = self.tokenizer.encode(text, add_special_tokens=False); self.buffer.extend(token_ids)
                except StopIteration:
                    if len(self.buffer) < self.sequence_length + 1: return
                    else: break
            if len(self.buffer) < self.sequence_length + 1: return
            chunk = self.buffer[:self.sequence_length + 1]; input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
            yield input_ids, target_ids; self.buffer = self.buffer[1:]

# --- Checkpoint Loading Function ---
def load_checkpoint(checkpoint_dir: str, model: Llama, optimizer, scaler, scheduler, device):
    latest_checkpoint_path = None; highest_step = -1
    if os.path.isdir(checkpoint_dir):
        checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
        for ckpt_path in checkpoints:
            try: step = int(os.path.basename(ckpt_path).split('_')[-1].split('.')[0]);
            except ValueError: continue
            if step > highest_step: highest_step = step; latest_checkpoint_path = ckpt_path
    if latest_checkpoint_path:
        print(f"Loading checkpoint from: {latest_checkpoint_path}")
        try:
            checkpoint = torch.load(latest_checkpoint_path, map_location='cpu', weights_only=False) # Use False for safety
            current_args_dict = model.args.__dict__
            saved_args_data = checkpoint.get('model_args', checkpoint.get('model_args_dict'))
            if not saved_args_data: print("Warning: Checkpoint missing model_args."); saved_args_dict=None; args_match=False
            elif not isinstance(saved_args_data, dict): saved_args_dict = dataclasses.asdict(saved_args_data) # Use imported module
            else: saved_args_dict = saved_args_data
            args_match = True
            if saved_args_dict:
                 for f in dataclasses.fields(ModelArgs): # Use dataclasses.fields
                      if f.init and f.name != 'head_dim':
                           current_val = current_args_dict.get(f.name); saved_val = saved_args_dict.get(f.name)
                           if current_val != saved_val: print(f"Mismatch in arg '{f.name}': Current={current_val}, Saved={saved_val}"); args_match = False; break
            else: args_match = False
            if not args_match: print("ERROR: Model args mismatch. Cannot load checkpoint."); return 0
            model.load_state_dict(checkpoint['model_state_dict']); model.to(device)
            if optimizer is not None:
                try: optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                except Exception as e: print(f"Warning: Could not load optimizer state dict: {e}")
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor): state[k] = v.to(device)
            if scaler is not None:
                try: scaler.load_state_dict(checkpoint['scaler_state_dict'])
                except Exception as e: print(f"Warning: Could not load scaler state dict: {e}")
            if scheduler is not None:
                try: scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                except Exception as e: print(f"Warning: Could not load scheduler state dict: {e}")
            start_step = checkpoint['step']; print(f"Resuming training from step {start_step + 1}"); return start_step
        except Exception as e: print(f"Error loading checkpoint {latest_checkpoint_path}: {e}"); traceback.print_exc(); return 0
    else: print("No checkpoint found. Starting training from scratch."); return 0

# --- Plotting Function ---
def plot_loss(train_losses, val_losses, val_steps_list, checkpoint_dir, start_step=0):
    plt.figure(figsize=(12, 6)); smoothing_window = 10
    train_steps = list(range(start_step + 1, start_step + len(train_losses) + 1))
    plt.plot(train_steps, train_losses, label='Training Loss (Raw)', alpha=0.3)
    if len(train_losses) > smoothing_window:
        train_losses_smoothed = [sum(train_losses[max(0, i-smoothing_window):i+1])/min(i+1, smoothing_window) for i in range(len(train_losses))]
        plt.plot(train_steps, train_losses_smoothed, label=f'Training Loss (Smoothed {smoothing_window} steps)', alpha=0.9)
    if val_losses and val_steps_list: plt.plot(val_steps_list, val_losses, label='Validation Loss', marker='o', linestyle='--')
    plt.xlabel("Optimizer Steps"); plt.ylabel("Loss"); plt.yscale('log'); plt.title("Training and Validation Loss Over Steps")
    plt.legend(); plt.grid(True); plot_filename = f"loss_plot_step_{start_step}_to_{start_step+len(train_losses)}.png"
    plot_path = os.path.join(checkpoint_dir, plot_filename); plt.savefig(plot_path)
    print(f"Loss plot saved to {plot_path}")

# --- Basic Training Function (Single GPU, AMP, LR Schedule, Validation, Checkpointing, Plotting) ---
def simple_train(

    model: Llama, tokenizer: AutoTokenizer, train_dataset: IterableDataset | Dataset, val_dataset: IterableDataset | Dataset | None,

    optimizer: torch.optim.Optimizer, criterion: nn.Module, scheduler,

    num_epochs: int, device: torch.device, gradient_accumulation_steps: int = 1,

    use_amp: bool = False, max_train_steps: int | None = None, start_step: int = 0,

    save_interval: int = 1000, checkpoint_dir: str = ".",

    validation_interval: int = 500, val_steps: int = 50

):
    model.train(); total_steps = start_step; global_step_this_run = 0
    scaler = GradScaler(enabled=use_amp and device.type == 'cuda')
    os.makedirs(checkpoint_dir, exist_ok=True)
    train_loss_history = []; val_loss_history = []; val_steps_history = []
    print(f"\n--- Starting Training (Resuming from step {start_step}, Target Steps: {max_train_steps if max_train_steps else 'N/A'}) ---")
    print(f"--- (AMP: {use_amp and device.type == 'cuda'}) ---")
    is_iterable = isinstance(train_dataset, IterableDataset)
    train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=(not is_iterable))
    if val_dataset: val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0)
    training_complete = False
    # Adjust tqdm total based on remaining steps
    tqdm_total = (max_train_steps - start_step) if max_train_steps is not None else None
    print(f"Starting loop, aiming for {max_train_steps} total steps...")
    # Use total=None for iterable datasets if max_steps not set, as length is unknown
    pbar = tqdm(total=tqdm_total, desc=f"Optim Steps ({start_step}...)")

    # Need to manually track iterations vs optimizer steps
    data_iterator = iter(train_loader)
    accum_count = 0 # Counter for gradient accumulation steps

    while not training_complete:
        # Check if we need to stop before starting the next optimizer step
        if max_train_steps is not None and total_steps >= max_train_steps:
            training_complete = True; break

        # --- Accumulation Loop ---
        accum_loss = 0.0
        optimizer.zero_grad() # Zero gradients at start of accumulation cycle

        for _ in range(gradient_accumulation_steps):
            try:
                input_ids, target_ids = next(data_iterator)
            except StopIteration:
                print("\nDataLoader exhausted within accumulation cycle or epoch.")
                # If loader exhausted before completing max_steps, stop training
                training_complete = True; break # Break inner accum loop

            input_ids = input_ids.to(device); target_ids = target_ids.to(device)
            seqlen = input_ids.shape[1]; positions = torch.arange(seqlen, device=device)

            # Use torch.amp.autocast
            with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
                logits = model(input_ids, positions)
                loss = criterion(logits.view(-1, logits.size(-1)).float(), target_ids.view(-1))
                loss = loss / gradient_accumulation_steps # Normalize loss for accumulation

            scaler.scale(loss).backward()
            accum_loss += loss.item() # Accumulate *normalized* loss item

        if training_complete: break # Exit outer loop if data exhausted

        # --- Optimizer Step ---
        scaler.unscale_(optimizer)
        # Optional: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer); scaler.update(); scheduler.step(); optimizer.zero_grad(set_to_none=True)
        total_steps += 1; global_step_this_run += 1
        pbar.update(1) # Update progress bar by one optimizer step

        # --- Logging ---
        current_loss = accum_loss * gradient_accumulation_steps # Log un-normalized loss for this step
        train_loss_history.append(current_loss)
        # Note: epoch_loss calculation might be less meaningful with iterable dataset and max_steps
        # avg_loss_so_far = sum(train_loss_history[-50:]) / min(len(train_loss_history), 50) # Example: rolling average
        pbar.set_postfix({"Loss": f"{current_loss:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}", "Steps": total_steps})

        # --- Validation ---
        if val_dataset and total_steps % validation_interval == 0 and total_steps > 0:
            model.eval(); val_loss = 0.0; val_batches = 0; print(f"\nRunning validation at step {total_steps}...")
            val_pbar = tqdm(enumerate(val_loader), total=val_steps, desc="Validation")
            with torch.no_grad():
                val_iter = iter(val_loader)
                for val_step in range(val_steps):
                    try:
                        val_input_ids, val_target_ids = next(val_iter)
                        val_input_ids = val_input_ids.to(device); val_target_ids = val_target_ids.to(device)
                        val_seqlen = val_input_ids.shape[1]; val_positions = torch.arange(val_seqlen, device=device)
                        with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
                            val_logits = model(val_input_ids, val_positions)
                            v_loss = criterion(val_logits.view(-1, val_logits.size(-1)).float(), val_target_ids.view(-1))
                        val_loss += v_loss.item(); val_batches += 1; val_pbar.update(1); val_pbar.set_postfix({"Val Loss": f"{val_loss/val_batches:.4f}"})
                    except StopIteration: print("Validation loader exhausted early."); break
            val_pbar.close()
            avg_val_loss = val_loss / val_batches if val_batches > 0 else float('inf')
            val_loss_history.append(avg_val_loss); val_steps_history.append(total_steps)
            print(f"Validation finished. Average Val Loss: {avg_val_loss:.4f}"); model.train()

        # --- Checkpointing ---
        if total_steps % save_interval == 0 and total_steps > 0:
            save_path = os.path.join(checkpoint_dir, f"step_{total_steps}.pt")
            try:
                model_args_dict = dataclasses.asdict(model.args)
                save_content = { 'step': total_steps, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                                 'scaler_state_dict': scaler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'model_args_dict': model_args_dict }
                torch.save(save_content, save_path); print(f"\nCheckpoint saved to {save_path}")
            except Exception as e: print(f"\nError saving checkpoint: {e}")

        # --- Check Max Steps ---
        if max_train_steps is not None and total_steps >= max_train_steps:
            print(f"\nReached max_train_steps ({max_train_steps}). Stopping training."); training_complete = True; break


    pbar.close() # Close pbar if loop finishes naturally
    print("--- Training Finished ---")
    return train_loss_history, val_loss_history, val_steps_history


# --- Main Execution Block ---
if __name__ == "__main__":
    # --- Configuration ---
    config = ModelArgs(add_recency_bias=False) # Use ~221M config
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Model Configuration:\n{config}")
    print(f"Calculated Head Dimension: {config.head_dim}")
    print(f"\nUsing device: {device}")

    # --- Component Tests (Commented out) ---
    """ """

    # --- Tokenizer ---
    print("\n--- Tokenizer Loading ---")
    tokenizer_name = "deepseek-ai/DeepSeek-R1"
    print(f"Loading tokenizer: {tokenizer_name}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
        print("Tokenizer loaded successfully.")
        if tokenizer.vocab_size != config.vocab_size: exit(f"FATAL: Tokenizer vocab size mismatch!")
        else: print(f"Tokenizer vocab size ({tokenizer.vocab_size}) matches model config.")
        if tokenizer.pad_token is None:
            if tokenizer.eos_token: tokenizer.pad_token = tokenizer.eos_token; print(f"Set PAD token to EOS token: {tokenizer.pad_token}")
            else: tokenizer.add_special_tokens({'pad_token': '[PAD]'}); print("Added a generic [PAD] token.")
    except Exception as e: exit(f"Error loading tokenizer '{tokenizer_name}': {e}")

    # --- Training Setup ---
    print("\n--- Training Setup ---")
    train_batch_size = 1
    train_seq_len = 256
    grad_accum_steps = 16
    use_amp_training = True if device.type == 'cuda' else False
    learning_rate = 1e-4 # Lower LR
    num_epochs = 1
    # --- ADJUSTED MAX STEPS for ~4 hour run ---
    max_steps_for_run = 1200 # Absolute target step for this run (start_step + new_steps)
    # --- ADJUSTED Total Scheduler Steps (longer term goal) ---
    total_scheduler_steps = 10000 # Example longer goal
    warmup_steps = 100
    # --- Save to current directory ---
    checkpoint_dir = "."
    save_interval = 200 # Save less frequently
    validation_interval = 100 # Validate less frequently
    val_steps = 20

    # --- Dataset ---
    print("\nLoading and preparing WikiText-2 dataset...")
    train_dataset, val_dataset = None, None
    try:
        # Using WikiText-2 directly
        token_file = "./wikitext2_tokens_128k.pt"
        val_token_file = "./wikitext2_val_tokens_128k.pt"
        force_remake_dataset = False
        if os.path.exists(token_file) and os.path.exists(val_token_file) and not force_remake_dataset:
             print(f"Loading tokenized data from {token_file} and {val_token_file}...")
             all_token_ids = torch.load(token_file)
             all_val_token_ids = torch.load(val_token_file)
             print("Tokenized data loaded.")
        else:
             print("Token files not found or remake forced, processing WikiText-2...")
             print("Processing train split...")
             train_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
             train_full_text = "\n".join([item['text'] for item in train_raw_dataset if item['text'].strip()])
             all_token_ids = tokenizer.encode(train_full_text)
             torch.save(all_token_ids, token_file)
             print(f"Saved tokenized train data ({len(all_token_ids)} tokens) to {token_file}")
             print("Processing validation split...")
             val_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
             val_full_text = "\n".join([item['text'] for item in val_raw_dataset if item['text'].strip()])
             all_val_token_ids = tokenizer.encode(val_full_text)
             torch.save(all_val_token_ids, val_token_file)
             print(f"Saved tokenized validation data ({len(all_val_token_ids)} tokens) to {val_token_file}")

        if len(all_token_ids) <= train_seq_len: exit("Train dataset too short.")
        if len(all_val_token_ids) <= train_seq_len: exit("Validation dataset too short.")
        train_dataset = SimpleLMDataset(all_token_ids, sequence_length=train_seq_len)
        val_dataset = SimpleLMDataset(all_val_token_ids, sequence_length=train_seq_len)
        print("Using WikiText-2 dataset.")
    except Exception as e: exit(f"Dataset error: {e}")

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True) if val_dataset else None
    print(f"DataLoaders created. Training Seq Len: {train_seq_len}")
    print(f"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset) if val_dataset else 0}")

    # --- Model, Optimizer, Scheduler, Loss ---
    train_model = Llama(config).to(device)
    print(f"Training model instantiated ({'float32' if not use_amp_training else 'mixed precision'}). Recency Bias: {config.add_recency_bias}")
    total_params_train = sum(p.numel() for p in train_model.parameters() if p.requires_grad)
    print(f"Total Trainable Parameters: {total_params_train / 1e6:.2f} Million")

    optimizer = optim.AdamW(train_model.parameters(), lr=learning_rate, weight_decay=0.1)
    criterion = nn.CrossEntropyLoss()
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_scheduler_steps
    )
    scaler = GradScaler(enabled=use_amp_training and device.type == 'cuda')
    print(f"Optimizer: AdamW, Base LR: {learning_rate}")
    print(f"Scheduler: Cosine with {warmup_steps} warmup steps up to {total_scheduler_steps} steps")
    print(f"Loss Function: CrossEntropyLoss")

    # --- Load Checkpoint ---
    # Pass optimizer, scaler, scheduler to be loaded
    start_step = load_checkpoint(checkpoint_dir, train_model, optimizer, scaler, scheduler, device)

    # Calculate steps to run in this session
    steps_to_run_this_session = max(0, max_steps_for_run - start_step)
    # The absolute step number to stop at in this run
    current_run_target_step = start_step + steps_to_run_this_session

    if steps_to_run_this_session <= 0:
        print(f"Already completed or exceeded target steps ({max_steps_for_run}). Exiting.")
        exit()

    # --- Run Training ---
    print(f"\n--- Running Training (Will run for {steps_to_run_this_session} steps in this session, target total: {max_steps_for_run}) ---")
    start_time = time.time()
    train_loss_hist, val_loss_hist, val_steps_hist = [], [], []
    try:
        # Pass the absolute target step for this run
        train_loss_hist, val_loss_hist, val_steps_hist = simple_train(
            model=train_model, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset,
            optimizer=optimizer, criterion=criterion, scheduler=scheduler,
            num_epochs=num_epochs, device=device, gradient_accumulation_steps=grad_accum_steps,
            use_amp=use_amp_training, max_train_steps=current_run_target_step, start_step=start_step,
            save_interval=save_interval, checkpoint_dir=checkpoint_dir, # Pass "."
            validation_interval=validation_interval, val_steps=val_steps
        )
        print("\nTraining loop finished.")
        end_time = time.time(); print(f"Training duration for this session: {datetime.timedelta(seconds=int(end_time - start_time))}")

        # --- Plotting ---
        if train_loss_hist:
             # Adjust steps for plotting if resuming
             plot_train_steps = list(range(start_step + 1, start_step + len(train_loss_hist) + 1))
             # Filter validation steps/losses that occurred *during this run*
             plot_val_steps = [s for s in val_steps_history if s >= start_step]
             plot_val_loss = [val_loss_history[i] for i, s in enumerate(val_steps_history) if s >= start_step]
             plot_loss(train_loss_hist, plot_val_loss, plot_val_steps, checkpoint_dir, start_step=start_step) # Pass "."


        # --- Generation After Training ---
        print("\n--- Generation After Training ---")
        train_model.eval()
        if device.type == 'cuda':
             try: train_model = train_model.half(); print("Trained model converted to float16 for generation.")
             except Exception as e: print(f"Could not convert trained model to float16: {e}.")
        test_prompt_after = "The meaning of life is"
        _ = generate(model=train_model, tokenizer=tokenizer, prompt=test_prompt_after, max_new_tokens=60, temperature=0.7, top_k=50, top_p=0.9)
        print("\n(Check if output shows more structure than random)")

    except torch.cuda.OutOfMemoryError: print("\n--- CUDA Out of Memory during Training ---"); print("Try reducing train_seq_len or gradient_accumulation_steps further.")
    except Exception as e: print(f"\nAn error occurred during training: {e}"); traceback.print_exc()

    print("\n--- Script Finished ---")