AbstractPhil commited on
Commit
84eb1a3
Β·
verified Β·
1 Parent(s): 71159f8

a couple slow ones

Browse files
constellation_relays_activation_effects_analysis.py CHANGED
@@ -12,6 +12,9 @@ Systematic test of:
12
  Each test uses the same random seed and input for fair comparison.
13
  """
14
 
 
 
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
@@ -444,6 +447,9 @@ for i, depth in enumerate(depths_to_check):
444
  # TEST 5: TRAINED RELAY β€” ACTIVATION EFFECT ON LEARNING
445
  # ══════════════════════════════════════════════════════════════════
446
 
 
 
 
447
  print(f"\n{'━'*80}")
448
  print(f"TEST 5: Trained Relay β€” does activation choice affect what the relay LEARNS?")
449
  print(f" Setup: 4-layer relay trained to classify 256d embeddings into 10 classes")
@@ -464,9 +470,12 @@ for c in range(N_CLASSES):
464
  noise = torch.randn(N_TRAIN // N_CLASSES, D, device=DEVICE) * 0.3
465
  pts = F.normalize(class_centers[c].unsqueeze(0) + noise, dim=-1)
466
  train_x.append(pts)
467
- train_y.append(torch.full((N_TRAIN // N_CLASSES,), c, device=DEVICE))
468
  train_x = torch.cat(train_x)
469
  train_y = torch.cat(train_y)
 
 
 
470
 
471
  print(f"\n {'pw_act':>14} {'acc':>8} {'loss':>8} {'cos_orig':>10} "
472
  f"{'CV':>8} {'eff_dim':>8} {'drift':>8} {'gate':>8}")
@@ -491,11 +500,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
491
  opt = torch.optim.Adam(model.parameters(), lr=1e-3)
492
 
493
  for step in range(TRAIN_STEPS):
494
- idx = torch.randint(0, len(train_x), (128,), device=DEVICE)
495
  logits = model(train_x[idx])
496
  loss = F.cross_entropy(logits, train_y[idx])
 
 
 
497
  opt.zero_grad()
498
  loss.backward()
 
499
  opt.step()
500
 
501
  # Evaluate
@@ -523,6 +536,8 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
523
  # TEST 6: HYBRID RELAY β€” INFORMATION RETENTION
524
  # ══════════════════════════════════════════════════════════════════
525
 
 
 
526
  print(f"\n{'━'*80}")
527
  print(f"TEST 6: Hybrid Relay β€” Information Retention")
528
  print(f" Setup: 8 layers of hybrid relay (attention + constellation)")
@@ -598,12 +613,14 @@ keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
598
  keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
599
 
600
  task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
601
- label_a = torch.randint(0, N_CLS, (N_SAMPLES,), device=DEVICE)
602
- label_b = torch.randint(0, N_CLS, (N_SAMPLES,), device=DEVICE)
603
  task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
604
  task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
605
  task_x = F.normalize(task_x, dim=-1)
606
- task_y = (label_a + label_b) % N_CLS # class depends on BOTH tokens
 
 
607
 
608
  print(f"\n {'relay_act':>14} {'acc':>8} {'loss':>8} {'g_relay':>8} "
609
  f"{'g_attn':>8} {'cross_Ξ”':>10}")
@@ -629,11 +646,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu"]:
629
  opt = torch.optim.Adam(model.parameters(), lr=3e-4)
630
 
631
  for step in range(STEPS):
632
- idx = torch.randint(0, N_SAMPLES, (128,), device=DEVICE)
633
  logits = model(task_x[idx])
634
  loss = F.cross_entropy(logits, task_y[idx])
 
 
 
635
  opt.zero_grad()
636
  loss.backward()
 
637
  opt.step()
638
 
639
  model.eval()
@@ -696,11 +717,15 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
696
 
697
  drift_log = {}
698
  for step in range(TRAIN_STEPS):
699
- idx = torch.randint(0, len(train_x), (128,), device=DEVICE)
700
  logits = model(train_x[idx])
701
  loss = F.cross_entropy(logits, train_y[idx])
 
 
 
702
  opt.zero_grad()
703
  loss.backward()
 
704
  opt.step()
705
 
706
  if (step + 1) in [50, 100, 200, 300, 500]:
@@ -743,6 +768,7 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
743
  h = x
744
  for layer in layers:
745
  h = layer(h)
 
746
 
747
  loss = h.sum()
748
  loss.backward()
@@ -751,7 +777,6 @@ for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_r
751
  anchor_grads = [l.anchors.grad.norm().item() for l in layers if l.anchors.grad is not None]
752
  gate_grads = [l.gate.grad.item() for l in layers if l.gate.grad is not None]
753
 
754
- # Output gradient (last layer's contribution)
755
  grad_out = h.grad.norm().item() if h.grad is not None else 0
756
 
757
  print(f" {act_name:>14} {grad_in:>10.4f} {grad_out:>10.4f} "
 
12
  Each test uses the same random seed and input for fair comparison.
13
  """
