mpknet-v6 / mpknet_components.py
DJLougen's picture
Upload folder using huggingface_hub
2335bf1 verified
"""
Shared components for all MPKNet model variants.
Contains building blocks used across V1, V2, V3, V4 and detection models:
- RGCLayer: Biologically accurate retinal ganglion cell preprocessing
- BinocularPreMPK: Legacy retinal preprocessing (deprecated, use RGCLayer)
- StereoDisparity: Stereo disparity simulation
- OcularDominanceConv: Convolution with ocular dominance channels
- BinocularMPKPathway: Pathway with binocular processing
- MonocularPathwayBlock: Pathway keeping eyes separate
- StridedMonocularBlock: Strided pathway for V4
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class RGCLayer(nn.Module):
"""
Biologically accurate Retinal Ganglion Cell layer.
Based on Kim et al. 2021 "Retinal Ganglion Cells—Diversity of Cell Types
and Clinical Relevance" (Front. Neurol. 12:661938).
Models three main RGC types that feed the P/K/M pathways:
1. MIDGET RGCs (~70% of RGCs):
- Small receptive field (5-10 μm dendritic field)
- Center-surround via Difference of Gaussians (DoG)
- Red-Green color opponency (L-M or M-L)
- Feeds PARVOCELLULAR (P) pathway
- High spatial acuity, low temporal resolution
2. PARASOL RGCs (~10% of RGCs):
- Large receptive field (30-300 μm dendritic field)
- Center-surround DoG on luminance
- Achromatic (no color, L+M pooled)
- Feeds MAGNOCELLULAR (M) pathway
- Motion detection, high temporal resolution
3. SMALL BISTRATIFIED RGCs (~5-8% of RGCs):
- Medium receptive field
- S-cone ON center, (L+M) OFF surround
- Blue-Yellow opponency
- Feeds KONIOCELLULAR (K) pathway
- Color context, particularly blue
Key biological details implemented:
- DoG (Difference of Gaussians) for center-surround RF
- RF size ratios: Midget < Bistratified < Parasol
- Surround ~3-6x larger than center (we use 3x)
- ON-center and OFF-center populations (we use ON-center)
"""
def __init__(
self,
midget_sigma: float = 0.8, # Small RF for fine detail
parasol_sigma: float = 2.5, # Large RF for motion/gist
bistrat_sigma: float = 1.2, # Medium RF for color context
surround_ratio: float = 3.0, # Surround is 3x center
):
super().__init__()
self.midget_sigma = midget_sigma
self.parasol_sigma = parasol_sigma
self.bistrat_sigma = bistrat_sigma
self.surround_ratio = surround_ratio
# Create DoG kernels for each cell type
self.register_buffer('midget_center', self._make_gaussian(midget_sigma))
self.register_buffer('midget_surround', self._make_gaussian(midget_sigma * surround_ratio))
self.register_buffer('parasol_center', self._make_gaussian(parasol_sigma))
self.register_buffer('parasol_surround', self._make_gaussian(parasol_sigma * surround_ratio))
self.register_buffer('bistrat_center', self._make_gaussian(bistrat_sigma))
self.register_buffer('bistrat_surround', self._make_gaussian(bistrat_sigma * surround_ratio))
# Store kernel sizes for padding calculation
self.midget_ks = self.midget_surround.shape[-1]
self.parasol_ks = self.parasol_surround.shape[-1]
self.bistrat_ks = self.bistrat_surround.shape[-1]
def _make_gaussian(self, sigma: float) -> torch.Tensor:
"""Create a normalized 2D Gaussian kernel."""
ks = int(6 * sigma + 1) | 1 # Ensure odd, cover 3 sigma each side
ax = torch.arange(ks, dtype=torch.float32) - ks // 2
xx, yy = torch.meshgrid(ax, ax, indexing='ij')
kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
kernel = kernel / kernel.sum() # Normalize
return kernel.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
def _apply_dog(
self,
x: torch.Tensor,
center_kernel: torch.Tensor,
surround_kernel: torch.Tensor,
kernel_size: int
) -> torch.Tensor:
"""Apply Difference of Gaussians (center - surround)."""
B, C, H, W = x.shape
padding = kernel_size // 2
# Expand kernels for all channels
center_k = center_kernel.expand(C, 1, -1, -1)
surround_k = surround_kernel.expand(C, 1, -1, -1)
# Pad surround kernel to match size if needed
c_size = center_k.shape[-1]
s_size = surround_k.shape[-1]
if c_size < s_size:
pad_amt = (s_size - c_size) // 2
center_k = F.pad(center_k, (pad_amt, pad_amt, pad_amt, pad_amt))
# Apply center and surround
center_response = F.conv2d(x, center_k, padding=padding, groups=C)
surround_response = F.conv2d(x, surround_k, padding=padding, groups=C)
# DoG: ON-center response (center - surround)
return center_response - surround_response
def forward(
self,
x_left: torch.Tensor,
x_right: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Process left and right eye inputs through RGC populations.
Returns:
P_left, P_right: Midget RGC output (R-G opponency) -> P pathway
M_left, M_right: Parasol RGC output (luminance DoG) -> M pathway
K_left, K_right: Bistratified RGC output (S vs L+M) -> K pathway
"""
# ========== MIDGET RGCs -> P pathway ==========
# Red-Green opponency: L-cone vs M-cone
# Approximate: R channel vs G channel
# DoG on the opponent signal
# Extract R and G channels (approximating L and M cones)
R_left, G_left = x_left[:, 0:1], x_left[:, 1:2]
R_right, G_right = x_right[:, 0:1], x_right[:, 1:2]
# L-M opponency (R-G) with small receptive field DoG
rg_left = R_left - G_left
rg_right = R_right - G_right
P_left = self._apply_dog(rg_left, self.midget_center, self.midget_surround, self.midget_ks)
P_right = self._apply_dog(rg_right, self.midget_center, self.midget_surround, self.midget_ks)
# Expand back to 3 channels for compatibility
P_left = P_left.expand(-1, 3, -1, -1)
P_right = P_right.expand(-1, 3, -1, -1)
# ========== PARASOL RGCs -> M pathway ==========
# Achromatic: pool L+M (approximate as luminance)
# Large RF DoG for motion sensitivity
lum_left = 0.299 * x_left[:, 0:1] + 0.587 * x_left[:, 1:2] + 0.114 * x_left[:, 2:3]
lum_right = 0.299 * x_right[:, 0:1] + 0.587 * x_right[:, 1:2] + 0.114 * x_right[:, 2:3]
M_left = self._apply_dog(lum_left, self.parasol_center, self.parasol_surround, self.parasol_ks)
M_right = self._apply_dog(lum_right, self.parasol_center, self.parasol_surround, self.parasol_ks)
# Expand to 3 channels
M_left = M_left.expand(-1, 3, -1, -1)
M_right = M_right.expand(-1, 3, -1, -1)
# ========== SMALL BISTRATIFIED RGCs -> K pathway ==========
# S-cone ON center, (L+M) OFF surround
# Blue-Yellow opponency: S vs (L+M)
# S-cone approximated by B channel
# (L+M) approximated by (R+G)/2
S_left = x_left[:, 2:3] # Blue
S_right = x_right[:, 2:3]
LM_left = (x_left[:, 0:1] + x_left[:, 1:2]) / 2
LM_right = (x_right[:, 0:1] + x_right[:, 1:2]) / 2
# S - (L+M) opponency with medium RF
by_left = S_left - LM_left
by_right = S_right - LM_right
K_left = self._apply_dog(by_left, self.bistrat_center, self.bistrat_surround, self.bistrat_ks)
K_right = self._apply_dog(by_right, self.bistrat_center, self.bistrat_surround, self.bistrat_ks)
# Expand to 3 channels
K_left = K_left.expand(-1, 3, -1, -1)
K_right = K_right.expand(-1, 3, -1, -1)
return P_left, M_left, K_left, P_right, M_right, K_right
class BinocularPreMPK(nn.Module):
"""
Simulates retinal + LGN preprocessing for both eyes.
Each eye gets its own center-surround filtering.
Biological motivation:
- Retinal ganglion cells have center-surround receptive fields
- M cells respond to luminance changes (motion/gist)
- P cells respond to color/detail (high-pass filtered)
"""
def __init__(self, sigma: float = 1.0):
super().__init__()
self.sigma = sigma
ks = int(4 * sigma + 1) | 1 # ensure odd
ax = torch.arange(ks, dtype=torch.float32) - ks // 2
xx, yy = torch.meshgrid(ax, ax, indexing='ij')
kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
self.register_buffer('gauss', kernel.unsqueeze(0).unsqueeze(0))
self.ks = ks
def _blur(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
kernel = self.gauss.expand(C, 1, self.ks, self.ks)
return F.conv2d(x, kernel, padding=self.ks // 2, groups=C)
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
Returns (P_left, M_left, P_right, M_right)
P = high-pass (center - surround) for detail
M = low-pass luminance for motion/gist
"""
# Left eye
blur_L = self._blur(x_left)
P_left = x_left - blur_L # high-pass (Parvo-like)
lum_L = x_left.mean(dim=1, keepdim=True)
M_left = self._blur(lum_L).expand(-1, 3, -1, -1) # low-pass luminance (Magno-like)
# Right eye
blur_R = self._blur(x_right)
P_right = x_right - blur_R
lum_R = x_right.mean(dim=1, keepdim=True)
M_right = self._blur(lum_R).expand(-1, 3, -1, -1)
return P_left, M_left, P_right, M_right
class StereoDisparity(nn.Module):
"""
Creates stereo disparity by horizontally shifting left/right views.
Simulates the slight positional difference between two eyes.
disparity_range: maximum pixel shift (positive = crossed disparity)
"""
def __init__(self, disparity_range: int = 2):
super().__init__()
self.disparity_range = disparity_range
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Takes single image, returns (left_view, right_view) with disparity.
For training, uses random disparity. For inference, uses fixed small disparity.
"""
B, C, H, W = x.shape
if self.training:
d = torch.randint(-self.disparity_range, self.disparity_range + 1, (1,)).item()
else:
d = 1
if d == 0:
return x, x
if d > 0:
x_left = F.pad(x[:, :, :, d:], (0, d, 0, 0), mode='replicate')
x_right = F.pad(x[:, :, :, :-d], (d, 0, 0, 0), mode='replicate')
else:
d = -d
x_left = F.pad(x[:, :, :, :-d], (d, 0, 0, 0), mode='replicate')
x_right = F.pad(x[:, :, :, d:], (0, d, 0, 0), mode='replicate')
return x_left, x_right
class OcularDominanceConv(nn.Module):
"""
Convolution with ocular dominance - channels are assigned to left/right eye
with graded mixing (some purely monocular, some binocular).
Inspired by V1 ocular dominance columns but applied at LGN stage
for computational efficiency.
"""
def __init__(self, in_ch: int, out_ch: int, kernel_size: int,
monocular_ratio: float = 0.5):
super().__init__()
self.out_ch = out_ch
self.monocular_ratio = monocular_ratio
n_mono = int(out_ch * monocular_ratio)
n_mono_per_eye = n_mono // 2
n_bino = out_ch - 2 * n_mono_per_eye
self.n_left = n_mono_per_eye
self.n_right = n_mono_per_eye
self.n_bino = n_bino
self.conv_left = nn.Conv2d(in_ch, n_mono_per_eye, kernel_size, padding=kernel_size//2)
self.conv_right = nn.Conv2d(in_ch, n_mono_per_eye, kernel_size, padding=kernel_size//2)
self.conv_bino_L = nn.Conv2d(in_ch, n_bino, kernel_size, padding=kernel_size//2)
self.conv_bino_R = nn.Conv2d(in_ch, n_bino, kernel_size, padding=kernel_size//2)
self.bn = nn.BatchNorm2d(out_ch)
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
left_only = self.conv_left(x_left)
right_only = self.conv_right(x_right)
bino = self.conv_bino_L(x_left) + self.conv_bino_R(x_right)
out = torch.cat([left_only, right_only, bino], dim=1)
return F.relu(self.bn(out))
class BinocularMPKPathway(nn.Module):
"""
Single pathway (M, P, or K) with binocular processing.
Receives left and right eye inputs, produces fused output.
"""
def __init__(self, in_ch: int, out_ch: int, kernel_sizes: list,
monocular_ratio: float = 0.5):
super().__init__()
layers = []
ch = in_ch
for i, ks in enumerate(kernel_sizes):
is_first = (i == 0)
if is_first:
layers.append(OcularDominanceConv(ch, out_ch, ks, monocular_ratio))
else:
layers.append(nn.Sequential(
nn.Conv2d(out_ch if i > 0 else ch, out_ch, ks, padding=ks//2),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
))
ch = out_ch
self.first_layer = layers[0]
self.rest = nn.Sequential(*layers[1:]) if len(layers) > 1 else nn.Identity()
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
x = self.first_layer(x_left, x_right)
return self.rest(x)
class MonocularPathwayBlock(nn.Module):
"""
Single pathway block that keeps left/right eyes separate.
Used for LGN processing where eye segregation persists.
"""
def __init__(self, in_ch: int, out_ch: int, kernel_size: int):
super().__init__()
self.conv_left = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, padding=kernel_size//2),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.conv_right = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, padding=kernel_size//2),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.conv_left(x_left), self.conv_right(x_right)
class StridedMonocularBlock(nn.Module):
"""
Monocular pathway block with configurable stride.
Keeps left/right eyes separate, uses stride to control spatial sampling.
Used in V4 for stride-based pathway differentiation.
"""
def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, stride: int = 1):
super().__init__()
padding = kernel_size // 2
self.conv_left = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.conv_right = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.conv_left(x_left), self.conv_right(x_right)
def count_params(model: nn.Module) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in model.parameters())