import os # For file operations import re # For regex operations import torch import math from transformers import GPT2Tokenizer from datasets import load_dataset import numpy as np import wandb # Import W&B library from model import minGRULM from util import generate_text, generate_name # ============================ # Configuration Parameters # ============================ dataset_path = 'flpelerin/tinystories-100k' model_name = generate_name() # Example: "mingru-a14c" num_epochs = 1 batch_size = 4 seq_length = 256 learning_rate = 1e-4 input_len = 50 num_predict = 250 infer_every = 200 reset_state_every = 16 validate_every = 200 save_every = 500 # Controls checkpointing frequency # ============================ # Initialize the Device # ============================ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Total context size is {batch_size * seq_length} tokens") # ============================ # Initialize the Tokenizer # ============================ tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token vocab_size = tokenizer.vocab_size print(f"Tokenizer has {vocab_size} unique tokens") # ============================ # Load and Preprocess Dataset # ============================ dataset = load_dataset(dataset_path) def process_function(examples): return tokenizer( examples['text'], padding='max_length', # Fixed padding truncation=True, max_length=(seq_length * batch_size) # Fixed max length ) tokenized_datasets = dataset.map(process_function, batched=True) print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens") # ============================ # Split Dataset into Train and Validation # ============================ split_dataset = tokenized_datasets['train'].train_test_split(test_size=(1/validate_every)) train_dataset = split_dataset['train'] valid_dataset = split_dataset['test'] print(f"Training set size: {len(train_dataset)}") print(f"Validation set size: {len(valid_dataset)}") # ============================ # Initialize the Model # ============================ model = minGRULM( vocab_size = vocab_size, d_model = 384, d_inner = 768, n_layers = 6 ) model.to(device) parameters_count = sum(p.numel() for p in model.parameters()) print(f"Model has {parameters_count:,} parameters") # ============================ # Symbolic Link Configuration # ============================ symlink_path = 'pytorch_model.bin' def update_symlink(target_path, symlink_path): """ Creates or updates a symbolic link pointing to the target path. Args: target_path (str): The file path the symlink should point to. symlink_path (str): The symlink's path. """ try: if os.path.islink(symlink_path) or os.path.exists(symlink_path): os.remove(symlink_path) os.symlink(target_path, symlink_path) print(f"Updated symlink: {symlink_path} -> {target_path}") except OSError as e: print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}") # ============================ # Load Checkpoint from pytorch_model.bin if Exists # ============================ if os.path.exists(symlink_path): try: model.load_state_dict(torch.load(symlink_path, map_location=device)) print(f"Loaded model weights from {symlink_path}") except Exception as e: print(f"Error loading model from {symlink_path}: {e}") print("Starting training from scratch.") else: print("No checkpoint found. Starting training from scratch.") # ============================ # Initialize the Weights and Biases Run # ============================ wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642") wandb.init( project="minGRU-Training", name=model_name, config={ "dataset_path": dataset_path, "num_epochs": num_epochs, "batch_size": batch_size, "seq_length": seq_length, "learning_rate": learning_rate, "input_len": input_len, "num_predict": num_predict, "infer_every": infer_every, "reset_state_every": reset_state_every, "validate_every": validate_every, "save_every": save_every, # Logging the new variable "dataset_rows": tokenized_datasets['train'].num_rows, "dataset_token_count": batch_size * seq_length, "train_set_size": len(train_dataset), "valid_set_size": len(valid_dataset), "model_parameters": parameters_count, "vocab_size": vocab_size, "d_model": model.d_model, "d_inner": model.d_inner, "n_layers": model.n_layers, "device": str(device), "model_name": model_name # Log model_name } ) # ============================ # Training Loop with Validation and Checkpointing # ============================ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) h_states = None step = 0 for epoch in range(num_epochs): print(f"Starting Epoch {epoch + 1}/{num_epochs}") for i in range(0, len(train_dataset), batch_size): batch = train_dataset[i:i + batch_size] input_ids = torch.tensor(batch['input_ids']).to(device) # Reset hidden states if needed if step % reset_state_every == 0: h_states = None # Otherwise, keep existing hidden states optimizer.zero_grad() try: _, h_states, loss = model.forward(input_ids, h_states) loss.backward() optimizer.step() except Exception as e: print(f"Error during training step {step + 1}: {e}") continue # Skip to the next batch step += 1 # Compute statistics of hidden states if h_states is not None: try: avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states) var_states = torch.var(torch.cat(h_states, dim=0)).item() except Exception as e: avg_states = None var_states = None else: avg_states = None var_states = None # Log step information wandb.log({ "loss": loss.item(), "average_hidden_state": avg_states, "variance_hidden_state": var_states, "step": step }) print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}") # Perform validation at specified intervals if step % validate_every == 0: validation_loss = 0.0 valid_steps = 0 with torch.no_grad(): for vi in range(0, len(valid_dataset), batch_size): val_batch = valid_dataset[vi:vi + batch_size] val_input_ids = torch.tensor(val_batch['input_ids']).to(device) # Forward pass _, _, val_loss = model.forward(val_input_ids, None) validation_loss += val_loss.item() valid_steps += 1 avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf') # Log validation loss wandb.log({"validation_loss": avg_validation_loss, "step": step}) print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----") # Perform inference at specified steps if step % infer_every == 0: with torch.no_grad(): if input_ids.size(1) < input_len: print("Input length is shorter than input_len. Skipping inference.") continue # Select a single input from the current batch for inference sample_ids = input_ids[0][:input_len] input_text = tokenizer.decode(sample_ids, skip_special_tokens=True) print(f"Input for Inference: {input_text}") prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len] generated_text = generate_text(model, tokenizer, prompt, num_predict) print(f"Generated Text:\n{generated_text}\n") # Optionally, log generated text (e.g., as HTML to preserve formatting) # wandb.log({"generated_text": wandb.Html(f"
{generated_text}")}, step=step)
# Perform checkpointing at specified steps
if step % save_every == 0:
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
checkpoint_filename = f"{model_name}-{step_str}.bin"
checkpoint_path = checkpoint_filename
try:
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
# Update the symbolic link to point to this checkpoint
update_symlink(checkpoint_path, symlink_path)
# Optionally, log the checkpoint to W&B
# wandb.save(checkpoint_path)
except Exception as e:
print(f"Error saving checkpoint at step {step}: {e}")