File size: 8,875 Bytes
b753304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import yaml, os, time
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from collections import defaultdict
from typing import Optional

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from model_base import EmbedMLP
from utils import Config, gen_train_test, full_loss, acc, cross_entropy_high_precision


class Trainer:
    '''Trainer class for managing the training process of a model'''

    def __init__(self, config: Config, model: Optional[EmbedMLP] = None, use_wandb: bool = True) -> None:
        self.use_wandb = use_wandb and WANDB_AVAILABLE
               
        # Use a given model or initialize a new Transformer model with the provided config
        self.model = model if model is not None else EmbedMLP(
                        d_vocab=config.d_vocab,
                        d_model=config.d_model,
                        d_mlp=config.d_mlp,
                        act_type=config.act_type,
                        use_cache=False,
                        init_type=config.init_type,
                        init_scale=config.init_scale if hasattr(config, 'init_scale') else 0.1,
                        embed_type=config.embed_type
                    )
        self.model.to(config.device)  # Move model to specified device (e.g., GPU)
        if config.optimizer == 'AdamW':
            self.optimizer = optim.AdamW(
                self.model.parameters(),
                lr=config.lr,
                weight_decay=config.weight_decay,
                betas=(0.9, 0.98)
            )

            # Update scheduler with `AdamW` optimizer
            self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1))
        elif config.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=config.lr,
                weight_decay=config.weight_decay  # This applies L2 regularization, equivalent to weight decay in GD
            )

            # You can keep the scheduler as is, if desired
            self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(step / 10, 1))
        
        # Generate a unique run name for this training session
        formatted_time = time.strftime("%m%d%H%M", time.localtime())
        init_scale_str = f"scale_{config.init_scale}" if hasattr(config, 'init_scale') else ""
        self.run_name = f"p_{config.p}_dmlp_{config.d_mlp}_{config.act_type}_{config.init_type}_{init_scale_str}_decay_{config.weight_decay}_{formatted_time}"
        
        # Initialize experiment logging with wandb (Weights and Biases)
        if self.use_wandb:
            wandb.init(project="modular_addition", config=config, name=self.run_name)
        
        # Define the directory where model checkpoints will be saved
        self.save_dir = "saved_models"
        os.makedirs(os.path.join(self.save_dir, self.run_name), exist_ok=True)

        # Generate training and testing datasets
        self.train, self.test = gen_train_test(config=config)

        # Save the training and testing datasets
        train_path = os.path.join(self.save_dir, self.run_name, "train_data.pth")
        test_path = os.path.join(self.save_dir, self.run_name, "test_data.pth")
        torch.save(self.train, train_path)
        torch.save(self.test, test_path)

        # Dictionary to store metrics (train/test losses, etc.)
        self.metrics_dictionary = defaultdict(dict)
        # Handle new tuple format: (data_tensor, labels_tensor)
        train_len = len(self.train[0]) if isinstance(self.train, tuple) else len(self.train)
        test_len = len(self.test[0]) if isinstance(self.test, tuple) else len(self.test)
        print('training length = ', train_len)
        print('testing length = ', test_len)
        
        # Lists to store loss values during training
        self.train_losses = []
        self.test_losses = []
        self.grad_norms = []
        self.param_norms = []
        self.test_accs = []
        self.train_accs = []
        self.config = config

    def save_epoch(self, epoch, save_to_wandb=True, local_save=False):
        '''Save model and training state at the specified epoch'''
        save_dict = {
            'model': self.model.state_dict(),
            'train_loss': self.train_losses[-1],
            'test_loss': self.test_losses[-1],
            'grad_norm': self.grad_norms[-1],
            'param_norm': self.param_norms[-1],
            'test_accuracy': self.test_accs[-1],
            'train_accuracy': self.train_accs[-1],
            'epoch': epoch,
        }
        if save_to_wandb and self.use_wandb:
            wandb.log(save_dict)
            config_dict = {
                k: (str(v) if isinstance(v, torch.device) else v)
                for k, v in self.config.__dict__.items()
            }
            wandb.log(config_dict)
            print("Saved epoch to wandb")
        if self.config.save_models or local_save: 
            # Save model state to a file
            save_path = os.path.join(self.save_dir, self.run_name, f"{epoch}.pth")
            torch.save(save_dict, save_path)
            print(f"Saved model to {save_path}")
        self.metrics_dictionary[epoch].update(save_dict)

    def do_a_training_step(self, epoch: int):
        '''Perform a single training step and return train and test loss'''
        # Calculate training loss on the training data
        train_loss = full_loss(config=self.config, model=self.model, data=self.train)
        
        # Calculate testing loss on the testing data
        test_loss = full_loss(config=self.config, model=self.model, data=self.test)

        # Calculate training loss on the training data
        train_acc = acc(config=self.config, model=self.model, data=self.train)
        
        # Calculate testing loss on the testing data
        test_acc = acc(config=self.config, model=self.model, data=self.test)

        # Append loss values to tracking lists
        self.train_losses.append(train_loss.item())
        self.test_losses.append(test_loss.item())
        self.train_accs.append(train_acc)
        self.test_accs.append(test_acc)
        
        if epoch % 100 == 0:
            # Log progress every 100 epochs
            print(f'Epoch {epoch}, train loss {train_loss.item():.4f}, test loss {test_loss.item():.4f}')

        
        # Backpropagation and optimization step
        train_loss.backward()  # Compute gradients
        # Compute gradient norm and parameter norm
        grad_norm = 0.0
        param_norm = 0.0

        for param in self.model.parameters():
            if param.grad is not None:
                grad_norm += param.grad.norm(2).item()**2  # Sum of squared gradients
            param_norm += param.norm(2).item()**2  # Sum of squared parameters
        self.grad_norms.append(grad_norm**0.5)  # L2 norm of gradients
        self.param_norms.append(param_norm**0.5)  # L2 norm of parameters

        self.optimizer.step()  # Update model parameters
        self.scheduler.step()  # Update learning rate
        self.optimizer.zero_grad()  # Clear gradients
        return train_loss, test_loss

    def initial_save_if_appropriate(self):
        '''Save initial model state and data if configured to do so'''
        if self.config.save_models:
            save_path = os.path.join(self.save_dir, self.run_name, 'init.pth')
            save_dict = {
                'model': self.model.state_dict(),
                'train_data': self.train,  # Now a tuple of (data_tensor, labels_tensor)
                'test_data': self.test     # Now a tuple of (data_tensor, labels_tensor)
            }
            torch.save(save_dict, save_path)

    def post_training_save(self, save_optimizer_and_scheduler=True, log_to_wandb=True):
        '''Save final model state and metrics after training'''
        save_path = os.path.join(self.save_dir, self.run_name, "final.pth")
        save_dict = {
            'model': self.model.state_dict(),
            'train_loss': self.train_losses[-1],
            'test_loss': self.test_losses[-1],
            'train_losses': self.train_losses,
            'test_losses': self.test_losses,
            'grad_norms': self.grad_norms,
            'param_norms': self.param_norms,
            'epoch': self.config.num_epochs,
        }
        if save_optimizer_and_scheduler:
            # Optionally save optimizer and scheduler states
            save_dict['optimizer'] = self.optimizer.state_dict()
            save_dict['scheduler'] = self.scheduler.state_dict()
        if log_to_wandb and self.use_wandb:
            wandb.log(save_dict)
        torch.save(save_dict, save_path)
        print(f"Saved model to {save_path}")
        self.metrics_dictionary[save_dict['epoch']].update(save_dict)