"""Magnitude-preserving layers and helper functions.""" from functools import partial import math import numpy as np import torch import torch.nn as nn def normalize(x, dim=None, eps=1e-4): norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True) norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) return x / norm def resample(x, mode='keep', factor=2): """Resample the input tensor x. If mode is 'keep', the input tensor is returned as is. If the mode is 'down', the input tensor is downsampled by a factor of 2 by a 1x1 convolution with stride 2. If the mode is 'up', the input tensor is upsampled by a factor of 2 by a 2x2 convolution with stride 2 and uniform weight 1. If the mode is 'up_bilinear', the input tensor is upsampled using bilinear interpolation. """ if mode == 'keep': return x c = x.shape[1] if mode == 'down': return torch.nn.functional.conv2d(x, torch.ones([c, 1, 1, 1], device=x.device, dtype=x.dtype), groups=c, stride=factor) if mode == 'up_bilinear': return torch.nn.functional.interpolate(x, scale_factor=factor, mode='bilinear', align_corners=False) assert mode == 'up' return torch.nn.functional.conv_transpose2d(x, torch.ones([c, 1, factor, factor], device=x.device, dtype=x.dtype), groups=c, stride=factor) def mp_silu(x): return torch.nn.functional.silu(x) / 0.596 def mp_hardsilu(x): return torch.nn.functional.hardswish(x) / 0.576 def mp_sigmoid(x): return torch.sigmoid(x) / 0.208 def mp_leaky_relu(x, alpha): factor = np.sqrt((1 + alpha**2) / 2) return torch.nn.functional.leaky_relu(x, alpha) / factor def mp_sum(args, w=None): """ Magnitude preserving sum of tensors. parameters: args: list of tensors to sum. w: list of weights for each tensor. If None, all tensors are weighted equally. Should sum to 1 to preserve magnitude. If a float, the weights are [1-w, w] (a linear interpolation). """ if w is None: w = torch.full((len(args),), 1 / len(args), dtype=args[0].dtype, device=args[0].device) elif isinstance(w, float): w = torch.tensor([1-w, w], dtype=args[0].dtype, device=args[0].device) else: w = torch.tensor(w, dtype=args[0].dtype, device=args[0].device) return torch.sum(torch.stack([args * w for args, w in zip(args, w)]), dim=0) / torch.linalg.vector_norm(w) def mp_concat(args, dim=1, w=None): """ Magnitude preserving concatenation of tensors. It should be noted that the concatenated tensors are already magnitude preserving, however the contribution of each tensor in subsequent layers is proportional to the number of channels it has. This function corrects for this by scaling the tensors to have the same overall magnitude, but the contributions of each tensor is the same. parameters: args: list of tensors to concatenate. w: list of weights for each tensor. If None, all tensors are weighted equally. Should sum to 1 to preserve magnitude. If a float, the weights are [1-w, w] (a linear interpolation). """ if w is None: w = torch.full((len(args),), 1 / len(args), dtype=args[0].dtype, device=args[0].device) elif isinstance(w, float): w = torch.tensor([1-w, w], dtype=args[0].dtype, device=args[0].device) else: w = torch.tensor(w, dtype=args[0].dtype, device=args[0].device) N = [x.shape[dim] for x in args] sum_N = torch.tensor(sum(N), dtype=args[0].dtype, device=args[0].device) C = torch.sqrt(sum_N / torch.sum(torch.square(w))) return torch.concat([args[i] * (C / np.sqrt(args[i].shape[dim]) * w[i]) for i in range(len(args))], dim=dim) class MPPositionalEmbedding(nn.Module): def __init__(self, num_channels): super().__init__() self.num_channels = num_channels half_dim = num_channels // 2 emb = math.log(10) / (half_dim - 1) self.register_buffer('freqs', torch.exp(torch.arange(half_dim) * -emb)) def forward(self, x): # Convert input to float32 for higher precision calculations y = x.to(torch.float32) # Compute outer product of input with frequencies y = y.outer(self.freqs.to(torch.float32)) # Apply sin and cos, concatenate, and normalize by sqrt(2) to maintain unit variance y = torch.cat([torch.sin(y), torch.cos(y)], dim=1) * np.sqrt(2) # Convert back to original dtype and return return y.to(x.dtype) class MPFourier(nn.Module): def __init__(self, num_channels, s=1): super().__init__() self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * s) self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels)) def forward(self, x): # Convert input to float32 for higher precision calculations y = x.to(torch.float32) # Compute outer product of input with frequencies # This creates a 2D tensor where each row is the input multiplied by a frequency y = y.outer(self.freqs.to(torch.float32)) # Add phase shifts to each element y = y + self.phases.to(torch.float32) # Apply cosine function to get periodic features # Multiply by sqrt(2) to maintain unit variance y = y.cos() * np.sqrt(2) # Convert back to original dtype and return return y.to(x.dtype) class MPConvResample(nn.Module): def __init__(self, resample_mode, kernel, in_channels, out_channels, skip_weight=0.0): """Resamples a tensor with MP convolution or transposed convolution. Args: resample_mode (str): Either 'up', 'up_bilinear', or 'down'. kernel (list): Kernel size for the convolution. in_channels (int): Number of input channels. out_channels (int): Number of output channels. skip_weight (float): Weight for the skip connection. """ super().__init__() self.resample_mode = resample_mode self.in_channels = in_channels self.out_channels = out_channels self.skip_weight = skip_weight self.stride = kernel[0] if self.resample_mode == 'down': self.weight = nn.Parameter(torch.ones(out_channels, in_channels, *kernel)) elif self.resample_mode == 'up' or self.resample_mode == 'up_bilinear': self.weight = nn.Parameter(torch.ones(in_channels, out_channels, *kernel)) else: raise ValueError("resample_mode must be either 'up' or 'down'") def forward(self, x, gain=1): # Keep weight in float32 during normalization w = self.weight.to(torch.float32) # For numerical stability, we normalize the weights to internally have a norm of 1. if self.training: with torch.no_grad(): self.weight.copy_(normalize(w)) # Weights are already normalized, but this is critical so that gradients are propogated through the normalization. w = normalize(w) w = w * (gain / np.sqrt(w[0].numel())) w = w.to(x.dtype) upsampled = resample(x, mode=self.resample_mode, factor=self.stride) if self.resample_mode == 'down': y = torch.nn.functional.conv2d(x, w, stride=self.stride, padding=0) else: y = torch.nn.functional.conv_transpose2d(x, w, stride=self.stride, padding=0) return mp_sum([y, upsampled], w=self.skip_weight) def norm_weights(self): with torch.no_grad(): self.weight.copy_(normalize(self.weight.to(torch.float32))) class MPConv(nn.Module): """ Magnitude preserving convolution. Conveniently, a kernel of [] is the same as a linear layer. This class is a wrapper around the standard Conv2d layer, but with the following modifications: - During training, the weight is forced to be normalized to have a magnitude of 1. - The weights are then normalized to have a norm of 1 and then scaled to preserve the magnitude of the outputs. `gain` is used to scale the output of the layer to potentially provide more control. The default value of 1 keeps output magnitudes similar to input magnitudes. """ def __init__(self, in_channels, out_channels, kernel, groups=1, no_padding=False): super().__init__() self.out_channels = out_channels assert in_channels % groups == 0, "in_channels must be divisible by groups" assert groups == 1 or len(kernel) == 2, "Groups other than 1 require a 2D kernel" self.weight = nn.Parameter(torch.randn(out_channels, in_channels // groups, *kernel)) self.groups = groups self.no_padding = no_padding def forward(self, x, gain=1): # Keep weight in float32 during normalization w = self.weight.to(torch.float32) # For numerical stability, we normalize the weights to internally have a norm of 1. if self.training: with torch.no_grad(): self.weight.copy_(normalize(w)) # Weights are already normalized, but this is critical so that gradients are propogated through the normalization. w = normalize(w) w = w * (gain / np.sqrt(w[0].numel())) w = w.to(x.dtype) # If the kernel is 0D, just do a linear layer if w.ndim == 2: return nn.functional.linear(x, w) # Otherwise do a 2D convolution assert w.ndim == 4 return nn.functional.conv2d(x, w, padding=(0 if self.no_padding else w.shape[-1]//2,), groups=self.groups) def norm_weights(self): with torch.no_grad(): self.weight.copy_(normalize(self.weight.to(torch.float32))) class MPEmbedding(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.out_channels = out_channels self.weight = nn.Parameter(torch.randn(in_channels, out_channels)) def forward(self, x): w = self.weight.to(torch.float32) if self.training: with torch.no_grad(): self.weight.copy_(normalize(w)) w = normalize(w) w = w.to(x.dtype) assert torch.max(x) < self.weight.shape[0], f"Embedding index out of bounds: {torch.max(x).item()}" return nn.functional.embedding(x, self.weight) def norm_weights(self): with torch.no_grad(): self.weight.copy_(normalize(self.weight.to(torch.float32)))