|
|
import yaml, os, time, wandb, random |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from collections import defaultdict |
|
|
from typing import Optional |
|
|
|
|
|
try: |
|
|
|
|
|
from mechanism_base import * |
|
|
from model_base import EmbedMLP |
|
|
except ImportError: |
|
|
|
|
|
from src.mechanism_base import * |
|
|
from src.model_base import EmbedMLP |
|
|
|
|
|
|
|
|
|
|
|
def set_all_seeds(seed): |
|
|
"""Set all random seeds for reproducibility""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def read_config(): |
|
|
current_dir = os.path.dirname(__file__) |
|
|
config_path = os.path.join(current_dir, "configs.yaml") |
|
|
with open(config_path, 'r') as stream: |
|
|
try: |
|
|
config = yaml.safe_load(stream) |
|
|
return config |
|
|
except yaml.YAMLError as exc: |
|
|
print(exc) |
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
def __init__(self, config): |
|
|
|
|
|
if not config: |
|
|
raise ValueError("Configuration dictionary cannot be None or empty.") |
|
|
|
|
|
|
|
|
self._flatten_config(config) |
|
|
|
|
|
|
|
|
if hasattr(self, 'lr') and isinstance(self.lr, str): |
|
|
self.lr = float(self.lr) |
|
|
if hasattr(self, 'weight_decay') and isinstance(self.weight_decay, str): |
|
|
self.weight_decay = float(self.weight_decay) |
|
|
if hasattr(self, 'stopping_thresh') and isinstance(self.stopping_thresh, str): |
|
|
self.stopping_thresh = float(self.stopping_thresh) |
|
|
|
|
|
|
|
|
if not hasattr(self, 'd_vocab') or self.d_vocab is None: |
|
|
self.d_vocab = self.p |
|
|
|
|
|
|
|
|
if not hasattr(self, 'd_model') or self.d_model is None: |
|
|
if hasattr(self, 'embed_type') and self.embed_type == 'one_hot': |
|
|
self.d_model = self.d_vocab |
|
|
else: |
|
|
|
|
|
self.d_model = 128 |
|
|
|
|
|
|
|
|
if hasattr(self, 'seed'): |
|
|
set_all_seeds(self.seed) |
|
|
print(f"All random seeds set to: {self.seed}") |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(self.device) |
|
|
|
|
|
def _flatten_config(self, config_dict, parent_key=''): |
|
|
"""Flatten nested configuration dictionary into flat attributes""" |
|
|
for key, value in config_dict.items(): |
|
|
if isinstance(value, dict): |
|
|
|
|
|
self._flatten_config(value, parent_key) |
|
|
else: |
|
|
|
|
|
setattr(self, key, value) |
|
|
|
|
|
|
|
|
@property |
|
|
def random_answers(self): |
|
|
return np.random.randint(low=0, high=self.p, size=(self.p, self.p)) |
|
|
|
|
|
|
|
|
@property |
|
|
def fns_dict(self): |
|
|
return { |
|
|
'add': lambda x, y: (x + y) % self.p, |
|
|
'subtract': lambda x, y: (x - y) % self.p, |
|
|
'x2xyy2': lambda x, y: (x**2 + x * y + y**2) % self.p, |
|
|
'rand': lambda x, y: self.random_answers[x][y] |
|
|
} |
|
|
|
|
|
|
|
|
@property |
|
|
def fn(self): |
|
|
return self.fns_dict[self.fn_name] |
|
|
|
|
|
|
|
|
def is_train_is_test(self, train): |
|
|
'''Creates an array of Boolean indices according to whether each data point is in train or test. |
|
|
Used to index into the big batch of all possible data''' |
|
|
|
|
|
is_train = [] |
|
|
is_test = [] |
|
|
|
|
|
for x in range(self.p): |
|
|
for y in range(self.p): |
|
|
if (x, y, 113) in train: |
|
|
is_train.append(True) |
|
|
is_test.append(False) |
|
|
else: |
|
|
is_train.append(False) |
|
|
is_test.append(True) |
|
|
|
|
|
is_train = np.array(is_train) |
|
|
is_test = np.array(is_test) |
|
|
return (is_train, is_test) |
|
|
|
|
|
|
|
|
def is_it_time_to_save(self, epoch): |
|
|
return (epoch % self.save_every == 0) |
|
|
|
|
|
|
|
|
def is_it_time_to_take_metrics(self, epoch): |
|
|
return epoch % self.take_metrics_every_n_epochs == 0 |
|
|
|
|
|
def update_param(self, param_name, value): |
|
|
setattr(self, param_name, value) |
|
|
|
|
|
|
|
|
def gen_train_test(config: Config): |
|
|
'''Generate train and test split with precomputed labels as tensors''' |
|
|
num_to_generate = config.p |
|
|
|
|
|
|
|
|
all_pairs = [] |
|
|
all_labels = [] |
|
|
for i in range(num_to_generate): |
|
|
for j in range(num_to_generate): |
|
|
all_pairs.append((i, j)) |
|
|
all_labels.append(config.fn(i, j)) |
|
|
|
|
|
|
|
|
device = config.device if hasattr(config, 'device') else torch.device('cpu') |
|
|
data_tensor = torch.tensor(all_pairs, device=device, dtype=torch.long) |
|
|
labels_tensor = torch.tensor(all_labels, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
random.seed(config.seed) |
|
|
indices = torch.randperm(len(all_pairs), device=device) |
|
|
|
|
|
data_tensor = data_tensor[indices] |
|
|
labels_tensor = labels_tensor[indices] |
|
|
|
|
|
|
|
|
if config.frac_train == 1: |
|
|
return (data_tensor, labels_tensor), (data_tensor, labels_tensor) |
|
|
|
|
|
div = int(config.frac_train * len(all_pairs)) |
|
|
train_data = (data_tensor[:div], labels_tensor[:div]) |
|
|
test_data = (data_tensor[div:], labels_tensor[div:]) |
|
|
|
|
|
return train_data, test_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy_high_precision(logits, labels): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logprobs = F.log_softmax(logits.to(torch.float32), dim=-1) |
|
|
prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1) |
|
|
loss = -torch.mean(prediction_logprobs) |
|
|
return loss |
|
|
|
|
|
def full_loss(config : Config, model: EmbedMLP, data): |
|
|
'''Takes the cross entropy loss of the model on the data''' |
|
|
|
|
|
if isinstance(data, tuple) and len(data) == 2: |
|
|
data_tensor, labels = data |
|
|
else: |
|
|
|
|
|
if not isinstance(data, torch.Tensor): |
|
|
data_tensor = torch.tensor(data, device=config.device) |
|
|
elif data.device != config.device: |
|
|
data_tensor = data.to(config.device) |
|
|
else: |
|
|
data_tensor = data |
|
|
|
|
|
labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) |
|
|
|
|
|
|
|
|
logits = model(data_tensor) |
|
|
return cross_entropy_high_precision(logits, labels) |
|
|
|
|
|
def acc_rate(logits, labels): |
|
|
predictions = torch.argmax(logits, dim=1) |
|
|
correct = (predictions == labels).sum().item() |
|
|
accuracy = correct / labels.size(0) |
|
|
return accuracy |
|
|
|
|
|
def acc(config: Config, model: EmbedMLP, data): |
|
|
'''Compute accuracy of the model on the data''' |
|
|
|
|
|
if isinstance(data, tuple) and len(data) == 2: |
|
|
data_tensor, labels = data |
|
|
else: |
|
|
|
|
|
if not isinstance(data, torch.Tensor): |
|
|
data_tensor = torch.tensor(data, device=config.device) |
|
|
elif data.device != config.device: |
|
|
data_tensor = data.to(config.device) |
|
|
else: |
|
|
data_tensor = data |
|
|
|
|
|
labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) |
|
|
|
|
|
logits = model(data_tensor) |
|
|
predictions = torch.argmax(logits, dim=1) |
|
|
correct = (predictions == labels).sum().item() |
|
|
accuracy = correct / labels.size(0) |
|
|
return accuracy |
|
|
|
|
|
|