|
|
import yaml, os, time |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from collections import defaultdict |
|
|
from typing import Optional |
|
|
|
|
|
try: |
|
|
import wandb |
|
|
WANDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
WANDB_AVAILABLE = False |
|
|
|
|
|
from model_base import EmbedMLP |
|
|
from utils import Config, gen_train_test, full_loss, acc, cross_entropy_high_precision |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
'''Trainer class for managing the training process of a model''' |
|
|
|
|
|
def __init__(self, config: Config, model: Optional[EmbedMLP] = None, use_wandb: bool = True) -> None: |
|
|
self.use_wandb = use_wandb and WANDB_AVAILABLE |
|
|
|
|
|
|
|
|
self.model = model if model is not None else EmbedMLP( |
|
|
d_vocab=config.d_vocab, |
|
|
d_model=config.d_model, |
|
|
d_mlp=config.d_mlp, |
|
|
act_type=config.act_type, |
|
|
use_cache=False, |
|
|
init_type=config.init_type, |
|
|
init_scale=config.init_scale if hasattr(config, 'init_scale') else 0.1, |
|
|
embed_type=config.embed_type |
|
|
) |
|
|
self.model.to(config.device) |
|
|
if config.optimizer == 'AdamW': |
|
|
self.optimizer = optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=config.lr, |
|
|
weight_decay=config.weight_decay, |
|
|
betas=(0.9, 0.98) |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1)) |
|
|
elif config.optimizer == 'SGD': |
|
|
self.optimizer = optim.SGD( |
|
|
self.model.parameters(), |
|
|
lr=config.lr, |
|
|
weight_decay=config.weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1)) |
|
|
|
|
|
|
|
|
formatted_time = time.strftime("%m%d%H%M", time.localtime()) |
|
|
init_scale_str = f"scale_{config.init_scale}" if hasattr(config, 'init_scale') else "" |
|
|
self.run_name = f"p_{config.p}_dmlp_{config.d_mlp}_{config.act_type}_{config.init_type}_{init_scale_str}_decay_{config.weight_decay}_{formatted_time}" |
|
|
|
|
|
|
|
|
if self.use_wandb: |
|
|
wandb.init(project="modular_addition", config=config, name=self.run_name) |
|
|
|
|
|
|
|
|
self.save_dir = "saved_models" |
|
|
os.makedirs(os.path.join(self.save_dir, self.run_name), exist_ok=True) |
|
|
|
|
|
|
|
|
self.train, self.test = gen_train_test(config=config) |
|
|
|
|
|
|
|
|
train_path = os.path.join(self.save_dir, self.run_name, "train_data.pth") |
|
|
test_path = os.path.join(self.save_dir, self.run_name, "test_data.pth") |
|
|
torch.save(self.train, train_path) |
|
|
torch.save(self.test, test_path) |
|
|
|
|
|
|
|
|
self.metrics_dictionary = defaultdict(dict) |
|
|
|
|
|
train_len = len(self.train[0]) if isinstance(self.train, tuple) else len(self.train) |
|
|
test_len = len(self.test[0]) if isinstance(self.test, tuple) else len(self.test) |
|
|
print('training length = ', train_len) |
|
|
print('testing length = ', test_len) |
|
|
|
|
|
|
|
|
self.train_losses = [] |
|
|
self.test_losses = [] |
|
|
self.grad_norms = [] |
|
|
self.param_norms = [] |
|
|
self.test_accs = [] |
|
|
self.train_accs = [] |
|
|
self.config = config |
|
|
|
|
|
def save_epoch(self, epoch, save_to_wandb=True, local_save=False): |
|
|
'''Save model and training state at the specified epoch''' |
|
|
save_dict = { |
|
|
'model': self.model.state_dict(), |
|
|
'train_loss': self.train_losses[-1], |
|
|
'test_loss': self.test_losses[-1], |
|
|
'grad_norm': self.grad_norms[-1], |
|
|
'param_norm': self.param_norms[-1], |
|
|
'test_accuracy': self.test_accs[-1], |
|
|
'train_accuracy': self.train_accs[-1], |
|
|
'epoch': epoch, |
|
|
} |
|
|
if save_to_wandb and self.use_wandb: |
|
|
wandb.log(save_dict) |
|
|
config_dict = { |
|
|
k: (str(v) if isinstance(v, torch.device) else v) |
|
|
for k, v in self.config.__dict__.items() |
|
|
} |
|
|
wandb.log(config_dict) |
|
|
print("Saved epoch to wandb") |
|
|
if self.config.save_models or local_save: |
|
|
|
|
|
save_path = os.path.join(self.save_dir, self.run_name, f"{epoch}.pth") |
|
|
torch.save(save_dict, save_path) |
|
|
print(f"Saved model to {save_path}") |
|
|
self.metrics_dictionary[epoch].update(save_dict) |
|
|
|
|
|
def do_a_training_step(self, epoch: int): |
|
|
'''Perform a single training step and return train and test loss''' |
|
|
|
|
|
train_loss = full_loss(config=self.config, model=self.model, data=self.train) |
|
|
|
|
|
|
|
|
test_loss = full_loss(config=self.config, model=self.model, data=self.test) |
|
|
|
|
|
|
|
|
train_acc = acc(config=self.config, model=self.model, data=self.train) |
|
|
|
|
|
|
|
|
test_acc = acc(config=self.config, model=self.model, data=self.test) |
|
|
|
|
|
|
|
|
self.train_losses.append(train_loss.item()) |
|
|
self.test_losses.append(test_loss.item()) |
|
|
self.train_accs.append(train_acc) |
|
|
self.test_accs.append(test_acc) |
|
|
|
|
|
if epoch % 100 == 0: |
|
|
|
|
|
print(f'Epoch {epoch}, train loss {train_loss.item():.4f}, test loss {test_loss.item():.4f}') |
|
|
|
|
|
|
|
|
|
|
|
train_loss.backward() |
|
|
|
|
|
grad_norm = 0.0 |
|
|
param_norm = 0.0 |
|
|
|
|
|
for param in self.model.parameters(): |
|
|
if param.grad is not None: |
|
|
grad_norm += param.grad.norm(2).item()**2 |
|
|
param_norm += param.norm(2).item()**2 |
|
|
self.grad_norms.append(grad_norm**0.5) |
|
|
self.param_norms.append(param_norm**0.5) |
|
|
|
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
return train_loss, test_loss |
|
|
|
|
|
def initial_save_if_appropriate(self): |
|
|
'''Save initial model state and data if configured to do so''' |
|
|
if self.config.save_models: |
|
|
save_path = os.path.join(self.save_dir, self.run_name, 'init.pth') |
|
|
save_dict = { |
|
|
'model': self.model.state_dict(), |
|
|
'train_data': self.train, |
|
|
'test_data': self.test |
|
|
} |
|
|
torch.save(save_dict, save_path) |
|
|
|
|
|
def post_training_save(self, save_optimizer_and_scheduler=True, log_to_wandb=True): |
|
|
'''Save final model state and metrics after training''' |
|
|
save_path = os.path.join(self.save_dir, self.run_name, "final.pth") |
|
|
save_dict = { |
|
|
'model': self.model.state_dict(), |
|
|
'train_loss': self.train_losses[-1], |
|
|
'test_loss': self.test_losses[-1], |
|
|
'train_losses': self.train_losses, |
|
|
'test_losses': self.test_losses, |
|
|
'grad_norms': self.grad_norms, |
|
|
'param_norms': self.param_norms, |
|
|
'epoch': self.config.num_epochs, |
|
|
} |
|
|
if save_optimizer_and_scheduler: |
|
|
|
|
|
save_dict['optimizer'] = self.optimizer.state_dict() |
|
|
save_dict['scheduler'] = self.scheduler.state_dict() |
|
|
if log_to_wandb and self.use_wandb: |
|
|
wandb.log(save_dict) |
|
|
torch.save(save_dict, save_path) |
|
|
print(f"Saved model to {save_path}") |
|
|
self.metrics_dictionary[save_dict['epoch']].update(save_dict) |
|
|
|