File size: 10,021 Bytes
b753304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import yaml, os, time, wandb, random
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from dataclasses import dataclass
from collections import defaultdict
from typing import Optional

try:
    # For running from within src/ directory (script execution)
    from mechanism_base import *
    from model_base import EmbedMLP
except ImportError:
    # For running from parent directory (Jupyter notebooks)
    from src.mechanism_base import *
    from src.model_base import EmbedMLP


########## Configuration Managers ##########
def set_all_seeds(seed):
    """Set all random seeds for reproducibility"""
    random.seed(seed)                          # Python random module
    np.random.seed(seed)                       # NumPy random
    torch.manual_seed(seed)                    # PyTorch CPU random
    torch.cuda.manual_seed(seed)               # PyTorch GPU random (current GPU)
    torch.cuda.manual_seed_all(seed)           # PyTorch GPU random (all GPUs)
    torch.backends.cudnn.deterministic = True  # Make CuDNN deterministic
    torch.backends.cudnn.benchmark = False     # Disable CuDNN benchmarking for reproducibility
    
def read_config():
    current_dir = os.path.dirname(__file__)
    config_path = os.path.join(current_dir, "configs.yaml")
    with open(config_path, 'r') as stream:
        try:
            config = yaml.safe_load(stream)
            return config
        except yaml.YAMLError as exc:
            print(exc)

@dataclass
class Config:
    def __init__(self, config):
        # Ensure that the config dictionary is provided
        if not config:
            raise ValueError("Configuration dictionary cannot be None or empty.")
        
        # Load configurations from the nested dictionary structure and flatten them
        self._flatten_config(config)
        
        # Ensure numeric types are properly converted
        if hasattr(self, 'lr') and isinstance(self.lr, str):
            self.lr = float(self.lr)
        if hasattr(self, 'weight_decay') and isinstance(self.weight_decay, str):
            self.weight_decay = float(self.weight_decay)
        if hasattr(self, 'stopping_thresh') and isinstance(self.stopping_thresh, str):
            self.stopping_thresh = float(self.stopping_thresh)

        # Set d_vocab equal to p if not explicitly specified
        if not hasattr(self, 'd_vocab') or self.d_vocab is None:
            self.d_vocab = self.p
        
        # Set d_model for one_hot embedding type
        if not hasattr(self, 'd_model') or self.d_model is None:
            if hasattr(self, 'embed_type') and self.embed_type == 'one_hot':
                self.d_model = self.d_vocab
            else:
                # For learned embeddings, use a default dimension
                self.d_model = 128
        
        # Set all random seeds for reproducibility
        if hasattr(self, 'seed'):
            set_all_seeds(self.seed)
            print(f"All random seeds set to: {self.seed}")
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)

    def _flatten_config(self, config_dict, parent_key=''):
        """Flatten nested configuration dictionary into flat attributes"""
        for key, value in config_dict.items():
            if isinstance(value, dict):
                # Recursively flatten nested dictionaries
                self._flatten_config(value, parent_key)
            else:
                # Set the attribute directly
                setattr(self, key, value)

    # Property to generate a matrix of random answers (used for 'rand' function)
    @property
    def random_answers(self):
        return np.random.randint(low=0, high=self.p, size=(self.p, self.p))

    # Property to map function names to their corresponding mathematical operations
    @property 
    def fns_dict(self):
        return {
            'add': lambda x, y: (x + y) % self.p, # Addition modulo p
            'subtract': lambda x, y: (x - y) % self.p, # Subtraction modulo p
            'x2xyy2': lambda x, y: (x**2 + x * y + y**2) % self.p, # Polynomial function modulo p
            'rand': lambda x, y: self.random_answers[x][y] # Random value from a precomputed table
        }

    # Property to access the selected function based on 'fn_name'
    @property
    def fn(self):
        return self.fns_dict[self.fn_name]

    # Function to create Boolean arrays indicating if a data point is in the training or test set
    def is_train_is_test(self, train):
        '''Creates an array of Boolean indices according to whether each data point is in train or test.
        Used to index into the big batch of all possible data'''
        # Initialize empty lists for training and test indices
        is_train = []
        is_test = []
        # Iterate over all possible data points (0 <= x, y < p)
        for x in range(self.p):
            for y in range(self.p):
                if (x, y, 113) in train: # If the data point is in the training set
                    is_train.append(True)
                    is_test.append(False)
                else: # Otherwise, it's in the test set
                    is_train.append(False)
                    is_test.append(True)
        # Convert lists to NumPy arrays for efficient indexing
        is_train = np.array(is_train)
        is_test = np.array(is_test)
        return (is_train, is_test)

    # Function to determine if it's time to save the model (based on epoch number)
    def is_it_time_to_save(self, epoch):
        return (epoch % self.save_every == 0)

    # Function to determine if it's time to take metrics (based on epoch number)
    def is_it_time_to_take_metrics(self, epoch):
        return epoch % self.take_metrics_every_n_epochs == 0

    def update_param(self, param_name, value):
        setattr(self, param_name, value)

