Text Classification
English
code
FurkanNar commited on
Commit
3169a72
·
verified ·
1 Parent(s): f974a69

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +111 -0
train.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for Spatial Context Networks (SCN).
3
+
4
+ Example usage:
5
+ python train.py --input_dim 10 --n_neurons 32 --output_dim 4 --epochs 50
6
+ """
7
+
8
+ import argparse
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+
14
+ from spatial_context_networks import SpatialContextNetwork
15
+
16
+
17
+ def make_synthetic_dataset(n_samples=256, input_dim=10, output_dim=4, seed=42):
18
+ """Creates a simple synthetic classification dataset for demonstration."""
19
+ torch.manual_seed(seed)
20
+ X = torch.randn(n_samples, input_dim)
21
+ y = torch.randint(0, output_dim, (n_samples,))
22
+ return TensorDataset(X, y)
23
+
24
+
25
+ def train(args):
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Training on: {device}")
28
+
29
+ # Data
30
+ dataset = make_synthetic_dataset(
31
+ n_samples=args.n_samples,
32
+ input_dim=args.input_dim,
33
+ output_dim=args.output_dim,
34
+ )
35
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
36
+
37
+ # Model
38
+ model = SpatialContextNetwork(
39
+ input_dim=args.input_dim,
40
+ n_neurons=args.n_neurons,
41
+ output_dim=args.output_dim,
42
+ routing_threshold=args.routing_threshold,
43
+ stability_factor=args.stability_factor,
44
+ explosion_threshold=args.explosion_threshold,
45
+ ).to(device)
46
+
47
+ total_params = sum(p.numel() for p in model.parameters())
48
+ print(f"Model parameters: {total_params}")
49
+ print(model)
50
+
51
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
52
+ criterion = nn.CrossEntropyLoss()
53
+
54
+ for epoch in range(1, args.epochs + 1):
55
+ model.train()
56
+ total_loss = 0.0
57
+ correct = 0
58
+
59
+ for X_batch, y_batch in loader:
60
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
61
+ optimizer.zero_grad()
62
+ logits = model(X_batch)
63
+ loss = criterion(logits, y_batch)
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ total_loss += loss.item() * len(X_batch)
68
+ correct += (logits.argmax(dim=-1) == y_batch).sum().item()
69
+
70
+ avg_loss = total_loss / len(dataset)
71
+ accuracy = correct / len(dataset)
72
+
73
+ if epoch % 10 == 0 or epoch == 1:
74
+ # Network efficiency stats
75
+ model.eval()
76
+ with torch.no_grad():
77
+ sample_x = torch.randn(args.batch_size, args.input_dim).to(device)
78
+ stats = model.get_network_stats(sample_x)
79
+ print(
80
+ f"Epoch {epoch:3d}/{args.epochs} | "
81
+ f"Loss: {avg_loss:.4f} | Acc: {accuracy:.3f} | "
82
+ f"Active neurons: {stats['mean_active_neurons']:.1f}/{args.n_neurons} "
83
+ f"(eff={stats['network_efficiency']:.2f}) | "
84
+ f"Context score: {stats['mean_context_score']:.3f}"
85
+ )
86
+
87
+ if args.save_path:
88
+ torch.save(model.state_dict(), args.save_path)
89
+ print(f"\nModel saved to {args.save_path}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(description="Train a Spatial Context Network")
94
+
95
+ # Architecture
96
+ parser.add_argument("--input_dim", type=int, default=10)
97
+ parser.add_argument("--n_neurons", type=int, default=32)
98
+ parser.add_argument("--output_dim", type=int, default=4)
99
+ parser.add_argument("--routing_threshold", type=float, default=0.5)
100
+ parser.add_argument("--stability_factor", type=float, default=10.0)
101
+ parser.add_argument("--explosion_threshold", type=float, default=2.0)
102
+
103
+ # Training
104
+ parser.add_argument("--epochs", type=int, default=50)
105
+ parser.add_argument("--batch_size", type=int, default=8)
106
+ parser.add_argument("--lr", type=float, default=1e-3)
107
+ parser.add_argument("--n_samples", type=int, default=256)
108
+ parser.add_argument("--save_path", type=str, default=None)
109
+
110
+ args = parser.parse_args()
111
+ train(args)