Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| import re | |
| import os | |
| def compute_decay(T, params, clamp_lims=(0, 15)): | |
| """ | |
| This function computes exponential decays for learnable synchronisation | |
| interactions between pairs of neurons. | |
| """ | |
| assert len(clamp_lims), 'Clamp lims should be length 2' | |
| assert type(clamp_lims) == tuple, 'Clamp lims should be tuple' | |
| indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0]) | |
| out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0)) | |
| return out | |
| def add_coord_dim(x, scaled=True): | |
| """ | |
| Adds a final dimension to the tensor representing 2D coordinates. | |
| Args: | |
| tensor: A PyTorch tensor of shape (B, D, H, W). | |
| Returns: | |
| A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension | |
| representing the 2D coordinates within the HW dimensions. | |
| """ | |
| B, H, W = x.shape | |
| # Create coordinate grids | |
| x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W) | |
| y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W) | |
| if scaled: | |
| x_coords /= (W-1) | |
| y_coords /= (H-1) | |
| # Stack coordinates and expand dimensions | |
| coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2) | |
| coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2) | |
| coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2) | |
| return coords | |
| def compute_normalized_entropy(logits, reduction='mean'): | |
| """ | |
| Calculates the normalized entropy of a PyTorch tensor of logits along the | |
| final dimension. | |
| Args: | |
| logits: A PyTorch tensor of logits. | |
| Returns: | |
| A PyTorch tensor containing the normalized entropy values. | |
| """ | |
| # Apply softmax to get probabilities | |
| preds = F.softmax(logits, dim=-1) | |
| # Calculate the log probabilities | |
| log_preds = torch.log_softmax(logits, dim=-1) | |
| # Calculate the entropy | |
| entropy = -torch.sum(preds * log_preds, dim=-1) | |
| # Calculate the maximum possible entropy | |
| num_classes = preds.shape[-1] | |
| max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32)) | |
| # Normalize the entropy | |
| normalized_entropy = entropy / max_entropy | |
| if len(logits.shape)>2 and reduction == 'mean': | |
| normalized_entropy = normalized_entropy.flatten(1).mean(-1) | |
| return normalized_entropy | |
| def reshape_predictions(predictions, prediction_reshaper): | |
| B, T = predictions.size(0), predictions.size(-1) | |
| new_shape = [B] + prediction_reshaper + [T] | |
| rehaped_predictions = predictions.reshape(new_shape) | |
| return rehaped_predictions | |
| def get_all_log_dirs(root_dir): | |
| folders = [] | |
| for dirpath, dirnames, filenames in os.walk(root_dir): | |
| if any(f.endswith(".pt") for f in filenames): | |
| folders.append(dirpath) | |
| return folders | |
| def get_latest_checkpoint(log_dir): | |
| files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)] | |
| return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None | |
| def get_latest_checkpoint_file(filepath, limit=300000): | |
| checkpoint_files = get_checkpoint_files(filepath) | |
| checkpoint_files = [ | |
| f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit | |
| ] | |
| if not checkpoint_files: | |
| return None | |
| return checkpoint_files[-1] | |
| def get_checkpoint_files(filepath): | |
| regex = r'checkpoint_(\d+)\.pt' | |
| files = [f for f in os.listdir(filepath) if re.match(regex, f)] | |
| files = sorted(files, key=lambda f: int(re.search(regex, f).group(1))) | |
| return [os.path.join(filepath, f) for f in files] | |
| def load_checkpoint(checkpoint_path, device): | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| return checkpoint | |
| def get_model_args_from_checkpoint(checkpoint): | |
| if "args" in checkpoint: | |
| return(checkpoint["args"]) | |
| else: | |
| raise ValueError("Checkpoint does not contain saved args.") | |
| def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"): | |
| training_iteration = checkpoint.get('training_iteration', 0) | |
| train_losses = checkpoint.get('train_losses', []) | |
| test_losses = checkpoint.get('test_losses', []) | |
| train_accuracies = checkpoint.get('train_accuracies_most_certain', []) | |
| test_accuracies = checkpoint.get('test_accuracies_most_certain', []) | |
| return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies | |