mingru / train.py
flpelerin's picture
Update train.py
83c3c6d verified
import os # For file operations
import re # For regex operations
import torch
import math
from transformers import GPT2Tokenizer
from datasets import load_dataset
import numpy as np
import wandb # Import W&B library
from model import minGRULM
from util import generate_text, generate_name
# ============================
# Configuration Parameters
# ============================
dataset_path = 'flpelerin/tinystories-100k'
model_name = generate_name() # Example: "mingru-a14c"
num_epochs = 1
batch_size = 4
seq_length = 256
learning_rate = 1e-4
input_len = 50
num_predict = 250
infer_every = 200
reset_state_every = 16
validate_every = 200
save_every = 500 # Controls checkpointing frequency
# ============================
# Initialize the Device
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Total context size is {batch_size * seq_length} tokens")
# ============================
# Initialize the Tokenizer
# ============================
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
print(f"Tokenizer has {vocab_size} unique tokens")
# ============================
# Load and Preprocess Dataset
# ============================
dataset = load_dataset(dataset_path)
def process_function(examples):
return tokenizer(
examples['text'],
padding='max_length', # Fixed padding
truncation=True,
max_length=(seq_length * batch_size) # Fixed max length
)
tokenized_datasets = dataset.map(process_function, batched=True)
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
# ============================
# Split Dataset into Train and Validation
# ============================
split_dataset = tokenized_datasets['train'].train_test_split(test_size=(1/validate_every))
train_dataset = split_dataset['train']
valid_dataset = split_dataset['test']
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(valid_dataset)}")
# ============================
# Initialize the Model
# ============================
model = minGRULM(
vocab_size = vocab_size,
d_model = 384,
d_inner = 768,
n_layers = 6
)
model.to(device)
parameters_count = sum(p.numel() for p in model.parameters())
print(f"Model has {parameters_count:,} parameters")
# ============================
# Symbolic Link Configuration
# ============================
symlink_path = 'pytorch_model.bin'
def update_symlink(target_path, symlink_path):
"""
Creates or updates a symbolic link pointing to the target path.
Args:
target_path (str): The file path the symlink should point to.
symlink_path (str): The symlink's path.
"""
try:
if os.path.islink(symlink_path) or os.path.exists(symlink_path):
os.remove(symlink_path)
os.symlink(target_path, symlink_path)
print(f"Updated symlink: {symlink_path} -> {target_path}")
except OSError as e:
print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}")
# ============================
# Load Checkpoint from pytorch_model.bin if Exists
# ============================
if os.path.exists(symlink_path):
try:
model.load_state_dict(torch.load(symlink_path, map_location=device))
print(f"Loaded model weights from {symlink_path}")
except Exception as e:
print(f"Error loading model from {symlink_path}: {e}")
print("Starting training from scratch.")
else:
print("No checkpoint found. Starting training from scratch.")
# ============================
# Initialize the Weights and Biases Run
# ============================
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
wandb.init(
project="minGRU-Training",
name=model_name,
config={
"dataset_path": dataset_path,
"num_epochs": num_epochs,
"batch_size": batch_size,
"seq_length": seq_length,
"learning_rate": learning_rate,
"input_len": input_len,
"num_predict": num_predict,
"infer_every": infer_every,
"reset_state_every": reset_state_every,
"validate_every": validate_every,
"save_every": save_every, # Logging the new variable
"dataset_rows": tokenized_datasets['train'].num_rows,
"dataset_token_count": batch_size * seq_length,
"train_set_size": len(train_dataset),
"valid_set_size": len(valid_dataset),
"model_parameters": parameters_count,
"vocab_size": vocab_size,
"d_model": model.d_model,
"d_inner": model.d_inner,
"n_layers": model.n_layers,
"device": str(device),
"model_name": model_name # Log model_name
}
)
# ============================
# Training Loop with Validation and Checkpointing
# ============================
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
h_states = None
step = 0
for epoch in range(num_epochs):
print(f"Starting Epoch {epoch + 1}/{num_epochs}")
for i in range(0, len(train_dataset), batch_size):
batch = train_dataset[i:i + batch_size]
input_ids = torch.tensor(batch['input_ids']).to(device)
# Reset hidden states if needed
if step % reset_state_every == 0:
h_states = None
# Otherwise, keep existing hidden states
optimizer.zero_grad()
try:
_, h_states, loss = model.forward(input_ids, h_states)
loss.backward()
optimizer.step()
except Exception as e:
print(f"Error during training step {step + 1}: {e}")
continue # Skip to the next batch
step += 1
# Compute statistics of hidden states
if h_states is not None:
try:
avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
var_states = torch.var(torch.cat(h_states, dim=0)).item()
except Exception as e:
avg_states = None
var_states = None
else:
avg_states = None
var_states = None
# Log step information
wandb.log({
"loss": loss.item(),
"average_hidden_state": avg_states,
"variance_hidden_state": var_states,
"step": step
})
print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")
# Perform validation at specified intervals
if step % validate_every == 0:
validation_loss = 0.0
valid_steps = 0
with torch.no_grad():
for vi in range(0, len(valid_dataset), batch_size):
val_batch = valid_dataset[vi:vi + batch_size]
val_input_ids = torch.tensor(val_batch['input_ids']).to(device)
# Forward pass
_, _, val_loss = model.forward(val_input_ids, None)
validation_loss += val_loss.item()
valid_steps += 1
avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
# Log validation loss
wandb.log({"validation_loss": avg_validation_loss, "step": step})
print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
# Perform inference at specified steps
if step % infer_every == 0:
with torch.no_grad():
if input_ids.size(1) < input_len:
print("Input length is shorter than input_len. Skipping inference.")
continue
# Select a single input from the current batch for inference
sample_ids = input_ids[0][:input_len]
input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
print(f"Input for Inference: {input_text}")
prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
generated_text = generate_text(model, tokenizer, prompt, num_predict)
print(f"Generated Text:\n{generated_text}\n")
# Optionally, log generated text (e.g., as HTML to preserve formatting)
# wandb.log({"generated_text": wandb.Html(f"<pre>{generated_text}</pre>")}, step=step)
# Perform checkpointing at specified steps
if step % save_every == 0:
step_str = f"{step}k" # Format step with 'k', e.g., '750k'
checkpoint_filename = f"{model_name}-{step_str}.bin"
checkpoint_path = checkpoint_filename
try:
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved model checkpoint at step {step} to {checkpoint_path}")
# Update the symbolic link to point to this checkpoint
update_symlink(checkpoint_path, symlink_path)
# Optionally, log the checkpoint to W&B
# wandb.save(checkpoint_path)
except Exception as e:
print(f"Error saving checkpoint at step {step}: {e}")