File size: 8,875 Bytes
b753304 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import yaml, os, time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
from typing import Optional
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
from model_base import EmbedMLP
from utils import Config, gen_train_test, full_loss, acc, cross_entropy_high_precision
class Trainer:
'''Trainer class for managing the training process of a model'''
def __init__(self, config: Config, model: Optional[EmbedMLP] = None, use_wandb: bool = True) -> None:
self.use_wandb = use_wandb and WANDB_AVAILABLE
# Use a given model or initialize a new Transformer model with the provided config
self.model = model if model is not None else EmbedMLP(
d_vocab=config.d_vocab,
d_model=config.d_model,
d_mlp=config.d_mlp,
act_type=config.act_type,
use_cache=False,
init_type=config.init_type,
init_scale=config.init_scale if hasattr(config, 'init_scale') else 0.1,
embed_type=config.embed_type
)
self.model.to(config.device) # Move model to specified device (e.g., GPU)
if config.optimizer == 'AdamW':
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
betas=(0.9, 0.98)
)
# Update scheduler with `AdamW` optimizer
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1))
elif config.optimizer == 'SGD':
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay # This applies L2 regularization, equivalent to weight decay in GD
)
# You can keep the scheduler as is, if desired
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1))
# Generate a unique run name for this training session
formatted_time = time.strftime("%m%d%H%M", time.localtime())
init_scale_str = f"scale_{config.init_scale}" if hasattr(config, 'init_scale') else ""
self.run_name = f"p_{config.p}_dmlp_{config.d_mlp}_{config.act_type}_{config.init_type}_{init_scale_str}_decay_{config.weight_decay}_{formatted_time}"
# Initialize experiment logging with wandb (Weights and Biases)
if self.use_wandb:
wandb.init(project="modular_addition", config=config, name=self.run_name)
# Define the directory where model checkpoints will be saved
self.save_dir = "saved_models"
os.makedirs(os.path.join(self.save_dir, self.run_name), exist_ok=True)
# Generate training and testing datasets
self.train, self.test = gen_train_test(config=config)
# Save the training and testing datasets
train_path = os.path.join(self.save_dir, self.run_name, "train_data.pth")
test_path = os.path.join(self.save_dir, self.run_name, "test_data.pth")
torch.save(self.train, train_path)
torch.save(self.test, test_path)
# Dictionary to store metrics (train/test losses, etc.)
self.metrics_dictionary = defaultdict(dict)
# Handle new tuple format: (data_tensor, labels_tensor)
train_len = len(self.train[0]) if isinstance(self.train, tuple) else len(self.train)
test_len = len(self.test[0]) if isinstance(self.test, tuple) else len(self.test)
print('training length = ', train_len)
print('testing length = ', test_len)
# Lists to store loss values during training
self.train_losses = []
self.test_losses = []
self.grad_norms = []
self.param_norms = []
self.test_accs = []
self.train_accs = []
self.config = config
def save_epoch(self, epoch, save_to_wandb=True, local_save=False):
'''Save model and training state at the specified epoch'''
save_dict = {
'model': self.model.state_dict(),
'train_loss': self.train_losses[-1],
'test_loss': self.test_losses[-1],
'grad_norm': self.grad_norms[-1],
'param_norm': self.param_norms[-1],
'test_accuracy': self.test_accs[-1],
'train_accuracy': self.train_accs[-1],
'epoch': epoch,
}
if save_to_wandb and self.use_wandb:
wandb.log(save_dict)
config_dict = {
k: (str(v) if isinstance(v, torch.device) else v)
for k, v in self.config.__dict__.items()
}
wandb.log(config_dict)
print("Saved epoch to wandb")
if self.config.save_models or local_save:
# Save model state to a file
save_path = os.path.join(self.save_dir, self.run_name, f"{epoch}.pth")
torch.save(save_dict, save_path)
print(f"Saved model to {save_path}")
self.metrics_dictionary[epoch].update(save_dict)
def do_a_training_step(self, epoch: int):
'''Perform a single training step and return train and test loss'''
# Calculate training loss on the training data
train_loss = full_loss(config=self.config, model=self.model, data=self.train)
# Calculate testing loss on the testing data
test_loss = full_loss(config=self.config, model=self.model, data=self.test)
# Calculate training loss on the training data
train_acc = acc(config=self.config, model=self.model, data=self.train)
# Calculate testing loss on the testing data
test_acc = acc(config=self.config, model=self.model, data=self.test)
# Append loss values to tracking lists
self.train_losses.append(train_loss.item())
self.test_losses.append(test_loss.item())
self.train_accs.append(train_acc)
self.test_accs.append(test_acc)
if epoch % 100 == 0:
# Log progress every 100 epochs
print(f'Epoch {epoch}, train loss {train_loss.item():.4f}, test loss {test_loss.item():.4f}')
# Backpropagation and optimization step
train_loss.backward() # Compute gradients
# Compute gradient norm and parameter norm
grad_norm = 0.0
param_norm = 0.0
for param in self.model.parameters():
if param.grad is not None:
grad_norm += param.grad.norm(2).item()**2 # Sum of squared gradients
param_norm += param.norm(2).item()**2 # Sum of squared parameters
self.grad_norms.append(grad_norm**0.5) # L2 norm of gradients
self.param_norms.append(param_norm**0.5) # L2 norm of parameters
self.optimizer.step() # Update model parameters
self.scheduler.step() # Update learning rate
self.optimizer.zero_grad() # Clear gradients
return train_loss, test_loss
def initial_save_if_appropriate(self):
'''Save initial model state and data if configured to do so'''
if self.config.save_models:
save_path = os.path.join(self.save_dir, self.run_name, 'init.pth')
save_dict = {
'model': self.model.state_dict(),
'train_data': self.train, # Now a tuple of (data_tensor, labels_tensor)
'test_data': self.test # Now a tuple of (data_tensor, labels_tensor)
}
torch.save(save_dict, save_path)
def post_training_save(self, save_optimizer_and_scheduler=True, log_to_wandb=True):
'''Save final model state and metrics after training'''
save_path = os.path.join(self.save_dir, self.run_name, "final.pth")
save_dict = {
'model': self.model.state_dict(),
'train_loss': self.train_losses[-1],
'test_loss': self.test_losses[-1],
'train_losses': self.train_losses,
'test_losses': self.test_losses,
'grad_norms': self.grad_norms,
'param_norms': self.param_norms,
'epoch': self.config.num_epochs,
}
if save_optimizer_and_scheduler:
# Optionally save optimizer and scheduler states
save_dict['optimizer'] = self.optimizer.state_dict()
save_dict['scheduler'] = self.scheduler.state_dict()
if log_to_wandb and self.use_wandb:
wandb.log(save_dict)
torch.save(save_dict, save_path)
print(f"Saved model to {save_path}")
self.metrics_dictionary[save_dict['epoch']].update(save_dict)
|