a couple slow ones
Browse files
constellation_relays_activation_effects_analysis.py
CHANGED
|
@@ -12,6 +12,9 @@ Systematic test of:
|
|
| 12 |
Each test uses the same random seed and input for fair comparison.
|
| 13 |
"""
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
import torch.nn.functional as F
|
|
@@ -444,6 +447,9 @@ for i, depth in enumerate(depths_to_check):
|
|
| 444 |
# TEST 5: TRAINED RELAY β ACTIVATION EFFECT ON LEARNING
|
| 445 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 446 |
|
|
|
|
|
|
|
|
|
|
| 447 |
print(f"\n{'β'*80}")
|
| 448 |
print(f"TEST 5: Trained Relay β does activation choice affect what the relay LEARNS?")
|
| 449 |
print(f" Setup: 4-layer relay trained to classify 256d embeddings into 10 classes")
|
|
@@ -464,9 +470,12 @@ for c in range(N_CLASSES):
|
|
| 464 |
noise = torch.randn(N_TRAIN // N_CLASSES, D, device=DEVICE) * 0.3
|
| 465 |
pts = F.normalize(class_centers[c].unsqueeze(0) + noise, dim=-1)
|
| 466 |
train_x.append(pts)
|
| 467 |
-
train_y.append(torch.full((N_TRAIN // N_CLASSES,), c, device=DEVICE))
|
| 468 |
train_x = torch.cat(train_x)
|
| 469 |
train_y = torch.cat(train_y)
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
print(f"\n {'pw_act':>14} {'acc':>8} {'loss':>8} {'cos_orig':>10} "
|
| 472 |
f"{'CV':>8} {'eff_dim':>8} {'drift':>8} {'gate':>8}")
|
|
@@ -491,11 +500,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
|
|
| 491 |
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 492 |
|
| 493 |
for step in range(TRAIN_STEPS):
|
| 494 |
-
idx = torch.randint(0, len(train_x), (128,)
|
| 495 |
logits = model(train_x[idx])
|
| 496 |
loss = F.cross_entropy(logits, train_y[idx])
|
|
|
|
|
|
|
|
|
|
| 497 |
opt.zero_grad()
|
| 498 |
loss.backward()
|
|
|
|
| 499 |
opt.step()
|
| 500 |
|
| 501 |
# Evaluate
|
|
@@ -523,6 +536,8 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
|
|
| 523 |
# TEST 6: HYBRID RELAY β INFORMATION RETENTION
|
| 524 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 525 |
|
|
|
|
|
|
|
| 526 |
print(f"\n{'β'*80}")
|
| 527 |
print(f"TEST 6: Hybrid Relay β Information Retention")
|
| 528 |
print(f" Setup: 8 layers of hybrid relay (attention + constellation)")
|
|
@@ -598,12 +613,14 @@ keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
|
|
| 598 |
keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
|
| 599 |
|
| 600 |
task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
|
| 601 |
-
label_a = torch.randint(0, N_CLS, (N_SAMPLES,), device=DEVICE)
|
| 602 |
-
label_b = torch.randint(0, N_CLS, (N_SAMPLES,), device=DEVICE)
|
| 603 |
task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
|
| 604 |
task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
|
| 605 |
task_x = F.normalize(task_x, dim=-1)
|
| 606 |
-
task_y = (label_a + label_b) % N_CLS
|
|
|
|
|
|
|
| 607 |
|
| 608 |
print(f"\n {'relay_act':>14} {'acc':>8} {'loss':>8} {'g_relay':>8} "
|
| 609 |
f"{'g_attn':>8} {'cross_Ξ':>10}")
|
|
@@ -629,11 +646,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu"]:
|
|
| 629 |
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
|
| 630 |
|
| 631 |
for step in range(STEPS):
|
| 632 |
-
idx = torch.randint(0, N_SAMPLES, (128,)
|
| 633 |
logits = model(task_x[idx])
|
| 634 |
loss = F.cross_entropy(logits, task_y[idx])
|
|
|
|
|
|
|
|
|
|
| 635 |
opt.zero_grad()
|
| 636 |
loss.backward()
|
|
|
|
| 637 |
opt.step()
|
| 638 |
|
| 639 |
model.eval()
|
|
@@ -696,11 +717,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
|
|
| 696 |
|
| 697 |
drift_log = {}
|
| 698 |
for step in range(TRAIN_STEPS):
|
| 699 |
-
idx = torch.randint(0, len(train_x), (128,)
|
| 700 |
logits = model(train_x[idx])
|
| 701 |
loss = F.cross_entropy(logits, train_y[idx])
|
|
|
|
|
|
|
|
|
|
| 702 |
opt.zero_grad()
|
| 703 |
loss.backward()
|
|
|
|
| 704 |
opt.step()
|
| 705 |
|
| 706 |
if (step + 1) in [50, 100, 200, 300, 500]:
|
|
@@ -743,6 +768,7 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
|
|
| 743 |
h = x
|
| 744 |
for layer in layers:
|
| 745 |
h = layer(h)
|
|
|
|
| 746 |
|
| 747 |
loss = h.sum()
|
| 748 |
loss.backward()
|
|
@@ -751,7 +777,6 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
|
|
| 751 |
anchor_grads = [l.anchors.grad.norm().item() for l in layers if l.anchors.grad is not None]
|
| 752 |
gate_grads = [l.gate.grad.item() for l in layers if l.gate.grad is not None]
|
| 753 |
|
| 754 |
-
# Output gradient (last layer's contribution)
|
| 755 |
grad_out = h.grad.norm().item() if h.grad is not None else 0
|
| 756 |
|
| 757 |
print(f" {act_name:>14} {grad_in:>10.4f} {grad_out:>10.4f} "
|
|
|
|
| 12 |
Each test uses the same random seed and input for fair comparison.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
import os
|
| 16 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 17 |
+
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
| 20 |
import torch.nn.functional as F
|
|
|
|
| 447 |
# TEST 5: TRAINED RELAY β ACTIVATION EFFECT ON LEARNING
|
| 448 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 449 |
|
| 450 |
+
# Cleanup from Tests 1-4
|
| 451 |
+
torch.cuda.empty_cache()
|
| 452 |
+
|
| 453 |
print(f"\n{'β'*80}")
|
| 454 |
print(f"TEST 5: Trained Relay β does activation choice affect what the relay LEARNS?")
|
| 455 |
print(f" Setup: 4-layer relay trained to classify 256d embeddings into 10 classes")
|
|
|
|
| 470 |
noise = torch.randn(N_TRAIN // N_CLASSES, D, device=DEVICE) * 0.3
|
| 471 |
pts = F.normalize(class_centers[c].unsqueeze(0) + noise, dim=-1)
|
| 472 |
train_x.append(pts)
|
| 473 |
+
train_y.append(torch.full((N_TRAIN // N_CLASSES,), c, dtype=torch.long, device=DEVICE))
|
| 474 |
train_x = torch.cat(train_x)
|
| 475 |
train_y = torch.cat(train_y)
|
| 476 |
+
assert train_y.max() < N_CLASSES, f"Label OOB: max={train_y.max()}, n_classes={N_CLASSES}"
|
| 477 |
+
assert train_y.min() >= 0, f"Negative label: min={train_y.min()}"
|
| 478 |
+
torch.cuda.synchronize()
|
| 479 |
|
| 480 |
print(f"\n {'pw_act':>14} {'acc':>8} {'loss':>8} {'cos_orig':>10} "
|
| 481 |
f"{'CV':>8} {'eff_dim':>8} {'drift':>8} {'gate':>8}")
|
|
|
|
| 500 |
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 501 |
|
| 502 |
for step in range(TRAIN_STEPS):
|
| 503 |
+
idx = torch.randint(0, len(train_x), (128,))
|
| 504 |
logits = model(train_x[idx])
|
| 505 |
loss = F.cross_entropy(logits, train_y[idx])
|
| 506 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 507 |
+
print(f" β Bad loss at step {step}, act={act_name}")
|
| 508 |
+
break
|
| 509 |
opt.zero_grad()
|
| 510 |
loss.backward()
|
| 511 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 512 |
opt.step()
|
| 513 |
|
| 514 |
# Evaluate
|
|
|
|
| 536 |
# TEST 6: HYBRID RELAY β INFORMATION RETENTION
|
| 537 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 538 |
|
| 539 |
+
torch.cuda.empty_cache()
|
| 540 |
+
|
| 541 |
print(f"\n{'β'*80}")
|
| 542 |
print(f"TEST 6: Hybrid Relay β Information Retention")
|
| 543 |
print(f" Setup: 8 layers of hybrid relay (attention + constellation)")
|
|
|
|
| 613 |
keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
|
| 614 |
|
| 615 |
task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
|
| 616 |
+
label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
|
| 617 |
+
label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
|
| 618 |
task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
|
| 619 |
task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
|
| 620 |
task_x = F.normalize(task_x, dim=-1)
|
| 621 |
+
task_y = ((label_a + label_b) % N_CLS).long()
|
| 622 |
+
assert task_y.max() < N_CLS and task_y.min() >= 0
|
| 623 |
+
torch.cuda.synchronize()
|
| 624 |
|
| 625 |
print(f"\n {'relay_act':>14} {'acc':>8} {'loss':>8} {'g_relay':>8} "
|
| 626 |
f"{'g_attn':>8} {'cross_Ξ':>10}")
|
|
|
|
| 646 |
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
|
| 647 |
|
| 648 |
for step in range(STEPS):
|
| 649 |
+
idx = torch.randint(0, N_SAMPLES, (128,))
|
| 650 |
logits = model(task_x[idx])
|
| 651 |
loss = F.cross_entropy(logits, task_y[idx])
|
| 652 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 653 |
+
print(f" β Bad loss at step {step}, act={act_name}")
|
| 654 |
+
break
|
| 655 |
opt.zero_grad()
|
| 656 |
loss.backward()
|
| 657 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 658 |
opt.step()
|
| 659 |
|
| 660 |
model.eval()
|
|
|
|
| 717 |
|
| 718 |
drift_log = {}
|
| 719 |
for step in range(TRAIN_STEPS):
|
| 720 |
+
idx = torch.randint(0, len(train_x), (128,))
|
| 721 |
logits = model(train_x[idx])
|
| 722 |
loss = F.cross_entropy(logits, train_y[idx])
|
| 723 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 724 |
+
print(f" β Bad loss at step {step}, act={act_name}")
|
| 725 |
+
break
|
| 726 |
opt.zero_grad()
|
| 727 |
loss.backward()
|
| 728 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 729 |
opt.step()
|
| 730 |
|
| 731 |
if (step + 1) in [50, 100, 200, 300, 500]:
|
|
|
|
| 768 |
h = x
|
| 769 |
for layer in layers:
|
| 770 |
h = layer(h)
|
| 771 |
+
h.retain_grad()
|
| 772 |
|
| 773 |
loss = h.sum()
|
| 774 |
loss.backward()
|
|
|
|
| 777 |
anchor_grads = [l.anchors.grad.norm().item() for l in layers if l.anchors.grad is not None]
|
| 778 |
gate_grads = [l.gate.grad.item() for l in layers if l.gate.grad is not None]
|
| 779 |
|
|
|
|
| 780 |
grad_out = h.grad.norm().item() if h.grad is not None else 0
|
| 781 |
|
| 782 |
print(f" {act_name:>14} {grad_in:>10.4f} {grad_out:>10.4f} "
|