Spaces:
Runtime error
Runtime error
dlokesha
Part 2 final: local plasticity mean rank 37.1 β continual backprop rank 33.1 confirms reactive methods worse than standard
c3a6dfa | """ | |
| train.py β Run the full TBC experiment comparison. | |
| Usage: | |
| python train.py # full run | |
| python train.py --n_samples 500 # quick smoke test | |
| python train.py --epochs 5 # fast iteration | |
| What this does: | |
| 1. Loads MNIST, binarizes it, resizes to 64x64 | |
| 2. Runs all images through the BioPreprocessor (reservoir layer) | |
| 3. Trains BaselineCNN on raw images | |
| 4. Trains BioCNN on bio-processed spike rates | |
| 5. Plots accuracy curves β should show bio beats baseline | |
| 6. Saves results to outputs/ | |
| """ | |
| import argparse | |
| import os | |
| import pickle | |
| import time | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torchvision import datasets, transforms | |
| from models import AblationCNN, BaselineCNN, BioCNN | |
| from reservoir import BioPreprocessor, MEAEncoder | |
| def load_mnist(n_samples: int = 2000, data_dir: str = "./data"): | |
| """Load MNIST and binarize β matching TBC's preprocessing exactly.""" | |
| print("Loading MNIST...") | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| train_data = datasets.MNIST(data_dir, train=True, download=True, transform=transform) | |
| test_data = datasets.MNIST(data_dir, train=False, download=True, transform=transform) | |
| # Subsample | |
| train_images = train_data.data[:n_samples].numpy() / 255.0 | |
| train_labels = train_data.targets[:n_samples].numpy() | |
| test_images = test_data.data[: n_samples // 5].numpy() / 255.0 | |
| test_labels = test_data.targets[: n_samples // 5].numpy() | |
| # Binarize: white digits on black background | |
| train_images = (train_images > 0.5).astype(float) | |
| test_images = (test_images > 0.5).astype(float) | |
| return train_images, train_labels, test_images, test_labels | |
| def encode_to_grid(images: np.ndarray, grid_size: int = 64) -> np.ndarray: | |
| """Resize binarized MNIST to 64Γ64 electrode grid.""" | |
| encoder = MEAEncoder(grid_size=grid_size) | |
| encoded = np.array([encoder.encode(img) for img in images]) | |
| return encoded # (N, 64, 64) | |
| def run_bio_preprocessing(images: np.ndarray, cache_path: str = None) -> np.ndarray: | |
| """ | |
| Run all images through the reservoir. | |
| Caches results to disk β reservoir is slow, don't rerun unnecessarily. | |
| """ | |
| if cache_path and os.path.exists(cache_path): | |
| print(f"Loading cached bio representations from {cache_path}") | |
| with open(cache_path, "rb") as f: | |
| return pickle.load(f) | |
| print("Running bio preprocessing (this takes a few minutes)...") | |
| t0 = time.time() | |
| preprocessor = BioPreprocessor(n_reservoir_units=1024, grid_size=64, steps=10) | |
| bio_reps = preprocessor.process_batch(images) | |
| elapsed = time.time() - t0 | |
| print(f"Done in {elapsed:.1f}s β {len(images)/elapsed:.1f} images/sec") | |
| if cache_path: | |
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |
| with open(cache_path, "wb") as f: | |
| pickle.dump(bio_reps, f) | |
| print(f"Cached to {cache_path}") | |
| return bio_reps # (N, 1024) | |
| def train_model(model, train_loader, test_loader, epochs: int, lr: float = 1e-3, label: str = ""): | |
| """Train a model and return per-epoch accuracy on both splits.""" | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| criterion = nn.CrossEntropyLoss() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| train_accs, test_accs = [], [] | |
| for epoch in range(epochs): | |
| # Train | |
| model.train() | |
| for x_batch, y_batch in train_loader: | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| optimizer.zero_grad() | |
| logits = model(x_batch) | |
| loss = criterion(logits, y_batch) | |
| loss.backward() | |
| optimizer.step() | |
| # Eval train | |
| model.eval() | |
| with torch.no_grad(): | |
| correct = total = 0 | |
| for x_batch, y_batch in train_loader: | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| preds = model(x_batch).argmax(1) | |
| correct += (preds == y_batch).sum().item() | |
| total += len(y_batch) | |
| train_accs.append(correct / total) | |
| # Eval test | |
| correct = total = 0 | |
| for x_batch, y_batch in test_loader: | |
| x_batch, y_batch = x_batch.to(device), y_batch.to(device) | |
| preds = model(x_batch).argmax(1) | |
| correct += (preds == y_batch).sum().item() | |
| total += len(y_batch) | |
| test_accs.append(correct / total) | |
| print(f" [{label}] Epoch {epoch+1}/{epochs} β test acc: {test_accs[-1]*100:.1f}%") | |
| return train_accs, test_accs | |
| def run_ablation( | |
| bio_train: np.ndarray, | |
| bio_test: np.ndarray, | |
| train_labels: np.ndarray, | |
| test_labels: np.ndarray, | |
| n_units: int = 1024, | |
| epochs: int = 10, | |
| ): | |
| """ | |
| Ablation study replicating TBC Figure 4 (right panel). | |
| Splits reservoir units into: | |
| - Whole: all 1024 units | |
| - Center: first 512 units (approximates directly stimulated region) | |
| - Periphery: last 512 units (approximates spread-beyond-stimulation) | |
| TBC result: Whole β Center >> Periphery > chance (10%) | |
| All three should beat chance β proving bio spread carries real info. | |
| """ | |
| print("\nRunning ablation study...") | |
| results = {} | |
| splits = { | |
| "Whole": (0, n_units), | |
| "Center": (0, n_units // 2), | |
| "Periphery": (n_units // 2, n_units), | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| for name, (start, end) in splits.items(): | |
| subset_train = bio_train[:, start:end] | |
| subset_test = bio_test[:, start:end] | |
| input_dim = end - start | |
| train_ds = TensorDataset( | |
| torch.FloatTensor(subset_train), torch.LongTensor(train_labels) | |
| ) | |
| test_ds = TensorDataset( | |
| torch.FloatTensor(subset_test), torch.LongTensor(test_labels) | |
| ) | |
| train_loader = DataLoader(train_ds, batch_size=64, shuffle=True) | |
| test_loader = DataLoader(test_ds, batch_size=64) | |
| model = AblationCNN(input_dim=input_dim) | |
| _, test_accs = train_model(model, train_loader, test_loader, epochs, label=name) | |
| results[name] = test_accs[-1] | |
| print(f" {name}: {test_accs[-1]*100:.1f}%") | |
| return results | |
| def plot_results(baseline_test, bio_test, save_path: str = "outputs/accuracy_curves.png"): | |
| """ | |
| Reproduce TBC Figure 4 (left panel): | |
| Bio accuracy curve should be higher and converge faster than baseline. | |
| """ | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| epochs = range(1, len(baseline_test) + 1) | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| fig.suptitle("TBC Replication: Biological Preprocessing vs Baseline", fontsize=14) | |
| # Accuracy curves | |
| ax = axes[0] | |
| ax.plot(epochs, [a * 100 for a in baseline_test], "gold", linewidth=2, label="Baseline (raw MNIST)") | |
| ax.plot(epochs, [a * 100 for a in bio_test], "#1D9E75", linewidth=2, label="Bio-preprocessed") | |
| ax.axhline(y=10, color="gray", linestyle="--", alpha=0.5, label="Chance (10%)") | |
| ax.set_xlabel("Training epoch") | |
| ax.set_ylabel("Test accuracy (%)") | |
| ax.set_title("Classification accuracy over training") | |
| ax.legend() | |
| ax.set_ylim(0, 100) | |
| ax.grid(True, alpha=0.3) | |
| # Final accuracy bar | |
| ax = axes[1] | |
| final_baseline = baseline_test[-1] * 100 | |
| final_bio = bio_test[-1] * 100 | |
| bars = ax.bar( | |
| ["Baseline", "Bio-preprocessed"], | |
| [final_baseline, final_bio], | |
| color=["gold", "#1D9E75"], | |
| width=0.5, | |
| ) | |
| ax.bar_label(bars, fmt="%.1f%%", padding=3) | |
| ax.set_ylabel("Final test accuracy (%)") | |
| ax.set_title(f"Final accuracy (Ξ = {final_bio - final_baseline:+.1f}%)") | |
| ax.set_ylim(0, 100) | |
| ax.axhline(y=10, color="gray", linestyle="--", alpha=0.5, label="Chance") | |
| ax.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| print(f"Saved accuracy plot to {save_path}") | |
| def plot_activation_spread( | |
| raw_image: np.ndarray, | |
| electrode_grid: np.ndarray, | |
| spike_readout: np.ndarray, | |
| save_path: str = "outputs/activation_spread.png", | |
| ): | |
| """ | |
| Visualize TBC Figure 3: | |
| Shows how activity spreads beyond the stimulated region. | |
| Raw digit β electrode grid β reservoir spread. | |
| """ | |
| from reservoir import BioPreprocessor | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| # Get spatial readout | |
| preprocessor = BioPreprocessor(n_reservoir_units=1024) | |
| spatial = preprocessor.reservoir.get_spatial_readout(spike_readout, grid_size=64) | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| fig.suptitle("Neural activity spread beyond stimulation region", fontsize=13) | |
| axes[0].imshow(raw_image, cmap="gray") | |
| axes[0].set_title("Original MNIST digit") | |
| axes[0].axis("off") | |
| axes[1].imshow(electrode_grid, cmap="gray") | |
| axes[1].set_title("Electrode grid (stimulation pattern)") | |
| axes[1].axis("off") | |
| # Red box showing original 28Γ28 digit area | |
| rect = plt.Rectangle( | |
| (18, 18), 28, 28, linewidth=2, edgecolor="red", facecolor="none" | |
| ) | |
| axes[1].add_patch(rect) | |
| axes[1].text(22, 15, "28Γ28 digit area", color="red", fontsize=8) | |
| axes[2].imshow(spatial, cmap="hot") | |
| axes[2].set_title("Reservoir activity spread") | |
| axes[2].axis("off") | |
| rect2 = plt.Rectangle( | |
| (18, 18), 28, 28, linewidth=2, edgecolor="cyan", facecolor="none" | |
| ) | |
| axes[2].add_patch(rect2) | |
| axes[2].text(22, 15, "original region", color="cyan", fontsize=8) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| print(f"Saved activation spread plot to {save_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--n_samples", type=int, default=2000, help="Training samples (500 for smoke test)") | |
| parser.add_argument("--epochs", type=int, default=10, help="Training epochs") | |
| parser.add_argument("--skip_ablation", action="store_true", help="Skip ablation study") | |
| args = parser.parse_args() | |
| os.makedirs("outputs", exist_ok=True) | |
| # ββ 1. Load and encode data βββββββββββββββββββββββββββββββββββββββββββββ | |
| train_images, train_labels, test_images, test_labels = load_mnist(args.n_samples) | |
| print("Encoding images to electrode grid...") | |
| train_grids = encode_to_grid(train_images) # (N, 64, 64) | |
| test_grids = encode_to_grid(test_images) | |
| # ββ 2. Bio preprocessing ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| bio_train = run_bio_preprocessing( | |
| train_grids, cache_path=f"outputs/bio_train_{args.n_samples}.pkl" | |
| ) | |
| bio_test = run_bio_preprocessing( | |
| test_grids, cache_path=f"outputs/bio_test_{args.n_samples//5}.pkl" | |
| ) | |
| # ββ 3. Build dataloaders βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Baseline: raw 64Γ64 grids as images | |
| baseline_train_ds = TensorDataset( | |
| torch.FloatTensor(train_grids).unsqueeze(1), # (N, 1, 64, 64) | |
| torch.LongTensor(train_labels), | |
| ) | |
| baseline_test_ds = TensorDataset( | |
| torch.FloatTensor(test_grids).unsqueeze(1), | |
| torch.LongTensor(test_labels), | |
| ) | |
| # Bio: spike rate vectors | |
| bio_train_ds = TensorDataset( | |
| torch.FloatTensor(bio_train), # (N, 1024) | |
| torch.LongTensor(train_labels), | |
| ) | |
| bio_test_ds = TensorDataset( | |
| torch.FloatTensor(bio_test), | |
| torch.LongTensor(test_labels), | |
| ) | |
| batch_size = 64 | |
| baseline_train_loader = DataLoader(baseline_train_ds, batch_size=batch_size, shuffle=True) | |
| baseline_test_loader = DataLoader(baseline_test_ds, batch_size=batch_size) | |
| bio_train_loader = DataLoader(bio_train_ds, batch_size=batch_size, shuffle=True) | |
| bio_test_loader = DataLoader(bio_test_ds, batch_size=batch_size) | |
| # ββ 4. Train both models βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nTraining baseline CNN on raw MNIST...") | |
| baseline_model = BaselineCNN() | |
| _, baseline_test = train_model( | |
| baseline_model, baseline_train_loader, baseline_test_loader, | |
| args.epochs, label="Baseline" | |
| ) | |
| print("\nTraining Bio CNN on reservoir spike rates...") | |
| bio_model = BioCNN() | |
| _, bio_test = train_model( | |
| bio_model, bio_train_loader, bio_test_loader, | |
| args.epochs, label="Bio" | |
| ) | |
| # ββ 5. Plot accuracy curves ββββββββββββββββββββββββββββββββββββββββββββββ | |
| plot_results(baseline_test, bio_test) | |
| # ββ 6. Activation spread visualization ββββββββββββββββββββββββββββββββββ | |
| print("\nGenerating activation spread visualization...") | |
| sample_idx = 0 | |
| plot_activation_spread( | |
| raw_image=train_images[sample_idx], | |
| electrode_grid=train_grids[sample_idx], | |
| spike_readout=bio_train[sample_idx], | |
| ) | |
| # ββ 7. Ablation study ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.skip_ablation: | |
| ablation_results = run_ablation( | |
| bio_train, bio_test, train_labels, test_labels, epochs=args.epochs | |
| ) | |
| print("\nAblation results:") | |
| for k, v in ablation_results.items(): | |
| print(f" {k}: {v*100:.1f}%") | |
| print(" Chance: 10.0%") | |
| # ββ 8. Save summary ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| final_baseline = baseline_test[-1] * 100 | |
| final_bio = bio_test[-1] * 100 | |
| summary = { | |
| "n_samples": args.n_samples, | |
| "epochs": args.epochs, | |
| "baseline_final_acc": final_baseline, | |
| "bio_final_acc": final_bio, | |
| "improvement": final_bio - final_baseline, | |
| "baseline_curve": baseline_test, | |
| "bio_curve": bio_test, | |
| } | |
| import json | |
| with open("outputs/results.json", "w") as f: | |
| json.dump(summary, f, indent=2) | |
| print(f"\n{'='*50}") | |
| print(f"RESULTS SUMMARY") | |
| print(f"{'='*50}") | |
| print(f"Baseline final accuracy: {final_baseline:.1f}%") | |
| print(f"Bio-preprocessed accuracy: {final_bio:.1f}%") | |
| print(f"Improvement: {final_bio - final_baseline:+.1f}%") | |
| print(f"(TBC paper reported: +4.7%)") | |
| print(f"{'='*50}") | |
| print(f"\nOutputs saved to ./outputs/") | |
| if __name__ == "__main__": | |
| main() | |