| """ |
| Exp31: Biological Initialization vs Random |
| ============================================= |
| |
| Compares four initialization strategies: |
| 1. V28-Random: Default scalar mu=0.4 (current baseline) |
| 2. V28-Allen: Heterogeneous mu/sigma/crystal from Allen Cell Types |
| 3. V28-MICrONs: h_phys initialized with connectome eigenvectors |
| 4. V28-Full-Bio: Allen + MICrONs combined |
| |
| Metrics: |
| - Epochs to 90% accuracy |
| - Effective dimension of h_phys (covariance matrix rank) |
| - T_mean, h_bimodal, entropy curves |
| - Representational richness (std across dimensions) |
| |
| Requires: dataset files generated by fetch_microns.py and fetch_allen_celltypes.py |
| """ |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import json |
| from datetime import datetime |
| from pathlib import Path |
|
|
| from SKYNET_V28_PHYSICAL_CYBORG import SKYNET_V28_PHYSICAL_CYBORG |
| from bio_initializer import load_bio_params, get_microns_init_template, get_spectral_modulation |
|
|
|
|
| LOG_DIR = Path(__file__).parent |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| N_CLASSES = 8 |
| N_INPUT = 658 |
| N_EPOCHS = 150 |
| N_TRAIN = 1000 |
| N_TEST = 200 |
|
|
|
|
| def create_model(config_name, device=DEVICE): |
| """Create V28 model with specified initialization.""" |
| bio_params = None |
|
|
| if config_name == 'Random': |
| pass |
|
|
| elif config_name == 'Allen': |
| bp = load_bio_params() |
| bio_params = { |
| 'mu': bp['mu'], |
| 'sigma': bp['sigma'], |
| 'crystal_strength': bp['crystal_strength'], |
| 'lambda_base': bp['lambda_base'], |
| } |
|
|
| elif config_name == 'MICrONs': |
| template = get_microns_init_template() |
| |
| |
| d_state = 64 |
| bio_params = { |
| 'mu': torch.full((d_state,), 0.4), |
| 'sigma': torch.full((d_state,), 0.3), |
| 'crystal_strength': torch.full((d_state,), 1.0), |
| 'lambda_base': torch.full((d_state,), 0.02), |
| 'init_template': template |
| } |
|
|
| elif config_name == 'Full-Bio': |
| bp = load_bio_params() |
| template = get_microns_init_template() |
| bio_params = { |
| 'mu': bp['mu'], |
| 'sigma': bp['sigma'], |
| 'crystal_strength': bp['crystal_strength'], |
| 'lambda_base': bp['lambda_base'], |
| 'init_template': template, |
| } |
|
|
| model = SKYNET_V28_PHYSICAL_CYBORG( |
| n_input=N_INPUT, n_actions=N_CLASSES, device=device, |
| bio_params=bio_params |
| ).to(device) |
|
|
| return model |
|
|
|
|
| def generate_dataset(n_samples, seed=42, centroids=None): |
| """Generate a HARD BUT SOLVABLE classification dataset (Iter 2).""" |
| torch.manual_seed(seed) |
| data = [] |
| |
| if centroids is None: |
| |
| centroids = torch.randn(N_CLASSES, N_INPUT) * 1.0 |
| |
| for _ in range(n_samples): |
| label = torch.randint(0, N_CLASSES, (1,)).item() |
| |
| x = centroids[label] + 1.5 * torch.randn(N_INPUT) |
| data.append((x, label)) |
| return data, centroids |
|
|
|
|
| def compute_effective_dimension(h_phys_samples): |
| """Compute effective dimension from covariance eigenvalues.""" |
| if len(h_phys_samples) < 2: |
| return 1.0 |
| H = torch.stack(h_phys_samples) |
| H = H - H.mean(dim=0, keepdim=True) |
| cov = (H.T @ H) / (H.shape[0] - 1) |
| eigenvalues = torch.linalg.eigvalsh(cov) |
| eigenvalues = eigenvalues.clamp(min=0) |
| |
| total = eigenvalues.sum() |
| if total < 1e-8: |
| return 1.0 |
| pr = (total ** 2) / (eigenvalues ** 2).sum() |
| return pr.item() |
|
|
|
|
| def train_and_evaluate(config_name): |
| """Train model and collect metrics.""" |
| print(f"\n Training {config_name}...") |
|
|
| model = create_model(config_name) |
| batch_size = 32 |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| criterion = nn.CrossEntropyLoss() |
|
|
| train_data, centroids = generate_dataset(N_TRAIN, seed=42) |
| train_X = torch.stack([x for x, l in train_data]).to(DEVICE) |
| train_Y = torch.tensor([l for x, l in train_data]).to(DEVICE) |
|
|
| test_data, _ = generate_dataset(N_TEST, seed=123, centroids=centroids) |
|
|
| metrics = { |
| 'loss': [], |
| 'accuracy': [], |
| 'T_mean': [], |
| 'h_bimodal': [], |
| 'entropy': [], |
| 'h_std': [], |
| 'eff_dim': [], |
| } |
|
|
| epochs_to_90 = N_EPOCHS |
|
|
| for epoch in range(N_EPOCHS): |
| model.train() |
| epoch_loss = 0 |
| correct = 0 |
| h_phys_samples = [] |
|
|
| |
| perm = torch.randperm(N_TRAIN) |
| X_sh = train_X[perm] |
| Y_sh = train_Y[perm] |
|
|
| for i in range(0, N_TRAIN, batch_size): |
| model.reset() |
| x_batch = X_sh[i:i+batch_size] |
| y_batch = Y_sh[i:i+batch_size] |
|
|
| out = model(x_batch, training=True) |
| loss = criterion(out['logits'], y_batch) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| model.detach_states() |
|
|
| epoch_loss += loss.item() * x_batch.shape[0] |
| correct += (out['logits'].argmax(dim=-1) == y_batch).sum().item() |
| h_phys_samples.append(model.organ.h_phys.detach().cpu()) |
|
|
| acc = correct / N_TRAIN * 100 |
| metrics['loss'].append(epoch_loss / N_TRAIN) |
| metrics['accuracy'].append(acc) |
| metrics['T_mean'].append(out['audit']['T_mean']) |
| metrics['h_bimodal'].append(out['audit']['h_bimodal']) |
| metrics['entropy'].append(out['audit']['entropy']) |
|
|
| h_stack = torch.cat(h_phys_samples, dim=0) |
| metrics['h_std'].append(h_stack.std(dim=0).mean().item()) |
|
|
| if epoch % 10 == 0: |
| eff_dim = compute_effective_dimension(list(h_stack[-50:])) |
| metrics['eff_dim'].append(eff_dim) |
|
|
| if acc >= 90 and epochs_to_90 == N_EPOCHS: |
| epochs_to_90 = epoch + 1 |
|
|
| if (epoch + 1) % 30 == 0: |
| print(f" Epoch {epoch+1}: acc={acc:.1f}%, " |
| f"T={out['audit']['T_mean']:.3f}") |
|
|
| |
| model.eval() |
| test_correct = 0 |
| for x, label in test_data: |
| model.reset() |
| with torch.no_grad(): |
| out = model(x.unsqueeze(0).to(DEVICE), training=False) |
| if out['logits'].argmax().item() == label: |
| test_correct += 1 |
|
|
| test_acc = test_correct / len(test_data) * 100 |
|
|
| result = { |
| 'config': config_name, |
| 'epochs_to_90': epochs_to_90, |
| 'final_train_acc': metrics['accuracy'][-1], |
| 'test_acc': test_acc, |
| 'final_T_mean': metrics['T_mean'][-1], |
| 'final_h_bimodal': metrics['h_bimodal'][-1], |
| 'final_entropy': metrics['entropy'][-1], |
| 'final_h_std': metrics['h_std'][-1], |
| 'final_eff_dim': metrics['eff_dim'][-1] if metrics['eff_dim'] else 0, |
| 'curves': { |
| 'loss': metrics['loss'], |
| 'accuracy': metrics['accuracy'], |
| 'T_mean': metrics['T_mean'], |
| 'h_bimodal': metrics['h_bimodal'], |
| 'entropy': metrics['entropy'], |
| 'h_std': metrics['h_std'], |
| } |
| } |
|
|
| print(f" => {config_name}: " |
| f"train={metrics['accuracy'][-1]:.1f}%, " |
| f"test={test_acc:.1f}%, " |
| f"ep90={epochs_to_90}, " |
| f"eff_dim={result['final_eff_dim']:.1f}") |
|
|
| return result |
|
|
|
|
| def save_results(results): |
| """Save and plot results.""" |
| log_path = LOG_DIR / 'exp31_bio_initialization.log' |
|
|
| report = { |
| 'experiment': 'Exp31: Bio-Initialization vs Random', |
| 'timestamp': datetime.now().isoformat(), |
| 'device': DEVICE, |
| 'results': {r['config']: {k: v for k, v in r.items() if k != 'curves'} for r in results}, |
| } |
|
|
| with open(log_path, 'w') as f: |
| f.write(json.dumps(report, indent=2, default=str)) |
| print(f"\n[SAVED] {log_path}") |
|
|
| |
| try: |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| colors = ['#2196F3', '#4CAF50', '#FF9800', '#E91E63'] |
| configs = [r['config'] for r in results] |
|
|
| fig, axes = plt.subplots(2, 3, figsize=(18, 10)) |
| fig.suptitle('Exp31: Biological Initialization vs Random', fontsize=14) |
|
|
| |
| ax = axes[0, 0] |
| for r, c in zip(results, colors): |
| ax.plot(r['curves']['accuracy'], color=c, label=r['config']) |
| ax.axhline(y=90, color='gray', linestyle='--', alpha=0.5) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('Accuracy (%)') |
| ax.set_title('Training Accuracy') |
| ax.legend() |
|
|
| |
| ax = axes[0, 1] |
| for r, c in zip(results, colors): |
| ax.plot(r['curves']['loss'], color=c, label=r['config']) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('Loss') |
| ax.set_title('Training Loss') |
| ax.legend() |
|
|
| |
| ax = axes[0, 2] |
| for r, c in zip(results, colors): |
| ax.plot(r['curves']['T_mean'], color=c, label=r['config']) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('T_mean') |
| ax.set_title('Temperature Evolution') |
| ax.legend() |
|
|
| |
| ax = axes[1, 0] |
| for r, c in zip(results, colors): |
| ax.plot(r['curves']['h_bimodal'], color=c, label=r['config']) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('h_bimodal') |
| ax.set_title('Bimodal Index') |
| ax.legend() |
|
|
| |
| ax = axes[1, 1] |
| for r, c in zip(results, colors): |
| ax.plot(r['curves']['h_std'], color=c, label=r['config']) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('h_std (across dims)') |
| ax.set_title('Representational Richness') |
| ax.legend() |
|
|
| |
| ax = axes[1, 2] |
| x = np.arange(len(configs)) |
| width = 0.35 |
| ep90 = [r['epochs_to_90'] for r in results] |
| test_acc = [r['test_acc'] for r in results] |
| ax.bar(x - width/2, ep90, width, label='Epochs to 90%', color=colors) |
| ax2 = ax.twinx() |
| ax2.bar(x + width/2, test_acc, width, label='Test Acc (%)', |
| color=[c + '80' for c in colors], alpha=0.7) |
| ax.set_xticks(x) |
| ax.set_xticklabels(configs, rotation=15) |
| ax.set_ylabel('Epochs to 90%') |
| ax2.set_ylabel('Test Accuracy (%)') |
| ax.set_title('Summary') |
| ax.legend(loc='upper left') |
| ax2.legend(loc='upper right') |
|
|
| plt.tight_layout() |
| png_path = LOG_DIR / 'exp31_bio_initialization.png' |
| plt.savefig(png_path, dpi=150) |
| print(f"[SAVED] {png_path}") |
| plt.close() |
| except ImportError: |
| print("[SKIP] matplotlib not available for plotting") |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print("EXP31: BIOLOGICAL INITIALIZATION vs RANDOM") |
| print("=" * 60) |
|
|
| configs = ['Random', 'Allen', 'MICrONs', 'Full-Bio'] |
| results = [] |
|
|
| for config in configs: |
| result = train_and_evaluate(config) |
| results.append(result) |
|
|
| save_results(results) |
|
|
| |
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print(f"{'Config':<12} {'Train%':>8} {'Test%':>8} {'Ep90':>6} {'EffDim':>8} {'h_std':>8}") |
| print("-" * 60) |
| for r in results: |
| print(f"{r['config']:<12} {r['final_train_acc']:>7.1f}% " |
| f"{r['test_acc']:>7.1f}% " |
| f"{r['epochs_to_90']:>6d} " |
| f"{r['final_eff_dim']:>8.1f} " |
| f"{r['final_h_std']:>8.4f}") |
| print("=" * 60) |
|
|