Text Classification
English
code
FurkanNar commited on
Commit
d907b19
·
verified ·
1 Parent(s): 5b25cac

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +225 -0
model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spatial Context Networks (SCN)
3
+ Geometric Semantic Routing in Neural Architectures
4
+
5
+ Author: Furkan Nar
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+
14
+ class GeometricActivation(nn.Module):
15
+ """
16
+ Geometric activation function based on normalized Euclidean distance.
17
+
18
+ Each neuron acts as a point-mass with a learnable centroid in d-dimensional space.
19
+ Activation is inversely proportional to the normalized distance from the centroid:
20
+
21
+ f(v) = 1 / (||v - mu||_2 / sqrt(d) + epsilon)
22
+
23
+ Args:
24
+ n_neurons (int): Number of neurons (centroids) in this layer.
25
+ dim (int): Dimensionality of the input semantic space.
26
+ stability_factor (float): SF in the paper; epsilon = 1/SF. Default: 10.0
27
+ """
28
+
29
+ def __init__(self, n_neurons: int, dim: int, stability_factor: float = 10.0):
30
+ super().__init__()
31
+ self.n_neurons = n_neurons
32
+ self.dim = dim
33
+ self.epsilon = 1.0 / stability_factor
34
+
35
+ # Learnable centroids: shape (n_neurons, dim)
36
+ self.centroids = nn.Parameter(torch.randn(n_neurons, dim))
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Args:
41
+ x: Input tensor of shape (batch_size, dim)
42
+ Returns:
43
+ activations: Tensor of shape (batch_size, n_neurons)
44
+ """
45
+ # x: (B, dim) -> (B, 1, dim)
46
+ # centroids: (n_neurons, dim) -> (1, n_neurons, dim)
47
+ diff = x.unsqueeze(1) - self.centroids.unsqueeze(0) # (B, n_neurons, dim)
48
+ dist = torch.norm(diff, dim=-1) # (B, n_neurons)
49
+ normalized_dist = dist / math.sqrt(self.dim)
50
+ activations = 1.0 / (normalized_dist + self.epsilon)
51
+ return activations
52
+
53
+
54
+ class SemanticRoutingLayer(nn.Module):
55
+ """
56
+ Semantic routing layer that selectively activates neurons based on
57
+ geometric affinity to the input.
58
+
59
+ Active set: S = { n_i | f_i(q) > tau }
60
+ Binary mask: M_ij = I[ f_j(v_i) > tau ]
61
+
62
+ Args:
63
+ n_neurons (int): Number of neurons.
64
+ dim (int): Input dimensionality.
65
+ routing_threshold (float): Activation threshold tau. Default: 0.5
66
+ stability_factor (float): Passed to GeometricActivation. Default: 10.0
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ n_neurons: int,
72
+ dim: int,
73
+ routing_threshold: float = 0.5,
74
+ stability_factor: float = 10.0,
75
+ ):
76
+ super().__init__()
77
+ self.routing_threshold = routing_threshold
78
+ self.geo_activation = GeometricActivation(n_neurons, dim, stability_factor)
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ """
82
+ Args:
83
+ x: Input tensor of shape (batch_size, dim)
84
+ Returns:
85
+ activations: Raw activations, shape (batch_size, n_neurons)
86
+ mask: Binary routing mask, shape (batch_size, n_neurons)
87
+ """
88
+ activations = self.geo_activation(x)
89
+ mask = (activations > self.routing_threshold).float()
90
+ return activations, mask
91
+
92
+
93
+ class ConnectionDensityLayer(nn.Module):
94
+ """
95
+ Connection density weighting with adaptive scaling and explosion control.
96
+
97
+ C = sum_{i in S} w_i / (alpha / z)
98
+
99
+ where alpha = total neurons, z = |S| (active neurons).
100
+ When C > tau_exp, square-root damping is applied: C_stable = sqrt(C).
101
+
102
+ Args:
103
+ n_neurons (int): Total number of neurons (alpha).
104
+ explosion_threshold (float): tau_exp. Default: 2.0
105
+ """
106
+
107
+ def __init__(self, n_neurons: int, explosion_threshold: float = 2.0):
108
+ super().__init__()
109
+ self.n_neurons = n_neurons
110
+ self.explosion_threshold = explosion_threshold
111
+
112
+ # Learnable per-neuron connection weights
113
+ self.connection_weights = nn.Parameter(torch.randn(n_neurons))
114
+
115
+ def forward(self, activations: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
116
+ """
117
+ Args:
118
+ activations: Shape (batch_size, n_neurons)
119
+ mask: Binary mask, shape (batch_size, n_neurons)
120
+ Returns:
121
+ context: Scalar context score per sample, shape (batch_size, 1)
122
+ """
123
+ z = mask.sum(dim=-1, keepdim=True).clamp(min=1.0) # (B, 1)
124
+ alpha = float(self.n_neurons)
125
+
126
+ # Weighted masked activations
127
+ weighted = activations * mask * self.connection_weights.unsqueeze(0) # (B, n)
128
+ context = weighted.sum(dim=-1, keepdim=True) / (alpha / z) # (B, 1)
129
+
130
+ # Explosion control: sqrt damping
131
+ context = torch.where(
132
+ context > self.explosion_threshold,
133
+ torch.sqrt(context.abs() + 1e-8) * context.sign(),
134
+ context,
135
+ )
136
+ return context
137
+
138
+
139
+ class SpatialContextNetwork(nn.Module):
140
+ """
141
+ Spatial Context Network (SCN).
142
+
143
+ Full architecture:
144
+ 1. SemanticRoutingLayer — geometric activation + binary routing mask
145
+ 2. ConnectionDensityLayer — adaptive normalization + explosion control
146
+ 3. Linear projection — map context score to output space
147
+ 4. Pattern distribution — element-wise multiply by softmax(pattern_weights)
148
+
149
+ Args:
150
+ input_dim (int): Dimensionality of input features.
151
+ n_neurons (int): Number of hidden geometric neurons. Default: 32
152
+ output_dim (int): Number of output classes/dimensions. Default: 4
153
+ routing_threshold (float): Routing threshold tau. Default: 0.5
154
+ stability_factor (float): Controls epsilon = 1/SF. Default: 10.0
155
+ explosion_threshold (float): Threshold for sqrt damping. Default: 2.0
156
+
157
+ Example::
158
+
159
+ model = SpatialContextNetwork(input_dim=10, n_neurons=32, output_dim=4)
160
+ x = torch.randn(8, 10)
161
+ output = model(x) # (8, 4)
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ input_dim: int = 10,
167
+ n_neurons: int = 32,
168
+ output_dim: int = 4,
169
+ routing_threshold: float = 0.5,
170
+ stability_factor: float = 10.0,
171
+ explosion_threshold: float = 2.0,
172
+ ):
173
+ super().__init__()
174
+ self.input_dim = input_dim
175
+ self.n_neurons = n_neurons
176
+ self.output_dim = output_dim
177
+
178
+ self.routing = SemanticRoutingLayer(
179
+ n_neurons, input_dim, routing_threshold, stability_factor
180
+ )
181
+ self.density = ConnectionDensityLayer(n_neurons, explosion_threshold)
182
+ self.projection = nn.Linear(1, output_dim)
183
+
184
+ # Pattern prior weights (learnable)
185
+ self.pattern_weights = nn.Parameter(torch.zeros(output_dim))
186
+
187
+ # Initialise pattern weights to approximate the priors from the paper
188
+ # [Mathematics=0.38, Language=0.25, Vision=0.22, Reasoning=0.15]
189
+ with torch.no_grad():
190
+ prior = torch.tensor([0.38, 0.25, 0.22, 0.15])
191
+ if output_dim == 4:
192
+ self.pattern_weights.copy_(torch.log(prior + 1e-8))
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ """
196
+ Args:
197
+ x: Input tensor of shape (batch_size, input_dim)
198
+ Returns:
199
+ output: Tensor of shape (batch_size, output_dim)
200
+ """
201
+ activations, mask = self.routing(x)
202
+ context = self.density(activations, mask)
203
+ hidden = self.projection(context) # (B, output_dim)
204
+ output = hidden * F.softmax(self.pattern_weights, dim=-1)
205
+ return output
206
+
207
+ def get_network_stats(self, x: torch.Tensor) -> dict:
208
+ """
209
+ Returns diagnostic statistics for a batch of inputs.
210
+
211
+ Returns:
212
+ dict with keys: mean_active_neurons, network_efficiency,
213
+ mean_context_score, activations, mask
214
+ """
215
+ with torch.no_grad():
216
+ activations, mask = self.routing(x)
217
+ context = self.density(activations, mask)
218
+ active = mask.sum(dim=-1)
219
+ return {
220
+ "mean_active_neurons": active.mean().item(),
221
+ "network_efficiency": (active / self.n_neurons).mean().item(),
222
+ "mean_context_score": context.mean().item(),
223
+ "activations": activations,
224
+ "mask": mask,
225
+ }