Create train.py
Browse files
train.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for Spatial Context Networks (SCN).
|
| 3 |
+
|
| 4 |
+
Example usage:
|
| 5 |
+
python train.py --input_dim 10 --n_neurons 32 --output_dim 4 --epochs 50
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 13 |
+
|
| 14 |
+
from spatial_context_networks import SpatialContextNetwork
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_synthetic_dataset(n_samples=256, input_dim=10, output_dim=4, seed=42):
|
| 18 |
+
"""Creates a simple synthetic classification dataset for demonstration."""
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
X = torch.randn(n_samples, input_dim)
|
| 21 |
+
y = torch.randint(0, output_dim, (n_samples,))
|
| 22 |
+
return TensorDataset(X, y)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def train(args):
|
| 26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
print(f"Training on: {device}")
|
| 28 |
+
|
| 29 |
+
# Data
|
| 30 |
+
dataset = make_synthetic_dataset(
|
| 31 |
+
n_samples=args.n_samples,
|
| 32 |
+
input_dim=args.input_dim,
|
| 33 |
+
output_dim=args.output_dim,
|
| 34 |
+
)
|
| 35 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
| 36 |
+
|
| 37 |
+
# Model
|
| 38 |
+
model = SpatialContextNetwork(
|
| 39 |
+
input_dim=args.input_dim,
|
| 40 |
+
n_neurons=args.n_neurons,
|
| 41 |
+
output_dim=args.output_dim,
|
| 42 |
+
routing_threshold=args.routing_threshold,
|
| 43 |
+
stability_factor=args.stability_factor,
|
| 44 |
+
explosion_threshold=args.explosion_threshold,
|
| 45 |
+
).to(device)
|
| 46 |
+
|
| 47 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 48 |
+
print(f"Model parameters: {total_params}")
|
| 49 |
+
print(model)
|
| 50 |
+
|
| 51 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
| 52 |
+
criterion = nn.CrossEntropyLoss()
|
| 53 |
+
|
| 54 |
+
for epoch in range(1, args.epochs + 1):
|
| 55 |
+
model.train()
|
| 56 |
+
total_loss = 0.0
|
| 57 |
+
correct = 0
|
| 58 |
+
|
| 59 |
+
for X_batch, y_batch in loader:
|
| 60 |
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
| 61 |
+
optimizer.zero_grad()
|
| 62 |
+
logits = model(X_batch)
|
| 63 |
+
loss = criterion(logits, y_batch)
|
| 64 |
+
loss.backward()
|
| 65 |
+
optimizer.step()
|
| 66 |
+
|
| 67 |
+
total_loss += loss.item() * len(X_batch)
|
| 68 |
+
correct += (logits.argmax(dim=-1) == y_batch).sum().item()
|
| 69 |
+
|
| 70 |
+
avg_loss = total_loss / len(dataset)
|
| 71 |
+
accuracy = correct / len(dataset)
|
| 72 |
+
|
| 73 |
+
if epoch % 10 == 0 or epoch == 1:
|
| 74 |
+
# Network efficiency stats
|
| 75 |
+
model.eval()
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
sample_x = torch.randn(args.batch_size, args.input_dim).to(device)
|
| 78 |
+
stats = model.get_network_stats(sample_x)
|
| 79 |
+
print(
|
| 80 |
+
f"Epoch {epoch:3d}/{args.epochs} | "
|
| 81 |
+
f"Loss: {avg_loss:.4f} | Acc: {accuracy:.3f} | "
|
| 82 |
+
f"Active neurons: {stats['mean_active_neurons']:.1f}/{args.n_neurons} "
|
| 83 |
+
f"(eff={stats['network_efficiency']:.2f}) | "
|
| 84 |
+
f"Context score: {stats['mean_context_score']:.3f}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if args.save_path:
|
| 88 |
+
torch.save(model.state_dict(), args.save_path)
|
| 89 |
+
print(f"\nModel saved to {args.save_path}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = argparse.ArgumentParser(description="Train a Spatial Context Network")
|
| 94 |
+
|
| 95 |
+
# Architecture
|
| 96 |
+
parser.add_argument("--input_dim", type=int, default=10)
|
| 97 |
+
parser.add_argument("--n_neurons", type=int, default=32)
|
| 98 |
+
parser.add_argument("--output_dim", type=int, default=4)
|
| 99 |
+
parser.add_argument("--routing_threshold", type=float, default=0.5)
|
| 100 |
+
parser.add_argument("--stability_factor", type=float, default=10.0)
|
| 101 |
+
parser.add_argument("--explosion_threshold", type=float, default=2.0)
|
| 102 |
+
|
| 103 |
+
# Training
|
| 104 |
+
parser.add_argument("--epochs", type=int, default=50)
|
| 105 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 106 |
+
parser.add_argument("--lr", type=float, default=1e-3)
|
| 107 |
+
parser.add_argument("--n_samples", type=int, default=256)
|
| 108 |
+
parser.add_argument("--save_path", type=str, default=None)
|
| 109 |
+
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
train(args)
|