| | """Projection module used to slice membranes along code vectors. |
| | |
| | This module implements :class:`Slicer`, a light‑weight wrapper around a |
| | single matrix multiplication. Given a bank of membranes ``M`` with shape |
| | ``[B, L, H, W]`` and a set of code vectors ``C`` with shape ``[L, K]`` the |
| | module projects the membranes onto the codes resulting in tensors of shape |
| | ``[B, K, H, W]``. An optional bias and ReLU non‑linearity can be applied to |
| | the result. |
| | |
| | Only tensor operations are performed; the module deliberately avoids any |
| | side effects beyond those on its parameters so that unit tests can reason |
| | about its behaviour deterministically. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class Slicer(nn.Module): |
| | """Project membranes ``M`` onto code vectors ``C``. |
| | |
| | Parameters |
| | ---------- |
| | C: |
| | Matrix with shape ``[L, K]`` containing the code vectors. The tensor |
| | is copied and stored as the weight ``W`` of the module. |
| | bias: |
| | If ``True`` (default) a bias term of shape ``[K, 1, 1]`` is added to |
| | the output. |
| | relu: |
| | If ``True`` (default) apply a ReLU non‑linearity to the projected |
| | result. |
| | """ |
| |
|
| | def __init__(self, C: torch.Tensor, bias: bool = True, relu: bool = True): |
| | super().__init__() |
| | self.use_bias = bias |
| | self.use_relu = relu |
| |
|
| | |
| | W = C.detach().clone() |
| | self.W = nn.Parameter(W, requires_grad=False) |
| |
|
| | if bias: |
| | b = torch.zeros(W.shape[1], 1, 1, dtype=W.dtype, device=W.device) |
| | self.bias = nn.Parameter(b, requires_grad=False) |
| | else: |
| | self.register_parameter("bias", None) |
| |
|
| | def forward(self, M: torch.Tensor) -> torch.Tensor: |
| | """Return ``torch.einsum('blhw,lk->bkhw', M, W)`` with optional bias |
| | and ReLU. |
| | """ |
| |
|
| | Y = torch.einsum("blhw,lk->bkhw", M, self.W) |
| | if self.use_bias and self.bias is not None: |
| | Y = Y + self.bias |
| | if self.use_relu: |
| | Y = torch.relu(Y) |
| | return Y |
| |
|
| |
|
| | def make_slicer(C: torch.Tensor, learnable: bool = False) -> Slicer: |
| | """Utility helper returning a :class:`Slicer` initialised with ``C``. |
| | |
| | Parameters |
| | ---------- |
| | C: |
| | Code matrix with shape ``[L, K]``. |
| | learnable: |
| | If ``True`` all parameters of the returned module will require |
| | gradients. By default the slicer is non‑learnable which matches the |
| | requirements for the P0 prototype. |
| | """ |
| |
|
| | slicer = Slicer(C) |
| | if learnable: |
| | for p in slicer.parameters(): |
| | if p is not None: |
| | p.requires_grad_(True) |
| | return slicer |
| |
|
| |
|