WCNegentropy's picture
📚 Updated with scientifically rigorous documentation
dc2b9f3 verified
"""Projection module used to slice membranes along code vectors.
This module implements :class:`Slicer`, a light‑weight wrapper around a
single matrix multiplication. Given a bank of membranes ``M`` with shape
``[B, L, H, W]`` and a set of code vectors ``C`` with shape ``[L, K]`` the
module projects the membranes onto the codes resulting in tensors of shape
``[B, K, H, W]``. An optional bias and ReLU non‑linearity can be applied to
the result.
Only tensor operations are performed; the module deliberately avoids any
side effects beyond those on its parameters so that unit tests can reason
about its behaviour deterministically.
"""
from __future__ import annotations
import torch
from torch import nn
class Slicer(nn.Module):
"""Project membranes ``M`` onto code vectors ``C``.
Parameters
----------
C:
Matrix with shape ``[L, K]`` containing the code vectors. The tensor
is copied and stored as the weight ``W`` of the module.
bias:
If ``True`` (default) a bias term of shape ``[K, 1, 1]`` is added to
the output.
relu:
If ``True`` (default) apply a ReLU non‑linearity to the projected
result.
"""
def __init__(self, C: torch.Tensor, bias: bool = True, relu: bool = True):
super().__init__()
self.use_bias = bias
self.use_relu = relu
# Store the codes as a non‑trainable parameter ``W`` with shape [L, K].
W = C.detach().clone()
self.W = nn.Parameter(W, requires_grad=False)
if bias:
b = torch.zeros(W.shape[1], 1, 1, dtype=W.dtype, device=W.device)
self.bias = nn.Parameter(b, requires_grad=False)
else:
self.register_parameter("bias", None)
def forward(self, M: torch.Tensor) -> torch.Tensor: # [B, L, H, W]
"""Return ``torch.einsum('blhw,lk->bkhw', M, W)`` with optional bias
and ReLU.
"""
Y = torch.einsum("blhw,lk->bkhw", M, self.W)
if self.use_bias and self.bias is not None:
Y = Y + self.bias # bias shape [K, 1, 1] broadcasts over batch and spatial dims
if self.use_relu:
Y = torch.relu(Y)
return Y
def make_slicer(C: torch.Tensor, learnable: bool = False) -> Slicer:
"""Utility helper returning a :class:`Slicer` initialised with ``C``.
Parameters
----------
C:
Code matrix with shape ``[L, K]``.
learnable:
If ``True`` all parameters of the returned module will require
gradients. By default the slicer is non‑learnable which matches the
requirements for the P0 prototype.
"""
slicer = Slicer(C)
if learnable:
for p in slicer.parameters():
if p is not None:
p.requires_grad_(True)
return slicer