neuroforge / train.py
dlokesha
Part 2 final: local plasticity mean rank 37.1 β€” continual backprop rank 33.1 confirms reactive methods worse than standard
c3a6dfa
"""
train.py β€” Run the full TBC experiment comparison.
Usage:
python train.py # full run
python train.py --n_samples 500 # quick smoke test
python train.py --epochs 5 # fast iteration
What this does:
1. Loads MNIST, binarizes it, resizes to 64x64
2. Runs all images through the BioPreprocessor (reservoir layer)
3. Trains BaselineCNN on raw images
4. Trains BioCNN on bio-processed spike rates
5. Plots accuracy curves β€” should show bio beats baseline
6. Saves results to outputs/
"""
import argparse
import os
import pickle
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from models import AblationCNN, BaselineCNN, BioCNN
from reservoir import BioPreprocessor, MEAEncoder
def load_mnist(n_samples: int = 2000, data_dir: str = "./data"):
"""Load MNIST and binarize β€” matching TBC's preprocessing exactly."""
print("Loading MNIST...")
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
test_data = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
# Subsample
train_images = train_data.data[:n_samples].numpy() / 255.0
train_labels = train_data.targets[:n_samples].numpy()
test_images = test_data.data[: n_samples // 5].numpy() / 255.0
test_labels = test_data.targets[: n_samples // 5].numpy()
# Binarize: white digits on black background
train_images = (train_images > 0.5).astype(float)
test_images = (test_images > 0.5).astype(float)
return train_images, train_labels, test_images, test_labels
def encode_to_grid(images: np.ndarray, grid_size: int = 64) -> np.ndarray:
"""Resize binarized MNIST to 64Γ—64 electrode grid."""
encoder = MEAEncoder(grid_size=grid_size)
encoded = np.array([encoder.encode(img) for img in images])
return encoded # (N, 64, 64)
def run_bio_preprocessing(images: np.ndarray, cache_path: str = None) -> np.ndarray:
"""
Run all images through the reservoir.
Caches results to disk β€” reservoir is slow, don't rerun unnecessarily.
"""
if cache_path and os.path.exists(cache_path):
print(f"Loading cached bio representations from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print("Running bio preprocessing (this takes a few minutes)...")
t0 = time.time()
preprocessor = BioPreprocessor(n_reservoir_units=1024, grid_size=64, steps=10)
bio_reps = preprocessor.process_batch(images)
elapsed = time.time() - t0
print(f"Done in {elapsed:.1f}s β€” {len(images)/elapsed:.1f} images/sec")
if cache_path:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, "wb") as f:
pickle.dump(bio_reps, f)
print(f"Cached to {cache_path}")
return bio_reps # (N, 1024)
def train_model(model, train_loader, test_loader, epochs: int, lr: float = 1e-3, label: str = ""):
"""Train a model and return per-epoch accuracy on both splits."""
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
train_accs, test_accs = [], []
for epoch in range(epochs):
# Train
model.train()
for x_batch, y_batch in train_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
optimizer.zero_grad()
logits = model(x_batch)
loss = criterion(logits, y_batch)
loss.backward()
optimizer.step()
# Eval train
model.eval()
with torch.no_grad():
correct = total = 0
for x_batch, y_batch in train_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
preds = model(x_batch).argmax(1)
correct += (preds == y_batch).sum().item()
total += len(y_batch)
train_accs.append(correct / total)
# Eval test
correct = total = 0
for x_batch, y_batch in test_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
preds = model(x_batch).argmax(1)
correct += (preds == y_batch).sum().item()
total += len(y_batch)
test_accs.append(correct / total)
print(f" [{label}] Epoch {epoch+1}/{epochs} β€” test acc: {test_accs[-1]*100:.1f}%")
return train_accs, test_accs
def run_ablation(
bio_train: np.ndarray,
bio_test: np.ndarray,
train_labels: np.ndarray,
test_labels: np.ndarray,
n_units: int = 1024,
epochs: int = 10,
):
"""
Ablation study replicating TBC Figure 4 (right panel).
Splits reservoir units into:
- Whole: all 1024 units
- Center: first 512 units (approximates directly stimulated region)
- Periphery: last 512 units (approximates spread-beyond-stimulation)
TBC result: Whole β‰ˆ Center >> Periphery > chance (10%)
All three should beat chance β€” proving bio spread carries real info.
"""
print("\nRunning ablation study...")
results = {}
splits = {
"Whole": (0, n_units),
"Center": (0, n_units // 2),
"Periphery": (n_units // 2, n_units),
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for name, (start, end) in splits.items():
subset_train = bio_train[:, start:end]
subset_test = bio_test[:, start:end]
input_dim = end - start
train_ds = TensorDataset(
torch.FloatTensor(subset_train), torch.LongTensor(train_labels)
)
test_ds = TensorDataset(
torch.FloatTensor(subset_test), torch.LongTensor(test_labels)
)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64)
model = AblationCNN(input_dim=input_dim)
_, test_accs = train_model(model, train_loader, test_loader, epochs, label=name)
results[name] = test_accs[-1]
print(f" {name}: {test_accs[-1]*100:.1f}%")
return results
def plot_results(baseline_test, bio_test, save_path: str = "outputs/accuracy_curves.png"):
"""
Reproduce TBC Figure 4 (left panel):
Bio accuracy curve should be higher and converge faster than baseline.
"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
epochs = range(1, len(baseline_test) + 1)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle("TBC Replication: Biological Preprocessing vs Baseline", fontsize=14)
# Accuracy curves
ax = axes[0]
ax.plot(epochs, [a * 100 for a in baseline_test], "gold", linewidth=2, label="Baseline (raw MNIST)")
ax.plot(epochs, [a * 100 for a in bio_test], "#1D9E75", linewidth=2, label="Bio-preprocessed")
ax.axhline(y=10, color="gray", linestyle="--", alpha=0.5, label="Chance (10%)")
ax.set_xlabel("Training epoch")
ax.set_ylabel("Test accuracy (%)")
ax.set_title("Classification accuracy over training")
ax.legend()
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3)
# Final accuracy bar
ax = axes[1]
final_baseline = baseline_test[-1] * 100
final_bio = bio_test[-1] * 100
bars = ax.bar(
["Baseline", "Bio-preprocessed"],
[final_baseline, final_bio],
color=["gold", "#1D9E75"],
width=0.5,
)
ax.bar_label(bars, fmt="%.1f%%", padding=3)
ax.set_ylabel("Final test accuracy (%)")
ax.set_title(f"Final accuracy (Ξ” = {final_bio - final_baseline:+.1f}%)")
ax.set_ylim(0, 100)
ax.axhline(y=10, color="gray", linestyle="--", alpha=0.5, label="Chance")
ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Saved accuracy plot to {save_path}")
def plot_activation_spread(
raw_image: np.ndarray,
electrode_grid: np.ndarray,
spike_readout: np.ndarray,
save_path: str = "outputs/activation_spread.png",
):
"""
Visualize TBC Figure 3:
Shows how activity spreads beyond the stimulated region.
Raw digit β†’ electrode grid β†’ reservoir spread.
"""
from reservoir import BioPreprocessor
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Get spatial readout
preprocessor = BioPreprocessor(n_reservoir_units=1024)
spatial = preprocessor.reservoir.get_spatial_readout(spike_readout, grid_size=64)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle("Neural activity spread beyond stimulation region", fontsize=13)
axes[0].imshow(raw_image, cmap="gray")
axes[0].set_title("Original MNIST digit")
axes[0].axis("off")
axes[1].imshow(electrode_grid, cmap="gray")
axes[1].set_title("Electrode grid (stimulation pattern)")
axes[1].axis("off")
# Red box showing original 28Γ—28 digit area
rect = plt.Rectangle(
(18, 18), 28, 28, linewidth=2, edgecolor="red", facecolor="none"
)
axes[1].add_patch(rect)
axes[1].text(22, 15, "28Γ—28 digit area", color="red", fontsize=8)
axes[2].imshow(spatial, cmap="hot")
axes[2].set_title("Reservoir activity spread")
axes[2].axis("off")
rect2 = plt.Rectangle(
(18, 18), 28, 28, linewidth=2, edgecolor="cyan", facecolor="none"
)
axes[2].add_patch(rect2)
axes[2].text(22, 15, "original region", color="cyan", fontsize=8)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Saved activation spread plot to {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_samples", type=int, default=2000, help="Training samples (500 for smoke test)")
parser.add_argument("--epochs", type=int, default=10, help="Training epochs")
parser.add_argument("--skip_ablation", action="store_true", help="Skip ablation study")
args = parser.parse_args()
os.makedirs("outputs", exist_ok=True)
# ── 1. Load and encode data ─────────────────────────────────────────────
train_images, train_labels, test_images, test_labels = load_mnist(args.n_samples)
print("Encoding images to electrode grid...")
train_grids = encode_to_grid(train_images) # (N, 64, 64)
test_grids = encode_to_grid(test_images)
# ── 2. Bio preprocessing ────────────────────────────────────────────────
bio_train = run_bio_preprocessing(
train_grids, cache_path=f"outputs/bio_train_{args.n_samples}.pkl"
)
bio_test = run_bio_preprocessing(
test_grids, cache_path=f"outputs/bio_test_{args.n_samples//5}.pkl"
)
# ── 3. Build dataloaders ─────────────────────────────────────────────────
# Baseline: raw 64Γ—64 grids as images
baseline_train_ds = TensorDataset(
torch.FloatTensor(train_grids).unsqueeze(1), # (N, 1, 64, 64)
torch.LongTensor(train_labels),
)
baseline_test_ds = TensorDataset(
torch.FloatTensor(test_grids).unsqueeze(1),
torch.LongTensor(test_labels),
)
# Bio: spike rate vectors
bio_train_ds = TensorDataset(
torch.FloatTensor(bio_train), # (N, 1024)
torch.LongTensor(train_labels),
)
bio_test_ds = TensorDataset(
torch.FloatTensor(bio_test),
torch.LongTensor(test_labels),
)
batch_size = 64
baseline_train_loader = DataLoader(baseline_train_ds, batch_size=batch_size, shuffle=True)
baseline_test_loader = DataLoader(baseline_test_ds, batch_size=batch_size)
bio_train_loader = DataLoader(bio_train_ds, batch_size=batch_size, shuffle=True)
bio_test_loader = DataLoader(bio_test_ds, batch_size=batch_size)
# ── 4. Train both models ─────────────────────────────────────────────────
print("\nTraining baseline CNN on raw MNIST...")
baseline_model = BaselineCNN()
_, baseline_test = train_model(
baseline_model, baseline_train_loader, baseline_test_loader,
args.epochs, label="Baseline"
)
print("\nTraining Bio CNN on reservoir spike rates...")
bio_model = BioCNN()
_, bio_test = train_model(
bio_model, bio_train_loader, bio_test_loader,
args.epochs, label="Bio"
)
# ── 5. Plot accuracy curves ──────────────────────────────────────────────
plot_results(baseline_test, bio_test)
# ── 6. Activation spread visualization ──────────────────────────────────
print("\nGenerating activation spread visualization...")
sample_idx = 0
plot_activation_spread(
raw_image=train_images[sample_idx],
electrode_grid=train_grids[sample_idx],
spike_readout=bio_train[sample_idx],
)
# ── 7. Ablation study ────────────────────────────────────────────────────
if not args.skip_ablation:
ablation_results = run_ablation(
bio_train, bio_test, train_labels, test_labels, epochs=args.epochs
)
print("\nAblation results:")
for k, v in ablation_results.items():
print(f" {k}: {v*100:.1f}%")
print(" Chance: 10.0%")
# ── 8. Save summary ──────────────────────────────────────────────────────
final_baseline = baseline_test[-1] * 100
final_bio = bio_test[-1] * 100
summary = {
"n_samples": args.n_samples,
"epochs": args.epochs,
"baseline_final_acc": final_baseline,
"bio_final_acc": final_bio,
"improvement": final_bio - final_baseline,
"baseline_curve": baseline_test,
"bio_curve": bio_test,
}
import json
with open("outputs/results.json", "w") as f:
json.dump(summary, f, indent=2)
print(f"\n{'='*50}")
print(f"RESULTS SUMMARY")
print(f"{'='*50}")
print(f"Baseline final accuracy: {final_baseline:.1f}%")
print(f"Bio-preprocessed accuracy: {final_bio:.1f}%")
print(f"Improvement: {final_bio - final_baseline:+.1f}%")
print(f"(TBC paper reported: +4.7%)")
print(f"{'='*50}")
print(f"\nOutputs saved to ./outputs/")
if __name__ == "__main__":
main()