File size: 3,842 Bytes
3169a72
 
 
 
 
 
 
 
 
 
 
 
 
f17d188
3169a72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Training script for Spatial Context Networks (SCN).

Example usage:
    python train.py --input_dim 10 --n_neurons 32 --output_dim 4 --epochs 50
"""

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from model import SpatialContextNetwork


def make_synthetic_dataset(n_samples=256, input_dim=10, output_dim=4, seed=42):
    """Creates a simple synthetic classification dataset for demonstration."""
    torch.manual_seed(seed)
    X = torch.randn(n_samples, input_dim)
    y = torch.randint(0, output_dim, (n_samples,))
    return TensorDataset(X, y)


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}")

    # Data
    dataset = make_synthetic_dataset(
        n_samples=args.n_samples,
        input_dim=args.input_dim,
        output_dim=args.output_dim,
    )
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # Model
    model = SpatialContextNetwork(
        input_dim=args.input_dim,
        n_neurons=args.n_neurons,
        output_dim=args.output_dim,
        routing_threshold=args.routing_threshold,
        stability_factor=args.stability_factor,
        explosion_threshold=args.explosion_threshold,
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params}")
    print(model)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0.0
        correct = 0

        for X_batch, y_batch in loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            logits = model(X_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(X_batch)
            correct += (logits.argmax(dim=-1) == y_batch).sum().item()

        avg_loss = total_loss / len(dataset)
        accuracy = correct / len(dataset)

        if epoch % 10 == 0 or epoch == 1:
            # Network efficiency stats
            model.eval()
            with torch.no_grad():
                sample_x = torch.randn(args.batch_size, args.input_dim).to(device)
                stats = model.get_network_stats(sample_x)
            print(
                f"Epoch {epoch:3d}/{args.epochs} | "
                f"Loss: {avg_loss:.4f} | Acc: {accuracy:.3f} | "
                f"Active neurons: {stats['mean_active_neurons']:.1f}/{args.n_neurons} "
                f"(eff={stats['network_efficiency']:.2f}) | "
                f"Context score: {stats['mean_context_score']:.3f}"
            )

    if args.save_path:
        torch.save(model.state_dict(), args.save_path)
        print(f"\nModel saved to {args.save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a Spatial Context Network")

    # Architecture
    parser.add_argument("--input_dim", type=int, default=10)
    parser.add_argument("--n_neurons", type=int, default=32)
    parser.add_argument("--output_dim", type=int, default=4)
    parser.add_argument("--routing_threshold", type=float, default=0.5)
    parser.add_argument("--stability_factor", type=float, default=10.0)
    parser.add_argument("--explosion_threshold", type=float, default=2.0)

    # Training
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--n_samples", type=int, default=256)
    parser.add_argument("--save_path", type=str, default=None)

    args = parser.parse_args()
    train(args)