|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
|
|
|
|
| class NdLinear(nn.Module): |
| def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, act_func=None, use_bias=True): |
| """ |
| NdLinear: A PyTorch layer for projecting tensors into multi-space representations. |
| |
| Unlike conventional embedding layers that map into a single vector space, NdLinear |
| transforms tensors across a collection of vector spaces, capturing multivariate structure |
| and topical information that standard deep learning architectures typically lose. |
| |
| Args: |
| input_dims (tuple): Shape of input tensor (excluding batch dimension). |
| hidden_size (tuple): Target hidden dimensions after transformation. |
| """ |
| super(NdLinear, self).__init__() |
|
|
| if len(input_dims) != len(hidden_size): |
| raise Exception("Input shape and hidden shape do not match.") |
|
|
| self.input_dims = input_dims |
| self.hidden_size = hidden_size |
| self.num_layers = len(input_dims) |
| |
| self.act_func = act_func if act_func is not None else nn.Identity() |
| self.transform_outer = transform_outer |
|
|
| |
| self.align_layers = nn.ModuleList([ |
| nn.Linear(input_dims[i], hidden_size[i], bias=use_bias) for i in range(self.num_layers) |
| ]) |
| self.initialize_weights() |
|
|
|
|
| def initialize_weights(self, mean=0.0, std=0.02): |
| for layer in self.align_layers: |
| nn.init.normal_(layer.weight, mean=mean, std=std) |
| if layer.bias is not None: |
| nn.init.constant_(layer.bias, 0) |
|
|
|
|
| def forward(self, X): |
| """ |
| Forward pass to project input tensor into a new multi-space representation. |
| - Incrementally transposes, flattens, applies linear layers, and restores shape. |
| |
| Expected Input Shape: [batch_size, *input_dims] |
| Output Shape: [batch_size, *hidden_size] |
| |
| Args: |
| X (torch.Tensor): Input tensor with shape [batch_size, *input_dims] |
| |
| Returns: |
| torch.Tensor: Output tensor with shape [batch_size, *hidden_size] |
| """ |
| num_transforms = self.num_layers |
| |
| |
| |
|
|
| for i in range(num_transforms): |
| if self.transform_outer: |
| layer = self.align_layers[i] |
| transpose_dim = i + 1 |
| else: |
| layer = self.align_layers[num_transforms - (i+1)] |
| transpose_dim = num_transforms - i |
|
|
| |
| X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
|
|
| |
| X_size = X.shape[:-1] |
|
|
| |
| X = X.view(-1, X.shape[-1]) |
|
|
| |
| X = self.act_func(layer(X)) |
| |
| |
| X = X.view(*X_size, X.shape[-1]) |
|
|
| |
| X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
|
|
| return X |