14
 
15
+ import os
16
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
17
+
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
 
447
  # TEST 5: TRAINED RELAY β€” ACTIVATION EFFECT ON LEARNING
448
  # ══════════════════════════════════════════════════════════════════
449
 
450
+ # Cleanup from Tests 1-4
451
+ torch.cuda.empty_cache()
452
+
453
  print(f"\n{'━'*80}")
454
  print(f"TEST 5: Trained Relay β€” does activation choice affect what the relay LEARNS?")
455
  print(f" Setup: 4-layer relay trained to classify 256d embeddings into 10 classes")
 
470
  noise = torch.randn(N_TRAIN // N_CLASSES, D, device=DEVICE) * 0.3
471
  pts = F.normalize(class_centers[c].unsqueeze(0) + noise, dim=-1)
472
  train_x.append(pts)
473
+ train_y.append(torch.full((N_TRAIN // N_CLASSES,), c, dtype=torch.long, device=DEVICE))
474
  train_x = torch.cat(train_x)
475
  train_y = torch.cat(train_y)
476
+ assert train_y.max() < N_CLASSES, f"Label OOB: max={train_y.max()}, n_classes={N_CLASSES}"
477
+ assert train_y.min() >= 0, f"Negative label: min={train_y.min()}"
478
+ torch.cuda.synchronize()
479
 
480
  print(f"\n {'pw_act':>14} {'acc':>8} {'loss':>8} {'cos_orig':>10} "
481
  f"{'CV':>8} {'eff_dim':>8} {'drift':>8} {'gate':>8}")
 
500
  opt = torch.optim.Adam(model.parameters(), lr=1e-3)
501
 
502
  for step in range(TRAIN_STEPS):
503
+ idx = torch.randint(0, len(train_x), (128,))
504
  logits = model(train_x[idx])
505
  loss = F.cross_entropy(logits, train_y[idx])
506
+ if torch.isnan(loss) or torch.isinf(loss):
507
+ print(f" ⚠ Bad loss at step {step}, act={act_name}")
508
+ break
509
  opt.zero_grad()
510
  loss.backward()
511
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
512
  opt.step()
513
 
514
  # Evaluate
 
536
  # TEST 6: HYBRID RELAY β€” INFORMATION RETENTION
537
  # ══════════════════════════════════════════════════════════════════
538
 
539
+ torch.cuda.empty_cache()
540
+
541
  print(f"\n{'━'*80}")
542
  print(f"TEST 6: Hybrid Relay β€” Information Retention")
543
  print(f" Setup: 8 layers of hybrid relay (attention + constellation)")
 
613
  keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
614
 
615
  task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
616
+ label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
617
+ label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
618
  task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
619
  task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
620
  task_x = F.normalize(task_x, dim=-1)
621
+ task_y = ((label_a + label_b) % N_CLS).long()
622
+ assert task_y.max() < N_CLS and task_y.min() >= 0
623
+ torch.cuda.synchronize()
624
 
625
  print(f"\n {'relay_act':>14} {'acc':>8} {'loss':>8} {'g_relay':>8} "
626
  f"{'g_attn':>8} {'cross_Ξ”':>10}")
 
646
  opt = torch.optim.Adam(model.parameters(), lr=3e-4)
647
 
648
  for step in range(STEPS):
649
+ idx = torch.randint(0, N_SAMPLES, (128,))
650
  logits = model(task_x[idx])
651
  loss = F.cross_entropy(logits, task_y[idx])
652
+ if torch.isnan(loss) or torch.isinf(loss):
653
+ print(f" ⚠ Bad loss at step {step}, act={act_name}")
654
+ break
655
  opt.zero_grad()
656
  loss.backward()
657
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
658
  opt.step()
659
 
660
  model.eval()
 
717
 
718
  drift_log = {}
719
  for step in range(TRAIN_STEPS):
720
+ idx = torch.randint(0, len(train_x), (128,))
721
  logits = model(train_x[idx])
722
  loss = F.cross_entropy(logits, train_y[idx])
723
+ if torch.isnan(loss) or torch.isinf(loss):
724
+ print(f" ⚠ Bad loss at step {step}, act={act_name}")
725
+ break
726
  opt.zero_grad()
727
  loss.backward()
728
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
729
  opt.step()
730
 
731
  if (step + 1) in [50, 100, 200, 300, 500]:
 
768
  h = x
769
  for layer in layers:
770
  h = layer(h)
771
+ h.retain_grad()
772
 
773
  loss = h.sum()
774
  loss.backward()
 
777
  anchor_grads = [l.anchors.grad.norm().item() for l in layers if l.anchors.grad is not None]
778
  gate_grads = [l.gate.grad.item() for l in layers if l.gate.grad is not None]
779
 
 
780
  grad_out = h.grad.norm().item() if h.grad is not None else 0
781
 
782
  print(f" {act_name:>14} {grad_in:>10.4f} {grad_out:>10.4f} "