"""Projection module used to slice membranes along code vectors. This module implements :class:`Slicer`, a light‑weight wrapper around a single matrix multiplication. Given a bank of membranes ``M`` with shape ``[B, L, H, W]`` and a set of code vectors ``C`` with shape ``[L, K]`` the module projects the membranes onto the codes resulting in tensors of shape ``[B, K, H, W]``. An optional bias and ReLU non‑linearity can be applied to the result. Only tensor operations are performed; the module deliberately avoids any side effects beyond those on its parameters so that unit tests can reason about its behaviour deterministically. """ from __future__ import annotations import torch from torch import nn class Slicer(nn.Module): """Project membranes ``M`` onto code vectors ``C``. Parameters ---------- C: Matrix with shape ``[L, K]`` containing the code vectors. The tensor is copied and stored as the weight ``W`` of the module. bias: If ``True`` (default) a bias term of shape ``[K, 1, 1]`` is added to the output. relu: If ``True`` (default) apply a ReLU non‑linearity to the projected result. """ def __init__(self, C: torch.Tensor, bias: bool = True, relu: bool = True): super().__init__() self.use_bias = bias self.use_relu = relu # Store the codes as a non‑trainable parameter ``W`` with shape [L, K]. W = C.detach().clone() self.W = nn.Parameter(W, requires_grad=False) if bias: b = torch.zeros(W.shape[1], 1, 1, dtype=W.dtype, device=W.device) self.bias = nn.Parameter(b, requires_grad=False) else: self.register_parameter("bias", None) def forward(self, M: torch.Tensor) -> torch.Tensor: # [B, L, H, W] """Return ``torch.einsum('blhw,lk->bkhw', M, W)`` with optional bias and ReLU. """ Y = torch.einsum("blhw,lk->bkhw", M, self.W) if self.use_bias and self.bias is not None: Y = Y + self.bias # bias shape [K, 1, 1] broadcasts over batch and spatial dims if self.use_relu: Y = torch.relu(Y) return Y def make_slicer(C: torch.Tensor, learnable: bool = False) -> Slicer: """Utility helper returning a :class:`Slicer` initialised with ``C``. Parameters ---------- C: Code matrix with shape ``[L, K]``. learnable: If ``True`` all parameters of the returned module will require gradients. By default the slicer is non‑learnable which matches the requirements for the P0 prototype. """ slicer = Slicer(C) if learnable: for p in slicer.parameters(): if p is not None: p.requires_grad_(True) return slicer