File size: 36,443 Bytes
48b48f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 |
# ==============================================================================
# Single-File Script ~221M Model - Resume Training for ~4 Hours
# ==============================================================================
# --- Necessary Imports ---
import torch
import torch.nn as nn
from dataclasses import dataclass, field
import math
import torch.nn.functional as F
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import os
from tqdm import tqdm
import traceback
# Corrected import: Added IterableDataset AND Dataset
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torch.optim as optim
# Use torch.amp imports (recommended over torch.cuda.amp)
from torch.amp import GradScaler, autocast
from datasets import load_dataset, IterableDataset as HFIterableDataset
import datetime
import random
import matplotlib.pyplot as plt
import glob
import time
import dataclasses # Make sure this is imported
# --- Model Configuration ---
@dataclass
class ModelArgs:
# --- ~221M Config for 4GB VRAM ---
hidden_size: int = 768; num_hidden_layers: int = 12; num_attention_heads: int = 12
num_key_value_heads: int = 12; intermediate_size: int = 2048; vocab_size: int = 128000
rms_norm_eps: float = 1e-5; rope_theta: float = 500000.0; max_position_embeddings: int = 4096
head_dim: int = field(init=False)
add_recency_bias: bool = False # Keep this option if desired
def __post_init__(self):
self.head_dim = self.hidden_size // self.num_attention_heads
if self.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size % num_attention_heads != 0")
if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("num_attention_heads % num_key_value_heads != 0")
# --- Model Components (RMSNorm, RoPE funcs, Attention, FeedForward, TransformerBlock, Llama) ---
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): original_dtype = x.dtype; output = self._norm(x.float()).to(original_dtype); return output * self.weight
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | torch.device, theta: float = 10000.0):
if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE")
theta_indices = torch.arange(0, head_dim, 2).float(); theta_freqs = 1.0 / (theta**(theta_indices / head_dim))
target_device = torch.device(device) if isinstance(device, str) else device; theta_freqs = theta_freqs.to(target_device)
positions = torch.arange(seq_len, device=target_device).float(); freqs = torch.outer(positions, theta_freqs).float(); return freqs, positions
def apply_rotary_embeddings(x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor):
positions = positions.long(); max_pos = freqs_cis_full.shape[0]
if torch.max(positions) >= max_pos: positions = torch.clamp(positions, max=max_pos - 1)
freqs = freqs_cis_full[positions]; freqs = freqs.unsqueeze(0).unsqueeze(2)
bsz, seq_len, n_part_heads, head_dim = x.shape; x1 = x[..., : head_dim // 2]; x2 = x[..., head_dim // 2 :]
cos_freqs = torch.cos(freqs).type_as(x); sin_freqs = torch.sin(freqs).type_as(x)
rotated_x1 = x1 * cos_freqs - x2 * sin_freqs; rotated_x2 = x1 * sin_freqs + x2 * cos_freqs
rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1); return rotated_x.type_as(x)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__(); self.args = args; self.num_heads = args.num_attention_heads; self.num_kv_heads = args.num_key_value_heads
self.head_dim = args.head_dim; self.repeats = self.num_heads // self.num_kv_heads
self.wq = nn.Linear(args.hidden_size, args.num_attention_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.num_attention_heads * args.head_dim, args.hidden_size, bias=False)
def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
bsz, n_kv_heads, seqlen, head_dim = x.shape;
if n_rep == 1: return x
return (x[:, :, None, :, :].expand(bsz, n_kv_heads, n_rep, seqlen, head_dim).reshape(bsz, n_kv_heads * n_rep, seqlen, head_dim))
def _create_recency_bias(self, seqlen, full_seqlen, device, dtype, bias_strength=0.1, decay_rate=0.9):
bias = torch.zeros((1, 1, seqlen, full_seqlen), device=device, dtype=dtype); indices = torch.arange(full_seqlen, device=device)
rel_pos = torch.arange(seqlen, device=device).unsqueeze(1) - indices.unsqueeze(0); mask = rel_pos >= 0
decaying_bias = bias_strength * (decay_rate ** (-rel_pos[mask])); bias[:, :, mask] = decaying_bias.type_as(bias); return bias
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
bsz, seqlen, _ = x.shape; xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim); xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim); xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
xq = apply_rotary_embeddings(xq, freqs_cis_full, positions); xk = apply_rotary_embeddings(xk, freqs_cis_full, positions)
xk = xk.transpose(1, 2); xv = xv.transpose(1, 2)
if cache is not None: cache_k, cache_v = cache; keys = torch.cat((cache_k.to(xk.device), xk), dim=2); values = torch.cat((cache_v.to(xv.device), xv), dim=2)
else: keys = xk; values = xv
updated_cache = (keys.detach(), values.detach()); keys_repeated = self._repeat_kv(keys, self.repeats); values_repeated = self._repeat_kv(values, self.repeats)
xq = xq.transpose(1, 2); scores = torch.matmul(xq.float(), keys_repeated.transpose(-2, -1).float()) / math.sqrt(self.head_dim)
if self.args.add_recency_bias:
full_seqlen = keys_repeated.shape[-2]; recency_bias = self._create_recency_bias(seqlen, full_seqlen, device=scores.device, dtype=scores.dtype); scores = scores + recency_bias
if mask is not None:
full_seqlen = keys_repeated.shape[-2]; expected_mask_shape_end = (seqlen, full_seqlen)
if mask.shape[-2:] != expected_mask_shape_end:
try: mask_slice = mask[:, :, -seqlen:, :full_seqlen]; scores = scores + mask_slice.float()
except Exception: pass
else: scores = scores + mask.float()
scores = nn.functional.softmax(scores, dim=-1).type_as(xq); output = torch.matmul(scores, values_repeated)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1); output = self.wo(output); return output, updated_cache
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs): super().__init__(); self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): super().__init__(); self.args = args; self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.attention = Attention(args); self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.feed_forward = FeedForward(args)
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
r, cache = self.attention(self.attention_norm(x), freqs_cis_full, positions, mask, cache); h = x + r; r = self.feed_forward(self.ffn_norm(h)); out = h + r; return out, cache
class Llama(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__(); self.args = args; self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size); self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.num_hidden_layers)])
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.tok_embeddings.weight.requires_grad = True
freqs_cis, _ = precompute_theta_pos_frequencies(args.head_dim, args.max_position_embeddings, device='cpu', theta=args.rope_theta)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor):
bsz, seqlen = tokens.shape; h = self.tok_embeddings(tokens); freqs_cis_full = self.freqs_cis.to(h.device); mask = None
if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device); mask = torch.triu(mask, diagonal=1).type_as(h)
positions = positions.to(h.device)
for layer in self.layers: h, _ = layer(h, freqs_cis_full, positions, mask, cache=None)
h = self.norm(h); output = F.linear(h, self.tok_embeddings.weight); return output
# --- Generate function (Added Top-P Sampling) ---
@torch.no_grad()
def generate(model: Llama, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None):
model.eval()
try: model_device = next(model.parameters()).device; model_dtype = next(model.parameters()).dtype
except StopIteration: model_device = torch.device("cpu"); model_dtype = torch.float32; print("Warning: Model has no parameters.")
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True); tokens = torch.tensor([prompt_ids], dtype=torch.long, device=model_device)
cache = [(torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype),
torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype))
for _ in range(model.args.num_hidden_layers)]
generated_token_ids = []; current_tokens = tokens; print(f"Generating {max_new_tokens} tokens from prompt: '{prompt}'"); print("Output: ", end='')
full_freqs_cis = model.freqs_cis.to(model_device)
for i in range(max_new_tokens):
current_seq_len = current_tokens.shape[1]; start_pos = cache[0][0].shape[2]; positions = torch.arange(start_pos, start_pos + current_seq_len, device=model_device)
current_mask = None;
if i == 0 and current_seq_len > 1: current_mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=model_device); current_mask = torch.triu(current_mask, diagonal=1).type(model_dtype)
h = model.tok_embeddings(current_tokens); updated_cache_list = []
for layer_idx, layer in enumerate(model.layers): h, updated_layer_cache = layer(h, full_freqs_cis, positions, current_mask, cache[layer_idx]); updated_cache_list.append(updated_layer_cache)
cache = updated_cache_list; h = model.norm(h); logits = F.linear(h, model.tok_embeddings.weight)
next_token_logits = logits[:, -1, :]
if temperature == 0: next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
next_token_logits = next_token_logits / temperature
if top_k is not None and top_k > 0: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))); next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
if top_p is not None and 0.0 < top_p < 1.0:
probs_for_filter = F.softmax(next_token_logits, dim=-1); probs_sort, probs_idx = torch.sort(probs_for_filter, descending=True); probs_sum = torch.cumsum(probs_sort, dim=-1)
mask_top_p = probs_sum > top_p; mask_top_p[..., 0] = False; mask_top_p[..., 1:] = mask_top_p[..., :-1].clone(); indices_to_remove = mask_top_p.scatter(1, probs_idx, mask_top_p); next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1); next_token_id = torch.multinomial(probs, num_samples=1)
if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id: print("\n[EOS token reached]"); break
next_token_id_item = next_token_id.item(); generated_token_ids.append(next_token_id_item); current_tokens = next_token_id.clone()
print(tokenizer.decode([next_token_id_item]), end='', flush=True)
if len(generated_token_ids) >= max_new_tokens: break
print("\n--- Generation Complete ---"); final_token_ids = prompt_ids + generated_token_ids; full_generated_text = tokenizer.decode(final_token_ids, skip_special_tokens=False)
print(f"\nFull generated text:\n{full_generated_text}"); return full_generated_text
# --- Dataset Class (Map Style for WikiText) ---
class SimpleLMDataset(Dataset):
def __init__(self, token_ids: list[int], sequence_length: int):
self.token_ids = token_ids; self.sequence_length = sequence_length
self.num_sequences = max(0, len(token_ids) - sequence_length)
if self.num_sequences == 0: raise ValueError(f"Dataset token count ({len(token_ids)}) not > sequence length ({sequence_length}).")
def __len__(self): return self.num_sequences
def __getitem__(self, idx):
chunk = self.token_ids[idx : idx + self.sequence_length + 1]
if len(chunk) < self.sequence_length + 1:
last_valid_idx = len(self.token_ids) - self.sequence_length - 1
chunk = self.token_ids[last_valid_idx : last_valid_idx + self.sequence_length + 1]
input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
return input_ids, target_ids
# --- Dataset Class (Iterable for SlimPajama - Kept for reference/fallback) ---
class TokenizedSequenceDataset(IterableDataset):
def __init__(self, dataset_name, dataset_config, split, tokenizer, sequence_length, buffer_size=10000):
try: self.dataset = load_dataset(dataset_name, dataset_config, split=split, streaming=True); print(f"Successfully loaded streaming dataset: {dataset_name} ({split})")
except Exception as e: raise RuntimeError(f"Failed to load streaming dataset {dataset_name} ({split}): {e}") from e
self.tokenizer = tokenizer; self.sequence_length = sequence_length; self.buffer_size = buffer_size; self.buffer = []
try: self.iter_dataset = iter(self.dataset)
except Exception as e: raise RuntimeError(f"Failed to create iterator for dataset {dataset_name} ({split}): {e}") from e
def __iter__(self):
while True:
while len(self.buffer) < self.sequence_length + 1:
try:
item = next(self.iter_dataset); text = item.get('text', '')
if text and text.strip(): token_ids = self.tokenizer.encode(text, add_special_tokens=False); self.buffer.extend(token_ids)
except StopIteration:
if len(self.buffer) < self.sequence_length + 1: return
else: break
if len(self.buffer) < self.sequence_length + 1: return
chunk = self.buffer[:self.sequence_length + 1]; input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
yield input_ids, target_ids; self.buffer = self.buffer[1:]
# --- Checkpoint Loading Function ---
def load_checkpoint(checkpoint_dir: str, model: Llama, optimizer, scaler, scheduler, device):
latest_checkpoint_path = None; highest_step = -1
if os.path.isdir(checkpoint_dir):
checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
for ckpt_path in checkpoints:
try: step = int(os.path.basename(ckpt_path).split('_')[-1].split('.')[0]);
except ValueError: continue
if step > highest_step: highest_step = step; latest_checkpoint_path = ckpt_path
if latest_checkpoint_path:
print(f"Loading checkpoint from: {latest_checkpoint_path}")
try:
checkpoint = torch.load(latest_checkpoint_path, map_location='cpu', weights_only=False) # Use False for safety
current_args_dict = model.args.__dict__
saved_args_data = checkpoint.get('model_args', checkpoint.get('model_args_dict'))
if not saved_args_data: print("Warning: Checkpoint missing model_args."); saved_args_dict=None; args_match=False
elif not isinstance(saved_args_data, dict): saved_args_dict = dataclasses.asdict(saved_args_data) # Use imported module
else: saved_args_dict = saved_args_data
args_match = True
if saved_args_dict:
for f in dataclasses.fields(ModelArgs): # Use dataclasses.fields
if f.init and f.name != 'head_dim':
current_val = current_args_dict.get(f.name); saved_val = saved_args_dict.get(f.name)
if current_val != saved_val: print(f"Mismatch in arg '{f.name}': Current={current_val}, Saved={saved_val}"); args_match = False; break
else: args_match = False
if not args_match: print("ERROR: Model args mismatch. Cannot load checkpoint."); return 0
model.load_state_dict(checkpoint['model_state_dict']); model.to(device)
if optimizer is not None:
try: optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
except Exception as e: print(f"Warning: Could not load optimizer state dict: {e}")
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor): state[k] = v.to(device)
if scaler is not None:
try: scaler.load_state_dict(checkpoint['scaler_state_dict'])
except Exception as e: print(f"Warning: Could not load scaler state dict: {e}")
if scheduler is not None:
try: scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
except Exception as e: print(f"Warning: Could not load scheduler state dict: {e}")
start_step = checkpoint['step']; print(f"Resuming training from step {start_step + 1}"); return start_step
except Exception as e: print(f"Error loading checkpoint {latest_checkpoint_path}: {e}"); traceback.print_exc(); return 0
else: print("No checkpoint found. Starting training from scratch."); return 0
# --- Plotting Function ---
def plot_loss(train_losses, val_losses, val_steps_list, checkpoint_dir, start_step=0):
plt.figure(figsize=(12, 6)); smoothing_window = 10
train_steps = list(range(start_step + 1, start_step + len(train_losses) + 1))
plt.plot(train_steps, train_losses, label='Training Loss (Raw)', alpha=0.3)
if len(train_losses) > smoothing_window:
train_losses_smoothed = [sum(train_losses[max(0, i-smoothing_window):i+1])/min(i+1, smoothing_window) for i in range(len(train_losses))]
plt.plot(train_steps, train_losses_smoothed, label=f'Training Loss (Smoothed {smoothing_window} steps)', alpha=0.9)
if val_losses and val_steps_list: plt.plot(val_steps_list, val_losses, label='Validation Loss', marker='o', linestyle='--')
plt.xlabel("Optimizer Steps"); plt.ylabel("Loss"); plt.yscale('log'); plt.title("Training and Validation Loss Over Steps")
plt.legend(); plt.grid(True); plot_filename = f"loss_plot_step_{start_step}_to_{start_step+len(train_losses)}.png"
plot_path = os.path.join(checkpoint_dir, plot_filename); plt.savefig(plot_path)
print(f"Loss plot saved to {plot_path}")
# --- Basic Training Function (Single GPU, AMP, LR Schedule, Validation, Checkpointing, Plotting) ---
def simple_train(
model: Llama, tokenizer: AutoTokenizer, train_dataset: IterableDataset | Dataset, val_dataset: IterableDataset | Dataset | None,
optimizer: torch.optim.Optimizer, criterion: nn.Module, scheduler,
num_epochs: int, device: torch.device, gradient_accumulation_steps: int = 1,
use_amp: bool = False, max_train_steps: int | None = None, start_step: int = 0,
save_interval: int = 1000, checkpoint_dir: str = ".",
validation_interval: int = 500, val_steps: int = 50
):
model.train(); total_steps = start_step; global_step_this_run = 0
scaler = GradScaler(enabled=use_amp and device.type == 'cuda')
os.makedirs(checkpoint_dir, exist_ok=True)
train_loss_history = []; val_loss_history = []; val_steps_history = []
print(f"\n--- Starting Training (Resuming from step {start_step}, Target Steps: {max_train_steps if max_train_steps else 'N/A'}) ---")
print(f"--- (AMP: {use_amp and device.type == 'cuda'}) ---")
is_iterable = isinstance(train_dataset, IterableDataset)
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=(not is_iterable))
if val_dataset: val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0)
training_complete = False
# Adjust tqdm total based on remaining steps
tqdm_total = (max_train_steps - start_step) if max_train_steps is not None else None
print(f"Starting loop, aiming for {max_train_steps} total steps...")
# Use total=None for iterable datasets if max_steps not set, as length is unknown
pbar = tqdm(total=tqdm_total, desc=f"Optim Steps ({start_step}...)")
# Need to manually track iterations vs optimizer steps
data_iterator = iter(train_loader)
accum_count = 0 # Counter for gradient accumulation steps
while not training_complete:
# Check if we need to stop before starting the next optimizer step
if max_train_steps is not None and total_steps >= max_train_steps:
training_complete = True; break
# --- Accumulation Loop ---
accum_loss = 0.0
optimizer.zero_grad() # Zero gradients at start of accumulation cycle
for _ in range(gradient_accumulation_steps):
try:
input_ids, target_ids = next(data_iterator)
except StopIteration:
print("\nDataLoader exhausted within accumulation cycle or epoch.")
# If loader exhausted before completing max_steps, stop training
training_complete = True; break # Break inner accum loop
input_ids = input_ids.to(device); target_ids = target_ids.to(device)
seqlen = input_ids.shape[1]; positions = torch.arange(seqlen, device=device)
# Use torch.amp.autocast
with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
logits = model(input_ids, positions)
loss = criterion(logits.view(-1, logits.size(-1)).float(), target_ids.view(-1))
loss = loss / gradient_accumulation_steps # Normalize loss for accumulation
scaler.scale(loss).backward()
accum_loss += loss.item() # Accumulate *normalized* loss item
if training_complete: break # Exit outer loop if data exhausted
# --- Optimizer Step ---
scaler.unscale_(optimizer)
# Optional: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer); scaler.update(); scheduler.step(); optimizer.zero_grad(set_to_none=True)
total_steps += 1; global_step_this_run += 1
pbar.update(1) # Update progress bar by one optimizer step
# --- Logging ---
current_loss = accum_loss * gradient_accumulation_steps # Log un-normalized loss for this step
train_loss_history.append(current_loss)
# Note: epoch_loss calculation might be less meaningful with iterable dataset and max_steps
# avg_loss_so_far = sum(train_loss_history[-50:]) / min(len(train_loss_history), 50) # Example: rolling average
pbar.set_postfix({"Loss": f"{current_loss:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}", "Steps": total_steps})
# --- Validation ---
if val_dataset and total_steps % validation_interval == 0 and total_steps > 0:
model.eval(); val_loss = 0.0; val_batches = 0; print(f"\nRunning validation at step {total_steps}...")
val_pbar = tqdm(enumerate(val_loader), total=val_steps, desc="Validation")
with torch.no_grad():
val_iter = iter(val_loader)
for val_step in range(val_steps):
try:
val_input_ids, val_target_ids = next(val_iter)
val_input_ids = val_input_ids.to(device); val_target_ids = val_target_ids.to(device)
val_seqlen = val_input_ids.shape[1]; val_positions = torch.arange(val_seqlen, device=device)
with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
val_logits = model(val_input_ids, val_positions)
v_loss = criterion(val_logits.view(-1, val_logits.size(-1)).float(), val_target_ids.view(-1))
val_loss += v_loss.item(); val_batches += 1; val_pbar.update(1); val_pbar.set_postfix({"Val Loss": f"{val_loss/val_batches:.4f}"})
except StopIteration: print("Validation loader exhausted early."); break
val_pbar.close()
avg_val_loss = val_loss / val_batches if val_batches > 0 else float('inf')
val_loss_history.append(avg_val_loss); val_steps_history.append(total_steps)
print(f"Validation finished. Average Val Loss: {avg_val_loss:.4f}"); model.train()
# --- Checkpointing ---
if total_steps % save_interval == 0 and total_steps > 0:
save_path = os.path.join(checkpoint_dir, f"step_{total_steps}.pt")
try:
model_args_dict = dataclasses.asdict(model.args)
save_content = { 'step': total_steps, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'model_args_dict': model_args_dict }
torch.save(save_content, save_path); print(f"\nCheckpoint saved to {save_path}")
except Exception as e: print(f"\nError saving checkpoint: {e}")
# --- Check Max Steps ---
if max_train_steps is not None and total_steps >= max_train_steps:
print(f"\nReached max_train_steps ({max_train_steps}). Stopping training."); training_complete = True; break
pbar.close() # Close pbar if loop finishes naturally
print("--- Training Finished ---")
return train_loss_history, val_loss_history, val_steps_history
# --- Main Execution Block ---
if __name__ == "__main__":
# --- Configuration ---
config = ModelArgs(add_recency_bias=False) # Use ~221M config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Model Configuration:\n{config}")
print(f"Calculated Head Dimension: {config.head_dim}")
print(f"\nUsing device: {device}")
# --- Component Tests (Commented out) ---
""" """
# --- Tokenizer ---
print("\n--- Tokenizer Loading ---")
tokenizer_name = "deepseek-ai/DeepSeek-R1"
print(f"Loading tokenizer: {tokenizer_name}")
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
print("Tokenizer loaded successfully.")
if tokenizer.vocab_size != config.vocab_size: exit(f"FATAL: Tokenizer vocab size mismatch!")
else: print(f"Tokenizer vocab size ({tokenizer.vocab_size}) matches model config.")
if tokenizer.pad_token is None:
if tokenizer.eos_token: tokenizer.pad_token = tokenizer.eos_token; print(f"Set PAD token to EOS token: {tokenizer.pad_token}")
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'}); print("Added a generic [PAD] token.")
except Exception as e: exit(f"Error loading tokenizer '{tokenizer_name}': {e}")
# --- Training Setup ---
print("\n--- Training Setup ---")
train_batch_size = 1
train_seq_len = 256
grad_accum_steps = 16
use_amp_training = True if device.type == 'cuda' else False
learning_rate = 1e-4 # Lower LR
num_epochs = 1
# --- ADJUSTED MAX STEPS for ~4 hour run ---
max_steps_for_run = 1200 # Absolute target step for this run (start_step + new_steps)
# --- ADJUSTED Total Scheduler Steps (longer term goal) ---
total_scheduler_steps = 10000 # Example longer goal
warmup_steps = 100
# --- Save to current directory ---
checkpoint_dir = "."
save_interval = 200 # Save less frequently
validation_interval = 100 # Validate less frequently
val_steps = 20
# --- Dataset ---
print("\nLoading and preparing WikiText-2 dataset...")
train_dataset, val_dataset = None, None
try:
# Using WikiText-2 directly
token_file = "./wikitext2_tokens_128k.pt"
val_token_file = "./wikitext2_val_tokens_128k.pt"
force_remake_dataset = False
if os.path.exists(token_file) and os.path.exists(val_token_file) and not force_remake_dataset:
print(f"Loading tokenized data from {token_file} and {val_token_file}...")
all_token_ids = torch.load(token_file)
all_val_token_ids = torch.load(val_token_file)
print("Tokenized data loaded.")
else:
print("Token files not found or remake forced, processing WikiText-2...")
print("Processing train split...")
train_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
train_full_text = "\n".join([item['text'] for item in train_raw_dataset if item['text'].strip()])
all_token_ids = tokenizer.encode(train_full_text)
torch.save(all_token_ids, token_file)
print(f"Saved tokenized train data ({len(all_token_ids)} tokens) to {token_file}")
print("Processing validation split...")
val_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
val_full_text = "\n".join([item['text'] for item in val_raw_dataset if item['text'].strip()])
all_val_token_ids = tokenizer.encode(val_full_text)
torch.save(all_val_token_ids, val_token_file)
print(f"Saved tokenized validation data ({len(all_val_token_ids)} tokens) to {val_token_file}")
if len(all_token_ids) <= train_seq_len: exit("Train dataset too short.")
if len(all_val_token_ids) <= train_seq_len: exit("Validation dataset too short.")
train_dataset = SimpleLMDataset(all_token_ids, sequence_length=train_seq_len)
val_dataset = SimpleLMDataset(all_val_token_ids, sequence_length=train_seq_len)
print("Using WikiText-2 dataset.")
except Exception as e: exit(f"Dataset error: {e}")
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True) if val_dataset else None
print(f"DataLoaders created. Training Seq Len: {train_seq_len}")
print(f"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset) if val_dataset else 0}")
# --- Model, Optimizer, Scheduler, Loss ---
train_model = Llama(config).to(device)
print(f"Training model instantiated ({'float32' if not use_amp_training else 'mixed precision'}). Recency Bias: {config.add_recency_bias}")
total_params_train = sum(p.numel() for p in train_model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params_train / 1e6:.2f} Million")
optimizer = optim.AdamW(train_model.parameters(), lr=learning_rate, weight_decay=0.1)
criterion = nn.CrossEntropyLoss()
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_scheduler_steps
)
scaler = GradScaler(enabled=use_amp_training and device.type == 'cuda')
print(f"Optimizer: AdamW, Base LR: {learning_rate}")
print(f"Scheduler: Cosine with {warmup_steps} warmup steps up to {total_scheduler_steps} steps")
print(f"Loss Function: CrossEntropyLoss")
# --- Load Checkpoint ---
# Pass optimizer, scaler, scheduler to be loaded
start_step = load_checkpoint(checkpoint_dir, train_model, optimizer, scaler, scheduler, device)
# Calculate steps to run in this session
steps_to_run_this_session = max(0, max_steps_for_run - start_step)
# The absolute step number to stop at in this run
current_run_target_step = start_step + steps_to_run_this_session
if steps_to_run_this_session <= 0:
print(f"Already completed or exceeded target steps ({max_steps_for_run}). Exiting.")
exit()
# --- Run Training ---
print(f"\n--- Running Training (Will run for {steps_to_run_this_session} steps in this session, target total: {max_steps_for_run}) ---")
start_time = time.time()
train_loss_hist, val_loss_hist, val_steps_hist = [], [], []
try:
# Pass the absolute target step for this run
train_loss_hist, val_loss_hist, val_steps_hist = simple_train(
model=train_model, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset,
optimizer=optimizer, criterion=criterion, scheduler=scheduler,
num_epochs=num_epochs, device=device, gradient_accumulation_steps=grad_accum_steps,
use_amp=use_amp_training, max_train_steps=current_run_target_step, start_step=start_step,
save_interval=save_interval, checkpoint_dir=checkpoint_dir, # Pass "."
validation_interval=validation_interval, val_steps=val_steps
)
print("\nTraining loop finished.")
end_time = time.time(); print(f"Training duration for this session: {datetime.timedelta(seconds=int(end_time - start_time))}")
# --- Plotting ---
if train_loss_hist:
# Adjust steps for plotting if resuming
plot_train_steps = list(range(start_step + 1, start_step + len(train_loss_hist) + 1))
# Filter validation steps/losses that occurred *during this run*
plot_val_steps = [s for s in val_steps_history if s >= start_step]
plot_val_loss = [val_loss_history[i] for i, s in enumerate(val_steps_history) if s >= start_step]
plot_loss(train_loss_hist, plot_val_loss, plot_val_steps, checkpoint_dir, start_step=start_step) # Pass "."
# --- Generation After Training ---
print("\n--- Generation After Training ---")
train_model.eval()
if device.type == 'cuda':
try: train_model = train_model.half(); print("Trained model converted to float16 for generation.")
except Exception as e: print(f"Could not convert trained model to float16: {e}.")
test_prompt_after = "The meaning of life is"
_ = generate(model=train_model, tokenizer=tokenizer, prompt=test_prompt_after, max_new_tokens=60, temperature=0.7, top_k=50, top_p=0.9)
print("\n(Check if output shows more structure than random)")
except torch.cuda.OutOfMemoryError: print("\n--- CUDA Out of Memory during Training ---"); print("Try reducing train_seq_len or gradient_accumulation_steps further.")
except Exception as e: print(f"\nAn error occurred during training: {e}"); traceback.print_exc()
print("\n--- Script Finished ---") |