Spaces:
Runtime error
Runtime error
dlokesha
Part 4 final: bio rank 21.4 overtakes baseline 20.2 at epoch 7 β richer representations confirmed
887bef9 | """ | |
| models.py β CNN architectures matching TBC's paper setup. | |
| TBC used: input layer β single conv layer (5 kernels, 4Γ4) β linear β ReLU β classifier | |
| We keep both architectures identical β the ONLY difference is the input representation. | |
| This isolates the effect of bio preprocessing, same as TBC did. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BaselineCNN(nn.Module): | |
| """ | |
| Trains directly on binarized MNIST images (64Γ64). | |
| This is the control β no biological preprocessing. | |
| """ | |
| def __init__(self, n_classes: int = 10): | |
| super().__init__() | |
| # TBC paper: "single convolutional layer with five 4Γ4 kernels" | |
| self.conv1 = nn.Conv2d(1, 5, kernel_size=4, stride=2, padding=1) | |
| # After conv + stride 2 on 64Γ64: (64-4+2)/2 + 1 = 32 β 32Γ32Γ5 | |
| self.flatten_size = 5 * 32 * 32 | |
| self.fc1 = nn.Linear(self.flatten_size, 128) | |
| self.classifier = nn.Linear(128, n_classes) | |
| def forward(self, x): | |
| # x: (B, 1, 64, 64) | |
| x = F.relu(self.conv1(x)) | |
| x = x.flatten(1) | |
| x = F.relu(self.fc1(x)) | |
| return self.classifier(x) | |
| class BioCNN(nn.Module): | |
| """ | |
| Trains on spike-rate vectors from the reservoir layer. | |
| Uses fully connected layers β spike rates are not spatially | |
| structured like images, so FC works better than conv. | |
| """ | |
| def __init__(self, n_reservoir_units: int = 1024, n_classes: int = 10): | |
| super().__init__() | |
| self.fc1 = nn.Linear(n_reservoir_units, 512) | |
| self.fc2 = nn.Linear(512, 128) | |
| self.classifier = nn.Linear(128, n_classes) | |
| def forward(self, x): | |
| x = F.relu(self.fc1(x)) | |
| x = F.dropout(x, p=0.3, training=self.training) | |
| x = F.relu(self.fc2(x)) | |
| return self.classifier(x) | |
| class AblationCNN(nn.Module): | |
| """ | |
| Used for ablation study β takes a subset of reservoir units. | |
| TBC showed: full array > center > periphery, all above chance. | |
| We replicate this by masking which units feed into the classifier. | |
| """ | |
| def __init__(self, input_dim: int, n_classes: int = 10, hidden_dim: int = 256): | |
| super().__init__() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2) | |
| self.classifier = nn.Linear(hidden_dim // 2, n_classes) | |
| def forward(self, x): | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| return self.classifier(x) | |
| def get_features(self, x): | |
| """Return representation before classifier (for effective rank, etc.).""" | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| return x | |