########## Data Manager ##########
def gen_train_test(config: Config):
    '''Generate train and test split with precomputed labels as tensors'''
    num_to_generate = config.p

    # Create data and labels as lists first
    all_pairs = []
    all_labels = []
    for i in range(num_to_generate):
        for j in range(num_to_generate):
            all_pairs.append((i, j))
            all_labels.append(config.fn(i, j))

    # Convert to tensors on the appropriate device (GPU if available, CPU otherwise)
    device = config.device if hasattr(config, 'device') else torch.device('cpu')
    data_tensor = torch.tensor(all_pairs, device=device, dtype=torch.long)
    labels_tensor = torch.tensor(all_labels, device=device, dtype=torch.long)

    # Shuffle using torch operations for efficiency
    random.seed(config.seed)
    indices = torch.randperm(len(all_pairs), device=device)

    data_tensor = data_tensor[indices]
    labels_tensor = labels_tensor[indices]

    # If frac_train is 1, use the whole dataset for both train and test.
    if config.frac_train == 1:
        return (data_tensor, labels_tensor), (data_tensor, labels_tensor)

    div = int(config.frac_train * len(all_pairs))
    train_data = (data_tensor[:div], labels_tensor[:div])
    test_data = (data_tensor[div:], labels_tensor[div:])

    return train_data, test_data

########## Training Managers ##########
# Trainer class has been moved to NNTrainer.py

########## Loss Definition ##########
def cross_entropy_high_precision(logits, labels):
    # Shapes: batch x vocab, batch
    # Cast logits to float64 because log_softmax has a float32 underflow on overly 
    # confident data and can only return multiples of 1.2e-7 (the smallest float x
    # such that 1+x is different from 1 in float32). This leads to loss spikes 
    # and dodgy gradients
    logprobs = F.log_softmax(logits.to(torch.float32), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1)
    loss = -torch.mean(prediction_logprobs)
    return loss

def full_loss(config : Config, model: EmbedMLP, data):
    '''Takes the cross entropy loss of the model on the data'''
    # Handle new format: data is (data_tensor, labels_tensor)
    if isinstance(data, tuple) and len(data) == 2:
        data_tensor, labels = data
    else:
        # Fallback for old format (list of pairs)
        if not isinstance(data, torch.Tensor):
            data_tensor = torch.tensor(data, device=config.device)
        elif data.device != config.device:
            data_tensor = data.to(config.device)
        else:
            data_tensor = data
        # Compute labels (slow path for backward compatibility)
        labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device)

    # Compute logits
    logits = model(data_tensor)
    return cross_entropy_high_precision(logits, labels)

def acc_rate(logits, labels):
    predictions = torch.argmax(logits, dim=1)  # Get predicted class indices
    correct = (predictions == labels).sum().item()  # Count correct predictions
    accuracy = correct / labels.size(0)  # Calculate accuracy
    return accuracy

def acc(config: Config, model: EmbedMLP, data):
    '''Compute accuracy of the model on the data'''
    # Handle new format: data is (data_tensor, labels_tensor)
    if isinstance(data, tuple) and len(data) == 2:
        data_tensor, labels = data
    else:
        # Fallback for old format (list of pairs)
        if not isinstance(data, torch.Tensor):
            data_tensor = torch.tensor(data, device=config.device)
        elif data.device != config.device:
            data_tensor = data.to(config.device)
        else:
            data_tensor = data
        # Compute labels (slow path for backward compatibility)
        labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device)

    logits = model(data_tensor)
    predictions = torch.argmax(logits, dim=1)  # Get predicted class indices
    correct = (predictions == labels).sum().item()  # Count correct predictions
    accuracy = correct / labels.size(0)  # Calculate accuracy
    return accuracy