import yaml, os, time, wandb, random import numpy as np import torch import torch.optim as optim import torch.nn.functional as F from dataclasses import dataclass from collections import defaultdict from typing import Optional try: # For running from within src/ directory (script execution) from mechanism_base import * from model_base import EmbedMLP except ImportError: # For running from parent directory (Jupyter notebooks) from src.mechanism_base import * from src.model_base import EmbedMLP ########## Configuration Managers ########## def set_all_seeds(seed): """Set all random seeds for reproducibility""" random.seed(seed) # Python random module np.random.seed(seed) # NumPy random torch.manual_seed(seed) # PyTorch CPU random torch.cuda.manual_seed(seed) # PyTorch GPU random (current GPU) torch.cuda.manual_seed_all(seed) # PyTorch GPU random (all GPUs) torch.backends.cudnn.deterministic = True # Make CuDNN deterministic torch.backends.cudnn.benchmark = False # Disable CuDNN benchmarking for reproducibility def read_config(): current_dir = os.path.dirname(__file__) config_path = os.path.join(current_dir, "configs.yaml") with open(config_path, 'r') as stream: try: config = yaml.safe_load(stream) return config except yaml.YAMLError as exc: print(exc) @dataclass class Config: def __init__(self, config): # Ensure that the config dictionary is provided if not config: raise ValueError("Configuration dictionary cannot be None or empty.") # Load configurations from the nested dictionary structure and flatten them self._flatten_config(config) # Ensure numeric types are properly converted if hasattr(self, 'lr') and isinstance(self.lr, str): self.lr = float(self.lr) if hasattr(self, 'weight_decay') and isinstance(self.weight_decay, str): self.weight_decay = float(self.weight_decay) if hasattr(self, 'stopping_thresh') and isinstance(self.stopping_thresh, str): self.stopping_thresh = float(self.stopping_thresh) # Set d_vocab equal to p if not explicitly specified if not hasattr(self, 'd_vocab') or self.d_vocab is None: self.d_vocab = self.p # Set d_model for one_hot embedding type if not hasattr(self, 'd_model') or self.d_model is None: if hasattr(self, 'embed_type') and self.embed_type == 'one_hot': self.d_model = self.d_vocab else: # For learned embeddings, use a default dimension self.d_model = 128 # Set all random seeds for reproducibility if hasattr(self, 'seed'): set_all_seeds(self.seed) print(f"All random seeds set to: {self.seed}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(self.device) def _flatten_config(self, config_dict, parent_key=''): """Flatten nested configuration dictionary into flat attributes""" for key, value in config_dict.items(): if isinstance(value, dict): # Recursively flatten nested dictionaries self._flatten_config(value, parent_key) else: # Set the attribute directly setattr(self, key, value) # Property to generate a matrix of random answers (used for 'rand' function) @property def random_answers(self): return np.random.randint(low=0, high=self.p, size=(self.p, self.p)) # Property to map function names to their corresponding mathematical operations @property def fns_dict(self): return { 'add': lambda x, y: (x + y) % self.p, # Addition modulo p 'subtract': lambda x, y: (x - y) % self.p, # Subtraction modulo p 'x2xyy2': lambda x, y: (x**2 + x * y + y**2) % self.p, # Polynomial function modulo p 'rand': lambda x, y: self.random_answers[x][y] # Random value from a precomputed table } # Property to access the selected function based on 'fn_name' @property def fn(self): return self.fns_dict[self.fn_name] # Function to create Boolean arrays indicating if a data point is in the training or test set def is_train_is_test(self, train): '''Creates an array of Boolean indices according to whether each data point is in train or test. Used to index into the big batch of all possible data''' # Initialize empty lists for training and test indices is_train = [] is_test = [] # Iterate over all possible data points (0 <= x, y < p) for x in range(self.p): for y in range(self.p): if (x, y, 113) in train: # If the data point is in the training set is_train.append(True) is_test.append(False) else: # Otherwise, it's in the test set is_train.append(False) is_test.append(True) # Convert lists to NumPy arrays for efficient indexing is_train = np.array(is_train) is_test = np.array(is_test) return (is_train, is_test) # Function to determine if it's time to save the model (based on epoch number) def is_it_time_to_save(self, epoch): return (epoch % self.save_every == 0) # Function to determine if it's time to take metrics (based on epoch number) def is_it_time_to_take_metrics(self, epoch): return epoch % self.take_metrics_every_n_epochs == 0 def update_param(self, param_name, value): setattr(self, param_name, value) ########## Data Manager ########## def gen_train_test(config: Config): '''Generate train and test split with precomputed labels as tensors''' num_to_generate = config.p # Create data and labels as lists first all_pairs = [] all_labels = [] for i in range(num_to_generate): for j in range(num_to_generate): all_pairs.append((i, j)) all_labels.append(config.fn(i, j)) # Convert to tensors on the appropriate device (GPU if available, CPU otherwise) device = config.device if hasattr(config, 'device') else torch.device('cpu') data_tensor = torch.tensor(all_pairs, device=device, dtype=torch.long) labels_tensor = torch.tensor(all_labels, device=device, dtype=torch.long) # Shuffle using torch operations for efficiency random.seed(config.seed) indices = torch.randperm(len(all_pairs), device=device) data_tensor = data_tensor[indices] labels_tensor = labels_tensor[indices] # If frac_train is 1, use the whole dataset for both train and test. if config.frac_train == 1: return (data_tensor, labels_tensor), (data_tensor, labels_tensor) div = int(config.frac_train * len(all_pairs)) train_data = (data_tensor[:div], labels_tensor[:div]) test_data = (data_tensor[div:], labels_tensor[div:]) return train_data, test_data ########## Training Managers ########## # Trainer class has been moved to NNTrainer.py ########## Loss Definition ########## def cross_entropy_high_precision(logits, labels): # Shapes: batch x vocab, batch # Cast logits to float64 because log_softmax has a float32 underflow on overly # confident data and can only return multiples of 1.2e-7 (the smallest float x # such that 1+x is different from 1 in float32). This leads to loss spikes # and dodgy gradients logprobs = F.log_softmax(logits.to(torch.float32), dim=-1) prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1) loss = -torch.mean(prediction_logprobs) return loss def full_loss(config : Config, model: EmbedMLP, data): '''Takes the cross entropy loss of the model on the data''' # Handle new format: data is (data_tensor, labels_tensor) if isinstance(data, tuple) and len(data) == 2: data_tensor, labels = data else: # Fallback for old format (list of pairs) if not isinstance(data, torch.Tensor): data_tensor = torch.tensor(data, device=config.device) elif data.device != config.device: data_tensor = data.to(config.device) else: data_tensor = data # Compute labels (slow path for backward compatibility) labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) # Compute logits logits = model(data_tensor) return cross_entropy_high_precision(logits, labels) def acc_rate(logits, labels): predictions = torch.argmax(logits, dim=1) # Get predicted class indices correct = (predictions == labels).sum().item() # Count correct predictions accuracy = correct / labels.size(0) # Calculate accuracy return accuracy def acc(config: Config, model: EmbedMLP, data): '''Compute accuracy of the model on the data''' # Handle new format: data is (data_tensor, labels_tensor) if isinstance(data, tuple) and len(data) == 2: data_tensor, labels = data else: # Fallback for old format (list of pairs) if not isinstance(data, torch.Tensor): data_tensor = torch.tensor(data, device=config.device) elif data.device != config.device: data_tensor = data.to(config.device) else: data_tensor = data # Compute labels (slow path for backward compatibility) labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) logits = model(data_tensor) predictions = torch.argmax(logits, dim=1) # Get predicted class indices correct = (predictions == labels).sum().item() # Count correct predictions accuracy = correct / labels.size(0) # Calculate accuracy return accuracy