AbstractPhil commited on
Commit
2fc8ef5
Β·
verified Β·
1 Parent(s): b91a2cb

Create geolip_core.py

Browse files
Files changed (1) hide show
  1. geolip_core.py +395 -0
geolip_core.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeoLIP Core β€” Geometric Building Blocks
3
+ ==========================================
4
+ All reusable geometric components. No losses, no training loops.
5
+
6
+ Components:
7
+ Activations: SquaredReLU, StarReLU, make_activation
8
+ Anchor Init: xavier, orthogonal, repulsion
9
+ Constellation: Triangulation on S^(d-1)
10
+ Patchwork: Round-robin compartmentalized interpretation
11
+ RelayLayer: Single constellation relay (vectorized, gated, no attention)
12
+ ConstellationRelay: Per-token geometric layer (O(S), 99.4% at depth 16)
13
+ MagnitudeFlow: Relay-stack per-compartment magnitude prediction
14
+ AnchorPush: Push strategies (raw, gru, momentum)
15
+ FlowAttention: ODE flow in tangent space (historical)
16
+
17
+ Usage:
18
+ from geolip_core import Constellation, Patchwork, MagnitudeFlow, AnchorPush
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import math
25
+
26
+ # ── ACTIVATIONS ──
27
+
28
+ class SquaredReLU(nn.Module):
29
+ def forward(self, x): return F.relu(x) ** 2
30
+
31
+ class StarReLU(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.scale = nn.Parameter(torch.ones(1) * 0.8944)
35
+ self.bias = nn.Parameter(torch.zeros(1) - 0.4472)
36
+ def forward(self, x): return F.relu(x) ** 2 * self.scale + self.bias
37
+
38
+ ACTIVATIONS = {
39
+ 'squared_relu': SquaredReLU, 'star_relu': StarReLU,
40
+ 'gelu': lambda: nn.GELU(), 'relu': lambda: nn.ReLU(), 'sigmoid': lambda: nn.Sigmoid(),
41
+ }
42
+
43
+ def make_activation(name='squared_relu'):
44
+ if name not in ACTIVATIONS:
45
+ raise ValueError(f"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}")
46
+ return ACTIVATIONS[name]()
47
+
48
+
49
+ # ── ANCHOR INITIALIZATION ──
50
+
51
+ def init_anchors_xavier(n, d):
52
+ w = torch.empty(n, d); nn.init.xavier_normal_(w); return F.normalize(w, dim=-1)
53
+
54
+ def init_anchors_orthogonal(n, d):
55
+ if n <= d:
56
+ Q, _ = torch.linalg.qr(torch.randn(d, n)); return Q.T.contiguous()
57
+ else:
58
+ Q, _ = torch.linalg.qr(torch.randn(d, d))
59
+ return torch.cat([Q.T, F.normalize(torch.randn(n - d, d), dim=-1)], dim=0)
60
+
61
+ def init_anchors_repulsion(n, d, iters=200, lr=0.05):
62
+ vecs = F.normalize(init_anchors_orthogonal(n, d), dim=-1)
63
+ for _ in range(iters):
64
+ sim = vecs @ vecs.T; sim.fill_diagonal_(-2.0)
65
+ vecs = F.normalize(vecs - lr * vecs[sim.argmax(dim=1)], dim=-1)
66
+ return vecs
67
+
68
+ INIT_METHODS = {'xavier': init_anchors_xavier, 'orthogonal': init_anchors_orthogonal, 'repulsion': init_anchors_repulsion}
69
+
70
+
71
+ # ── CONSTELLATION ──
72
+
73
+ class Constellation(nn.Module):
74
+ """Anchors on S^(d-1). Triangulates input embeddings."""
75
+ def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
76
+ super().__init__()
77
+ self.anchors = nn.Parameter(INIT_METHODS[anchor_init](n_anchors, dim))
78
+ self.anchor_drop = anchor_drop
79
+ self.n_anchors = n_anchors
80
+ self.dim = dim
81
+
82
+ def triangulate(self, emb, training=False):
83
+ anchors = F.normalize(self.anchors, dim=-1)
84
+ if training and self.anchor_drop > 0:
85
+ mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
86
+ if mask.sum() < 2: mask[:2] = True
87
+ anchors = anchors[mask]
88
+ cos = emb @ anchors.T; tri = 1.0 - cos
89
+ _, nl = cos.max(dim=-1)
90
+ return tri, mask.nonzero(as_tuple=True)[0][nl]
91
+ cos = emb @ anchors.T; tri = 1.0 - cos; _, nearest = cos.max(dim=-1)
92
+ return tri, nearest
93
+
94
+ def forward(self, emb, training=False): return self.triangulate(emb, training)
95
+
96
+
97
+ # ── PATCHWORK ──
98
+
99
+ class Patchwork(nn.Module):
100
+ """Round-robin compartments reading diverse anchor subsets."""
101
+ def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'):
102
+ super().__init__()
103
+ self.n_comp, self.d_comp = n_comp, d_comp
104
+ self.output_dim = n_comp * d_comp
105
+ self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
106
+ apc = n_anchors // n_comp
107
+ self.comps = nn.ModuleList([
108
+ nn.Sequential(nn.Linear(apc, d_comp*2), make_activation(activation),
109
+ nn.Linear(d_comp*2, d_comp), nn.LayerNorm(d_comp))
110
+ for _ in range(n_comp)])
111
+
112
+ def forward(self, tri):
113
+ return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], dim=-1)
114
+
115
+
116
+ # ── RELAY LAYER ──
117
+
118
+ class RelayLayer(nn.Module):
119
+ """Single constellation relay. Vectorized, gated, no attention.
120
+ Patches β†’ S^(patch_dim-1) β†’ triangulate at 3 SLERP phases β†’ patchwork β†’ gated residual."""
121
+ def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3, pw_hidden=32, gate_init=-3.0):
122
+ super().__init__()
123
+ assert input_dim % patch_dim == 0
124
+ self.input_dim, self.patch_dim = input_dim, patch_dim
125
+ self.n_patches = input_dim // patch_dim
126
+ self.n_anchors, self.n_phases = n_anchors, n_phases
127
+ P, A, d = self.n_patches, n_anchors, patch_dim
128
+
129
+ home = torch.empty(P, A, d); nn.init.xavier_normal_(home.view(P*A, d))
130
+ home = F.normalize(home.view(P, A, d), dim=-1)
131
+ self.register_buffer('home', home)
132
+ self.anchors = nn.Parameter(home.clone())
133
+
134
+ tri_dim = n_phases * A
135
+ self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
136
+ self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
137
+ self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
138
+ self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
139
+ for p in range(P):
140
+ nn.init.xavier_normal_(self.pw_w1.data[p])
141
+ nn.init.xavier_normal_(self.pw_w2.data[p])
142
+ self.pw_norm = nn.LayerNorm(d)
143
+ self.gates = nn.Parameter(torch.full((P,), gate_init))
144
+ self.norm = nn.LayerNorm(input_dim)
145
+
146
+ def drift(self):
147
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
148
+ return torch.acos((h * c).sum(dim=-1).clamp(-1+1e-7, 1-1e-7))
149
+
150
+ def at_phase(self, t):
151
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
152
+ omega = self.drift().unsqueeze(-1); so = omega.sin().clamp(min=1e-7)
153
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
154
+
155
+ def forward(self, x):
156
+ B, D = x.shape; P, A, d = self.n_patches, self.n_anchors, self.patch_dim
157
+ patches = self.norm(x).reshape(B, P, d)
158
+ patches_n = F.normalize(patches, dim=-1)
159
+ tris = []
160
+ for t in [0.0, 1/3, 2/3]:
161
+ at = F.normalize(self.at_phase(t), dim=-1)
162
+ tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
163
+ tri = torch.cat(tris, dim=-1)
164
+ h = F.gelu(torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1)
165
+ pw = self.pw_norm(torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2)
166
+ gate = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
167
+ return x + (gate * pw + (1-gate) * patches).reshape(B, D)
168
+
169
+
170
+ # ── CONSTELLATION RELAY (sequence-aware) ──
171
+
172
+ class ConstellationRelay(nn.Module):
173
+ """Per-token geometric processing. O(S). Handles (B,D) and (B,S,D)."""
174
+ def __init__(self, dim, n_anchors=16, n_comp=8, d_comp=64,
175
+ gate_init=-3.0, anchor_init='repulsion', activation='squared_relu'):
176
+ super().__init__()
177
+ self.dim = dim; self.norm = nn.LayerNorm(dim)
178
+ self.constellation = Constellation(n_anchors, dim, anchor_init=anchor_init)
179
+ self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)
180
+ self.proj = nn.Linear(self.patchwork.output_dim, dim)
181
+ self.gate = nn.Parameter(torch.full((dim,), gate_init))
182
+
183
+ def forward(self, x):
184
+ squeeze = x.dim() == 2
185
+ if squeeze: x = x.unsqueeze(1)
186
+ B, S, D = x.shape; residual = x
187
+ h_flat = F.normalize(self.norm(x).reshape(B*S, D), dim=-1)
188
+ tri, _ = self.constellation.triangulate(h_flat)
189
+ update = self.proj(self.patchwork(tri)).reshape(B, S, D)
190
+ out = residual + torch.sigmoid(self.gate) * update
191
+ return out.squeeze(1) if squeeze else out
192
+
193
+
194
+ # ── MAGNITUDE FLOW ──
195
+
196
+ class MagnitudeFlow(nn.Module):
197
+ """Relay-stack per-compartment magnitude. No attention."""
198
+ def __init__(self, dim, n_anchors, hidden_dim=64, n_heads=4,
199
+ n_layers=2, mag_min=0.1, mag_max=5.0, n_comp=8):
200
+ super().__init__()
201
+ self.dim, self.n_anchors = dim, n_anchors
202
+ self.mag_min, self.mag_max, self.n_comp, self.n_layers = mag_min, mag_max, n_comp, n_layers
203
+ patch_dim = 16; relay_dim = n_comp * patch_dim
204
+ self.patch_dim, self.relay_dim = patch_dim, relay_dim
205
+
206
+ self.emb_proj = nn.Linear(dim, relay_dim // 2)
207
+ self.tri_proj = nn.Linear(n_anchors, relay_dim // 4)
208
+ self.ctx_proj = nn.Linear(relay_dim // 2 + relay_dim // 4 + 1, relay_dim)
209
+ self.relays = nn.ModuleList([
210
+ RelayLayer(relay_dim, patch_dim, 16, 3, hidden_dim, -3.0) for _ in range(n_layers)])
211
+ self.mag_heads = nn.ModuleList([
212
+ nn.Sequential(nn.Linear(patch_dim, patch_dim//2), nn.GELU(), nn.Linear(patch_dim//2, 1))
213
+ for _ in range(n_comp)])
214
+ self.register_buffer('stats_bias_cached', torch.zeros(n_comp), persistent=False)
215
+
216
+ def update_stats(self, push_diag, anchor_push):
217
+ with torch.no_grad():
218
+ device = self.stats_bias_cached.device
219
+ if anchor_push.strategy == 'momentum' and anchor_push.accumulator is not None:
220
+ mn = anchor_push.accumulator.norm(dim=-1)
221
+ apc = self.n_anchors // self.n_comp
222
+ self.stats_bias_cached = torch.stack([
223
+ mn[k*apc : (k+1)*apc if k < self.n_comp-1 else self.n_anchors].mean()
224
+ for k in range(self.n_comp)])
225
+ else: self.stats_bias_cached.zero_()
226
+
227
+ def forward(self, emb, triangulation, raw_magnitude):
228
+ B, A = emb.shape[0], self.n_anchors
229
+ x = self.ctx_proj(torch.cat([self.emb_proj(emb), self.tri_proj(triangulation), raw_magnitude], -1))
230
+ for relay in self.relays: x = relay(x)
231
+ patches = x.reshape(B, self.n_comp, self.patch_dim)
232
+ mc = torch.cat([self.mag_heads[k](patches[:, k]) for k in range(self.n_comp)], -1)
233
+ mc = self.mag_min + (self.mag_max - self.mag_min) * torch.sigmoid(mc + self.stats_bias_cached)
234
+ apc = A // self.n_comp
235
+ mag = torch.cat([mc[:, k:k+1].expand(-1, apc if k < self.n_comp-1 else A - k*apc)
236
+ for k in range(self.n_comp)], -1)
237
+ return mag, mc
238
+
239
+ def get_relay_diagnostics(self):
240
+ return [{'layer': i, 'drift_mean': r.drift().mean().item(),
241
+ 'gate_mean': r.gates.sigmoid().mean().item()} for i, r in enumerate(self.relays)]
242
+
243
+
244
+ # ── ANCHOR PUSH ──
245
+
246
+ def _project_tangent(vec, point):
247
+ return vec - (vec * point).sum(dim=-1, keepdim=True) * point
248
+
249
+ def _compute_centroids_and_assign(anchors_n, emb_n, label_buffer, device):
250
+ n_a = anchors_n.shape[0]; classes = label_buffer.unique(); n_cls = classes.shape[0]
251
+ centroids = torch.cat([F.normalize(emb_n[label_buffer==c].mean(0, keepdim=True), dim=-1)
252
+ for c in classes if (label_buffer==c).sum() > 0], dim=0)
253
+ if centroids.shape[0] == 0: return None, None, None, None
254
+ cos = anchors_n @ centroids.T; apc = n_a // n_cls
255
+ assigned = torch.full((n_a,), -1, dtype=torch.long, device=device)
256
+ cc = torch.zeros(n_cls, dtype=torch.long, device=device)
257
+ for idx in cos.flatten().sort(descending=True).indices:
258
+ a, c = (idx // n_cls).item(), (idx % n_cls).item()
259
+ if assigned[a] >= 0 or cc[c] >= apc + 1: continue
260
+ assigned[a] = c; cc[c] += 1
261
+ if (assigned >= 0).all(): break
262
+ u = (assigned < 0).nonzero(as_tuple=True)[0]
263
+ if len(u) > 0: assigned[u] = (anchors_n[u] @ centroids.T).argmax(1)
264
+ nearest = (emb_n @ anchors_n.T).argmax(1)
265
+ util = torch.bincount(nearest, minlength=n_a).float()
266
+ return centroids, assigned, util / util.sum().clamp(min=1), classes
267
+
268
+ def _perturb_target(target, apc, rank):
269
+ if apc > 1 and rank > 0:
270
+ noise = torch.randn_like(target) * 0.05
271
+ return F.normalize(target + noise - (noise * target).sum() * target, dim=-1)
272
+ return target
273
+
274
+ class AnchorPush:
275
+ """Configurable anchor push. Strategies: raw, gru, momentum."""
276
+ def __init__(self, strategy, n_anchors, dim, **kw):
277
+ self.strategy, self.n_anchors, self.dim, self.push_count = strategy, n_anchors, dim, 0
278
+ if strategy == 'raw': self.lr = kw.get('lr', 0.1)
279
+ elif strategy == 'momentum':
280
+ self.decay, self.alpha, self.beta = kw.get('decay', 0.9), kw.get('alpha', 0.1), kw.get('beta', 0.05)
281
+ self.util_floor, self.accumulator = kw.get('util_floor', 0.001), None
282
+ elif strategy == 'gru':
283
+ self.ema_decay = kw.get('ema_decay', 0.9); self.z_scale = kw.get('z_scale', 3.0)
284
+ self.r_scale = kw.get('r_scale', 5.0)
285
+ self.prev_pos = self.util_ema = self.drift_ema = None
286
+
287
+ @torch.no_grad()
288
+ def push(self, core, emb_buf, lbl_buf):
289
+ anchors = core.constellation.anchors.data; n_a = anchors.shape[0]; device = anchors.device
290
+ emb_n = F.normalize(emb_buf, dim=-1); anchors_n = F.normalize(anchors, dim=-1)
291
+ centroids, assigned, util, classes = _compute_centroids_and_assign(anchors_n, emb_n, lbl_buf, device)
292
+ if centroids is None: return {'moved': 0}
293
+ if hasattr(core, 'anchor_classes'):
294
+ for a in range(n_a): core.anchor_classes[a] = classes[assigned[a]]
295
+ if hasattr(core, 'class_centroids'):
296
+ for i, c in enumerate(classes): core.class_centroids[c] = centroids[i]
297
+ apc = n_a // centroids.shape[0]
298
+ targets = torch.stack([_perturb_target(centroids[assigned[a].item()], apc,
299
+ (assigned[:a]==assigned[a]).sum().item()) for a in range(n_a)])
300
+ if self.strategy == 'raw':
301
+ for a in range(n_a): anchors[a] = F.normalize(anchors_n[a] + self.lr*(targets[a]-anchors_n[a]), dim=-1)
302
+ d = torch.acos((anchors_n * F.normalize(anchors, dim=-1)).sum(-1).clamp(-1+1e-6, 1-1e-6))
303
+ diag = {'drift_mean': d.mean().item(), 'drift_max': d.max().item()}
304
+ elif self.strategy == 'momentum':
305
+ if self.accumulator is None: self.accumulator = torch.zeros(n_a, self.dim, device=device)
306
+ res = _project_tangent(targets - anchors_n, anchors_n)
307
+ self.accumulator = self.decay * _project_tangent(self.accumulator, anchors_n) + res
308
+ corr = self.alpha * res + self.beta * self.accumulator
309
+ dead = util < self.util_floor
310
+ if dead.any(): corr[dead] = res[dead] * 0.5
311
+ new = F.normalize(anchors_n + corr, dim=-1)
312
+ d = torch.acos((anchors_n * new).sum(-1).clamp(-1+1e-6, 1-1e-6))
313
+ anchors.copy_(new)
314
+ diag = {'drift_mean': d.mean().item(), 'drift_max': d.max().item(),
315
+ 'momentum_mean': self.accumulator.norm(dim=-1).mean().item(), 'dead_count': dead.sum().item()}
316
+ else:
317
+ diag = {}
318
+ diag.update({'moved': n_a, 'n_active': (util > 0).sum().item(),
319
+ 'util_min': util.min().item(), 'util_max': util.max().item()})
320
+ self.push_count += 1; return diag
321
+
322
+
323
+ # ── FLOW ATTENTION (historical) ──
324
+
325
+ class FlowAttention(nn.Module):
326
+ """3-step Euler flow in tangent space. Superseded by relay."""
327
+ def __init__(self, dim, n_anchors, flow_dim=64, n_steps=3, time_dim=32, gate_init=-3.0):
328
+ super().__init__()
329
+ self.dim, self.flow_dim, self.n_anchors, self.n_steps, self.time_dim = dim, flow_dim, n_anchors, n_steps, time_dim
330
+ self.to_flow = nn.Sequential(nn.Linear(n_anchors+dim, flow_dim), nn.LayerNorm(flow_dim))
331
+ self.time_mlp = nn.Sequential(nn.Linear(time_dim, flow_dim), nn.GELU())
332
+ self.stats_proj = nn.Linear(3, flow_dim, bias=False)
333
+ self.velocity = nn.Sequential(nn.Linear(flow_dim, flow_dim*2), nn.GELU(), nn.Linear(flow_dim*2, flow_dim))
334
+ self.to_correction = nn.Linear(flow_dim, dim, bias=False)
335
+ self.gate = nn.Parameter(torch.full((dim,), gate_init))
336
+ self.register_buffer('stats_bias_cached', torch.zeros(flow_dim), persistent=False)
337
+
338
+ def update_stats(self, push_diag, anchor_push):
339
+ with torch.no_grad():
340
+ dev = self.stats_proj.weight.device
341
+ mn = anchor_push.accumulator.norm(dim=-1) if (anchor_push.strategy=='momentum' and anchor_push.accumulator is not None) else torch.zeros(self.n_anchors, device=dev)
342
+ dr = torch.tensor(push_diag.get('drift_mean',0.0), device=dev).expand(self.n_anchors)
343
+ ut = torch.tensor(push_diag.get('util_max',0.0), device=dev).expand(self.n_anchors)
344
+ self.stats_bias_cached = self.stats_proj(torch.stack([mn, ut, dr], -1)).mean(0)
345
+
346
+ def forward(self, emb, constellation):
347
+ B, D, dev = *emb.shape, emb.device
348
+ tri = emb @ F.normalize(constellation.anchors, dim=-1).T
349
+ z = self.to_flow(torch.cat([tri, emb], -1)); dt = 1.0/self.n_steps
350
+ half = self.time_dim // 2
351
+ freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=dev) / half)
352
+ for s in range(self.n_steps):
353
+ args = (s*dt)*freqs; t_emb = torch.cat([args.sin(), args.cos()])
354
+ z = z + dt * (self.velocity(z + self.time_mlp(t_emb)) + self.stats_bias_cached)
355
+ c = self.to_correction(z); c = c - (c*emb).sum(-1,keepdim=True)*emb
356
+ return F.normalize(emb + torch.sigmoid(self.gate)*c, dim=-1)
357
+
358
+
359
+ # ── GEOMETRIC AUTOGRAD ──
360
+
361
+ class GeometricAutograd(torch.autograd.Function):
362
+ """Manifold-aware gradient correction on S^(D-1). Forward: identity."""
363
+ @staticmethod
364
+ def forward(ctx, emb, anchors, tang_strength, sep_strength):
365
+ ctx.save_for_backward(emb, anchors); ctx.tang, ctx.sep = tang_strength, sep_strength
366
+ return emb
367
+
368
+ @staticmethod
369
+ def backward(ctx, grad):
370
+ emb, anchors = ctx.saved_tensors
371
+ dot = (grad * emb).sum(-1, keepdim=True)
372
+ corrected = grad - ctx.tang * dot * emb
373
+ if ctx.sep > 0:
374
+ an = F.normalize(anchors.detach(), dim=-1)
375
+ nearest = an[(emb @ an.T).argmax(-1)]
376
+ toward = (corrected * nearest).sum(-1, keepdim=True)
377
+ corrected = corrected - ctx.sep * F.relu(toward) * nearest
378
+ return corrected, None, None, None
379
+
380
+
381
+ # ── UTILITIES ──
382
+
383
+ def param_count(module, name=""):
384
+ t = sum(p.numel() for p in module.parameters())
385
+ tr = sum(p.numel() for p in module.parameters() if p.requires_grad)
386
+ if name: print(f" {name}: {t:,} ({tr:,} trainable)")
387
+ return t, tr
388
+
389
+ def model_summary(model):
390
+ total = sum(p.numel() for p in model.parameters())
391
+ print(f" Total: {total:,}")
392
+ for n, m in model.named_children():
393
+ c = sum(p.numel() for p in m.parameters())
394
+ if c > 0: print(f" {n}: {c:,}")
395
+ return total