CharlesCNorton commited on
Commit
5579250
·
1 Parent(s): 224eea2

Add HybridExtractor for digit lookup + word number learning

Browse files

- HybridExtractor: detects digit tokens (32-41) for hardcoded lookup,
uses learned MLP for word numbers ("forty seven plus eighty six")
- int_to_words(): converts 0-255 to English words
- generate_problem(): randomly mixes digit and word formats
- compute_hybrid_loss(): only trains on word samples (digits are free)
- Hybrid is now the default mode for --mode llm

Files changed (2) hide show
  1. llm_integration/model.py +234 -4
  2. llm_integration/train.py +179 -26
llm_integration/model.py CHANGED
@@ -745,6 +745,217 @@ class DigitExtractor(nn.Module):
745
  return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
746
 
747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  class ArithmeticModel(nn.Module):
749
  """
750
  LLM + extractor + frozen threshold circuits.
@@ -753,7 +964,8 @@ class ArithmeticModel(nn.Module):
753
 
754
  def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0,
755
  extract_layer: int = -1, position_extract: bool = False,
756
- digit_pred: bool = False, positional_digit: bool = False):
 
757
  super().__init__()
758
  self.device = device
759
  self.unfreeze_layers = unfreeze_layers
@@ -761,6 +973,7 @@ class ArithmeticModel(nn.Module):
761
  self.position_extract = position_extract
762
  self.digit_pred = digit_pred
763
  self.positional_digit = positional_digit
 
764
 
765
  from transformers import AutoModelForCausalLM, AutoTokenizer
766
 
@@ -801,7 +1014,14 @@ class ArithmeticModel(nn.Module):
801
  print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True)
802
 
803
  print("[4/4] Initializing extractor...", flush=True)
804
- if positional_digit:
 
 
 
 
 
 
 
805
  print(" Using POSITIONAL DIGIT extraction (100% proven)", flush=True)
806
  self.extractor = PositionalDigitExtractor(
807
  hidden_dim=hidden_dim
@@ -875,31 +1095,41 @@ class ArithmeticModel(nn.Module):
875
  """
876
  hidden, mask, token_ids = self.get_hidden_states(texts)
877
 
878
- if self.positional_digit or self.position_extract:
879
  extractor_out = self.extractor(hidden, mask, token_ids)
880
  else:
881
  extractor_out = self.extractor(hidden, mask)
882
 
883
- if self.positional_digit:
 
 
 
 
884
  a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits = extractor_out
 
885
  elif self.digit_pred:
886
  a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
887
  op_indices_from_tokens = None
888
  a_values, b_values = None, None
 
889
  elif self.position_extract:
890
  a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out
891
  a_digit_logits, b_digit_logits = None, None
892
  a_values, b_values = None, None
 
893
  else:
894
  a_bits, b_bits, op_logits = extractor_out
895
  a_digit_logits, b_digit_logits = None, None
896
  op_indices_from_tokens = None
897
  a_values, b_values = None, None
 
898
 
899
  op_probs = torch.softmax(op_logits, dim=-1)
900
 
901
  result_bits = self.circuits(a_bits, b_bits, op_probs)
902
 
 
 
903
  if self.positional_digit:
904
  return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits
905
  if self.digit_pred:
 
745
  return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
746
 
747
 
748
+ class HybridExtractor(nn.Module):
749
+ """
750
+ Hybrid extractor that handles both digit tokens and word numbers.
751
+
752
+ For digit tokens (32-41): Direct lookup, no training needed
753
+ For word numbers: Learned MLP extraction from pooled hidden states
754
+
755
+ This is the real training target - learning to extract numbers from
756
+ natural language like "forty seven plus eighty six".
757
+ """
758
+
759
+ DIGIT_TOKENS = set(range(32, 42))
760
+ OPERATOR_TOKENS = {
761
+ 1232: 0, # ' +' -> add
762
+ 731: 1, # ' -' -> sub
763
+ 1672: 2, # ' *' -> mul
764
+ 2986: 3, # ' >' -> gt
765
+ 2067: 4, # ' <' -> lt
766
+ 1758: 5, # ' ==' -> eq
767
+ }
768
+ WORD_OP_TOKENS = {
769
+ 'plus': 0, 'minus': 1, 'times': 2,
770
+ 'greater': 3, 'less': 4, 'equals': 5, 'equal': 5,
771
+ }
772
+
773
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
774
+ super().__init__()
775
+ self.hidden_dim = hidden_dim
776
+
777
+ self.attention_pool = AttentionPooling(hidden_dim, num_heads)
778
+
779
+ self.a_predictor = nn.Sequential(
780
+ nn.Linear(hidden_dim, intermediate_dim),
781
+ nn.GELU(),
782
+ nn.Dropout(0.1),
783
+ nn.Linear(intermediate_dim, intermediate_dim),
784
+ nn.GELU(),
785
+ nn.Linear(intermediate_dim, 256),
786
+ )
787
+
788
+ self.b_predictor = nn.Sequential(
789
+ nn.Linear(hidden_dim, intermediate_dim),
790
+ nn.GELU(),
791
+ nn.Dropout(0.1),
792
+ nn.Linear(intermediate_dim, intermediate_dim),
793
+ nn.GELU(),
794
+ nn.Linear(intermediate_dim, 256),
795
+ )
796
+
797
+ self.op_predictor = nn.Sequential(
798
+ nn.Linear(hidden_dim, intermediate_dim // 2),
799
+ nn.GELU(),
800
+ nn.Linear(intermediate_dim // 2, len(OPERATIONS)),
801
+ )
802
+
803
+ def _has_digit_tokens(self, token_ids: torch.Tensor) -> bool:
804
+ """Check if input contains digit tokens."""
805
+ for tid in token_ids.tolist():
806
+ if tid in self.DIGIT_TOKENS:
807
+ return True
808
+ return False
809
+
810
+ def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
811
+ """
812
+ Extract values directly from digit tokens (hardcoded lookup).
813
+ Returns (a_value, b_value, op_idx) or None if pattern not found.
814
+ """
815
+ tokens = token_ids.tolist()
816
+
817
+ op_pos = -1
818
+ op_idx = 0
819
+ for i, tid in enumerate(tokens):
820
+ if tid in self.OPERATOR_TOKENS:
821
+ op_pos = i
822
+ op_idx = self.OPERATOR_TOKENS[tid]
823
+ break
824
+
825
+ if op_pos == -1:
826
+ return None
827
+
828
+ a_digits = []
829
+ for i in range(op_pos):
830
+ if tokens[i] in self.DIGIT_TOKENS:
831
+ a_digits.append(tokens[i] - 32)
832
+
833
+ b_start = op_pos + 1
834
+ if b_start < len(tokens) and tokens[b_start] == 216:
835
+ b_start += 1
836
+
837
+ b_digits = []
838
+ for i in range(b_start, len(tokens)):
839
+ if tokens[i] in self.DIGIT_TOKENS:
840
+ b_digits.append(tokens[i] - 32)
841
+
842
+ if not a_digits or not b_digits:
843
+ return None
844
+
845
+ a_val = 0
846
+ for d in a_digits:
847
+ a_val = a_val * 10 + d
848
+
849
+ b_val = 0
850
+ for d in b_digits:
851
+ b_val = b_val * 10 + d
852
+
853
+ return min(a_val, 255), min(b_val, 255), op_idx
854
+
855
+ def _value_to_bits(self, value: int, device) -> torch.Tensor:
856
+ """Convert integer to 8-bit tensor."""
857
+ bits = torch.zeros(8, device=device)
858
+ for i in range(8):
859
+ bits[7 - i] = (value >> i) & 1
860
+ return bits
861
+
862
+ def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None):
863
+ """
864
+ Args:
865
+ hidden: [batch, seq_len, hidden_dim]
866
+ mask: [batch, seq_len]
867
+ token_ids: [batch, seq_len] - optional, enables digit lookup
868
+
869
+ Returns:
870
+ a_bits: [batch, 8]
871
+ b_bits: [batch, 8]
872
+ op_logits: [batch, 6]
873
+ a_values: [batch] predicted values (for loss)
874
+ b_values: [batch] predicted values (for loss)
875
+ used_lookup: [batch] bool tensor indicating if lookup was used
876
+ """
877
+ batch_size = hidden.shape[0]
878
+ device = hidden.device
879
+
880
+ a_bits_list = []
881
+ b_bits_list = []
882
+ op_logits_list = []
883
+ a_values_list = []
884
+ b_values_list = []
885
+ used_lookup_list = []
886
+
887
+ pooled = self.attention_pool(hidden, mask)
888
+
889
+ for i in range(batch_size):
890
+ lookup_result = None
891
+ if token_ids is not None:
892
+ seq_mask = mask[i].bool()
893
+ valid_len = seq_mask.sum().item()
894
+ start_pos = hidden.shape[1] - valid_len
895
+ valid_tokens = token_ids[i, start_pos:]
896
+
897
+ if self._has_digit_tokens(valid_tokens):
898
+ lookup_result = self._extract_from_digits(valid_tokens)
899
+
900
+ if lookup_result is not None:
901
+ a_val, b_val, op_idx = lookup_result
902
+ a_bits = self._value_to_bits(a_val, device)
903
+ b_bits = self._value_to_bits(b_val, device)
904
+ op_logits = torch.zeros(len(OPERATIONS), device=device)
905
+ op_logits[op_idx] = 10.0
906
+
907
+ a_bits_list.append(a_bits)
908
+ b_bits_list.append(b_bits)
909
+ op_logits_list.append(op_logits)
910
+ a_values_list.append(float(a_val))
911
+ b_values_list.append(float(b_val))
912
+ used_lookup_list.append(True)
913
+ else:
914
+ sample_pooled = pooled[i]
915
+
916
+ a_logits = self.a_predictor(sample_pooled)
917
+ b_logits = self.b_predictor(sample_pooled)
918
+ op_logits = self.op_predictor(sample_pooled)
919
+
920
+ a_probs = torch.softmax(a_logits, dim=-1)
921
+ b_probs = torch.softmax(b_logits, dim=-1)
922
+
923
+ values = torch.arange(256, device=device, dtype=torch.float32)
924
+ a_val = (a_probs * values).sum()
925
+ b_val = (b_probs * values).sum()
926
+
927
+ a_bits = self._soft_value_to_bits(a_val, device)
928
+ b_bits = self._soft_value_to_bits(b_val, device)
929
+
930
+ a_bits_list.append(a_bits)
931
+ b_bits_list.append(b_bits)
932
+ op_logits_list.append(op_logits)
933
+ a_values_list.append(a_val)
934
+ b_values_list.append(b_val)
935
+ used_lookup_list.append(False)
936
+
937
+ a_bits = torch.stack(a_bits_list)
938
+ b_bits = torch.stack(b_bits_list)
939
+ op_logits = torch.stack(op_logits_list)
940
+ a_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in a_values_list])
941
+ b_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in b_values_list])
942
+ used_lookup = torch.tensor(used_lookup_list, device=device, dtype=torch.bool)
943
+
944
+ return a_bits, b_bits, op_logits, a_values, b_values, used_lookup
945
+
946
+ def _soft_value_to_bits(self, value: torch.Tensor, device) -> torch.Tensor:
947
+ """Convert soft value (0-255) to 8-bit representation differentiably."""
948
+ value = torch.clamp(value, 0, 255)
949
+ bits = []
950
+ remaining = value
951
+ for i in range(7, -1, -1):
952
+ threshold = 2 ** i
953
+ bit = torch.sigmoid((remaining - threshold + 0.5) * 10)
954
+ bits.append(bit)
955
+ remaining = remaining - bit * threshold
956
+ return torch.stack(bits)
957
+
958
+
959
  class ArithmeticModel(nn.Module):
960
  """
961
  LLM + extractor + frozen threshold circuits.
 
964
 
965
  def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0,
966
  extract_layer: int = -1, position_extract: bool = False,
967
+ digit_pred: bool = False, positional_digit: bool = False,
968
+ hybrid: bool = False):
969
  super().__init__()
970
  self.device = device
971
  self.unfreeze_layers = unfreeze_layers
 
973
  self.position_extract = position_extract
974
  self.digit_pred = digit_pred
975
  self.positional_digit = positional_digit
976
+ self.hybrid = hybrid
977
 
978
  from transformers import AutoModelForCausalLM, AutoTokenizer
979
 
 
1014
  print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True)
1015
 
1016
  print("[4/4] Initializing extractor...", flush=True)
1017
+ if hybrid:
1018
+ print(" Using HYBRID extraction (digit lookup + word learning)", flush=True)
1019
+ self.extractor = HybridExtractor(
1020
+ hidden_dim=hidden_dim,
1021
+ intermediate_dim=256,
1022
+ num_heads=4
1023
+ ).to(device)
1024
+ elif positional_digit:
1025
  print(" Using POSITIONAL DIGIT extraction (100% proven)", flush=True)
1026
  self.extractor = PositionalDigitExtractor(
1027
  hidden_dim=hidden_dim
 
1095
  """
1096
  hidden, mask, token_ids = self.get_hidden_states(texts)
1097
 
1098
+ if self.hybrid or self.positional_digit or self.position_extract:
1099
  extractor_out = self.extractor(hidden, mask, token_ids)
1100
  else:
1101
  extractor_out = self.extractor(hidden, mask)
1102
 
1103
+ if self.hybrid:
1104
+ a_bits, b_bits, op_logits, a_values, b_values, used_lookup = extractor_out
1105
+ op_indices_from_tokens = None
1106
+ a_digit_logits, b_digit_logits = None, None
1107
+ elif self.positional_digit:
1108
  a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits = extractor_out
1109
+ used_lookup = None
1110
  elif self.digit_pred:
1111
  a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
1112
  op_indices_from_tokens = None
1113
  a_values, b_values = None, None
1114
+ used_lookup = None
1115
  elif self.position_extract:
1116
  a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out
1117
  a_digit_logits, b_digit_logits = None, None
1118
  a_values, b_values = None, None
1119
+ used_lookup = None
1120
  else:
1121
  a_bits, b_bits, op_logits = extractor_out
1122
  a_digit_logits, b_digit_logits = None, None
1123
  op_indices_from_tokens = None
1124
  a_values, b_values = None, None
1125
+ used_lookup = None
1126
 
1127
  op_probs = torch.softmax(op_logits, dim=-1)
1128
 
1129
  result_bits = self.circuits(a_bits, b_bits, op_probs)
1130
 
1131
+ if self.hybrid:
1132
+ return result_bits, a_bits, b_bits, op_logits, a_values, b_values, used_lookup
1133
  if self.positional_digit:
1134
  return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits
1135
  if self.digit_pred:
llm_integration/train.py CHANGED
@@ -39,6 +39,29 @@ from fitness import generate_batch, compute_fitness, compute_loss
39
 
40
  DEVICE = 'cuda'
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor:
44
  bits = torch.zeros(8, device=device)
@@ -55,14 +78,84 @@ def bits_to_int(bits: torch.Tensor) -> int:
55
  return val
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def generate_problem(max_val: int = 255):
59
- """Generate a random arithmetic problem for LLM training."""
 
 
 
60
  a = random.randint(0, max_val)
61
  b = random.randint(0, max_val)
62
  op = random.choice(OPERATIONS)
63
 
64
- sym = OP_SYMBOLS[op]
65
- text = f"{a} {sym} {b}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if op == 'add':
68
  result = (a + b) & 0xFF
@@ -457,8 +550,51 @@ def compute_positional_digit_loss(pred_bits, op_logits, a_digit_logits_list, b_d
457
  }
458
 
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  def evaluate_llm(model, n_samples: int = 500):
461
- """Evaluate LLM model on random problems."""
462
  model.extractor.eval()
463
  correct = 0
464
  op_correct = 0
@@ -493,10 +629,12 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
493
  Args:
494
  unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen)
495
  extract_layer: Which layer to extract from (-1 = last)
496
- position_extract: Use position-specific extraction
497
- digit_pred: Predict digits instead of bits
498
- positional_digit: Use positional digit extraction (100% proven accuracy)
499
  """
 
 
500
  print("=" * 70)
501
  print(" LLM TRAINING")
502
  if unfreeze_layers > 0:
@@ -505,12 +643,14 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
505
  print(" LLM frozen")
506
  if extract_layer != -1:
507
  print(f" Extracting from layer {extract_layer}")
508
- if positional_digit:
509
- print(" POSITIONAL DIGIT extraction (100% proven)")
 
 
510
  elif position_extract:
511
- print(" Position-specific extraction")
512
- if digit_pred:
513
- print(" Digit-level prediction")
514
  print("=" * 70)
515
 
516
  print("\nInitializing model...")
@@ -520,7 +660,8 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
520
  extract_layer=extract_layer,
521
  position_extract=position_extract,
522
  digit_pred=digit_pred,
523
- positional_digit=positional_digit
 
524
  )
525
 
526
  optimizer = optim.AdamW(model.trainable_parameters(), lr=lr)
@@ -534,7 +675,7 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
534
  print(f" Samples/epoch: {batch_size * 20}")
535
 
536
  print(f"\nInitial evaluation (200 samples)...")
537
- acc, op_acc = evaluate_llm(model, n_samples=200)
538
  print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}")
539
 
540
  print(f"\nStarting training...")
@@ -551,7 +692,9 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
551
  max_val = get_curriculum_max(epoch, epochs)
552
 
553
  epoch_loss = 0
554
- if positional_digit:
 
 
555
  epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0}
556
  else:
557
  epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0}
@@ -589,7 +732,13 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
589
  outputs = model(batch_texts)
590
  pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
591
 
592
- if positional_digit:
 
 
 
 
 
 
593
  a_digit_logits_list = outputs[7]
594
  b_digit_logits_list = outputs[8]
595
  loss, losses = compute_positional_digit_loss(
@@ -621,7 +770,7 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
621
  for k in epoch_losses:
622
  epoch_losses[k] /= n_batches
623
 
624
- acc, op_acc = evaluate_llm(model, n_samples=300)
625
  elapsed = time.perf_counter() - start_time
626
 
627
  marker = " *" if acc > best_acc else ""
@@ -632,7 +781,11 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
632
  print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | "
633
  f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | "
634
  f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s")
635
- if positional_digit:
 
 
 
 
636
  print(f" Losses - result:{epoch_losses['result']:.4f} "
637
  f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} "
638
  f"op:{epoch_losses['op']:.4f}")
@@ -651,7 +804,7 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
651
  print(" FINAL EVALUATION")
652
  print("=" * 70)
653
 
654
- acc, op_acc = evaluate_llm(model, n_samples=1000)
655
  print(f"Final accuracy: {acc:.4f}")
656
  print(f"Final op accuracy: {op_acc:.4f}")
657
  print(f"Best accuracy: {best_acc:.4f}")
@@ -717,19 +870,19 @@ Examples:
717
  choices=['router', 'interface', 'llm'],
718
  help='Training mode')
719
  parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
720
- parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
721
  parser.add_argument('--lr', type=float, default=None,
722
  help='Learning rate (default: mode-specific)')
723
  parser.add_argument('--unfreeze_layers', type=int, default=0,
724
  help='Unfreeze top N transformer layers (default 0 = frozen)')
725
- parser.add_argument('--extract_layer', type=int, default=-1,
726
- help='Which layer to extract from (-1 = last)')
727
  parser.add_argument('--position_extract', action='store_true',
728
- help='Use position-specific extraction')
729
  parser.add_argument('--digit_pred', action='store_true',
730
- help='Predict digits instead of bits')
731
- parser.add_argument('--positional_digit', action='store_true',
732
- help='Use positional digit extraction (100% proven accuracy)')
733
  parser.add_argument('--device', type=str, default='cuda', help='Device')
734
  args = parser.parse_args()
735
 
 
39
 
40
  DEVICE = 'cuda'
41
 
42
+ ONES = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
43
+ 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
44
+ 'seventeen', 'eighteen', 'nineteen']
45
+ TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety']
46
+
47
+ def int_to_words(n: int) -> str:
48
+ """Convert integer 0-255 to English words."""
49
+ if n < 0 or n > 255:
50
+ return str(n)
51
+ if n < 20:
52
+ return ONES[n]
53
+ if n < 100:
54
+ if n % 10 == 0:
55
+ return TENS[n // 10]
56
+ return f"{TENS[n // 10]} {ONES[n % 10]}"
57
+ if n % 100 == 0:
58
+ return f"{ONES[n // 100]} hundred"
59
+ if n % 100 < 20:
60
+ return f"{ONES[n // 100]} hundred {ONES[n % 100]}"
61
+ if n % 10 == 0:
62
+ return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]}"
63
+ return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]} {ONES[n % 10]}"
64
+
65
 
66
  def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor:
67
  bits = torch.zeros(8, device=device)
 
78
  return val
79
 
80
 
81
+ NL_TEMPLATES = {
82
+ 'add': [
83
+ "What is {a} plus {b}?",
84
+ "Calculate {a} + {b}",
85
+ "Add {a} and {b}",
86
+ "What's the sum of {a} and {b}?",
87
+ "If I have {a} and get {b} more, how many total?",
88
+ "{a} + {b} = ?",
89
+ "Compute {a} plus {b}",
90
+ ],
91
+ 'sub': [
92
+ "What is {a} minus {b}?",
93
+ "Calculate {a} - {b}",
94
+ "Subtract {b} from {a}",
95
+ "What's {a} take away {b}?",
96
+ "If I have {a} and lose {b}, how many left?",
97
+ "{a} - {b} = ?",
98
+ "Compute {a} minus {b}",
99
+ ],
100
+ 'mul': [
101
+ "What is {a} times {b}?",
102
+ "Calculate {a} * {b}",
103
+ "Multiply {a} by {b}",
104
+ "What's {a} multiplied by {b}?",
105
+ "{a} * {b} = ?",
106
+ "Compute {a} times {b}",
107
+ "What is the product of {a} and {b}?",
108
+ ],
109
+ 'gt': [
110
+ "Is {a} greater than {b}?",
111
+ "Is {a} > {b}?",
112
+ "Check if {a} is larger than {b}",
113
+ "Compare: is {a} more than {b}?",
114
+ "{a} > {b}?",
115
+ ],
116
+ 'lt': [
117
+ "Is {a} less than {b}?",
118
+ "Is {a} < {b}?",
119
+ "Check if {a} is smaller than {b}",
120
+ "Compare: is {a} fewer than {b}?",
121
+ "{a} < {b}?",
122
+ ],
123
+ 'eq': [
124
+ "Is {a} equal to {b}?",
125
+ "Is {a} == {b}?",
126
+ "Does {a} equal {b}?",
127
+ "Check if {a} equals {b}",
128
+ "Are {a} and {b} the same?",
129
+ ],
130
+ }
131
+
132
+
133
  def generate_problem(max_val: int = 255):
134
+ """
135
+ Generate a random arithmetic problem for LLM training.
136
+ Randomly mixes digit and word formats.
137
+ """
138
  a = random.randint(0, max_val)
139
  b = random.randint(0, max_val)
140
  op = random.choice(OPERATIONS)
141
 
142
+ fmt = random.choice(['digits', 'words', 'nl_digits', 'nl_words'])
143
+
144
+ if fmt == 'digits':
145
+ sym = OP_SYMBOLS[op]
146
+ text = f"{a} {sym} {b}"
147
+ elif fmt == 'words':
148
+ a_word = int_to_words(a)
149
+ b_word = int_to_words(b)
150
+ op_word = {'add': 'plus', 'sub': 'minus', 'mul': 'times',
151
+ 'gt': 'greater than', 'lt': 'less than', 'eq': 'equals'}[op]
152
+ text = f"{a_word} {op_word} {b_word}"
153
+ elif fmt == 'nl_digits':
154
+ template = random.choice(NL_TEMPLATES[op])
155
+ text = template.format(a=a, b=b)
156
+ else:
157
+ template = random.choice(NL_TEMPLATES[op])
158
+ text = template.format(a=int_to_words(a), b=int_to_words(b))
159
 
160
  if op == 'add':
161
  result = (a + b) & 0xFF
 
550
  }
551
 
552
 
553
+ def compute_hybrid_loss(pred_bits, a_values, b_values, op_logits, used_lookup,
554
+ target_result, target_a_values, target_b_values, target_op_idx,
555
+ device, value_weight: float = 1.0):
556
+ """
557
+ Loss for hybrid extraction.
558
+
559
+ Only compute value loss for samples where lookup was NOT used (word numbers).
560
+ Samples using digit lookup are already 100% accurate.
561
+ """
562
+ result_loss = nn.functional.binary_cross_entropy_with_logits(
563
+ pred_bits, target_result
564
+ )
565
+
566
+ op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
567
+
568
+ word_mask = ~used_lookup
569
+ n_words = word_mask.sum().item()
570
+
571
+ if n_words > 0:
572
+ a_word_values = a_values[word_mask]
573
+ b_word_values = b_values[word_mask]
574
+ target_a_word = target_a_values[word_mask]
575
+ target_b_word = target_b_values[word_mask]
576
+
577
+ a_value_loss = nn.functional.mse_loss(a_word_values, target_a_word)
578
+ b_value_loss = nn.functional.mse_loss(b_word_values, target_b_word)
579
+ else:
580
+ a_value_loss = torch.tensor(0.0, device=device)
581
+ b_value_loss = torch.tensor(0.0, device=device)
582
+
583
+ total = result_loss + op_loss + value_weight * (a_value_loss + b_value_loss)
584
+ total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0)
585
+
586
+ return total, {
587
+ 'result': result_loss.item() if not torch.isnan(result_loss) else 10.0,
588
+ 'a_value': a_value_loss.item() if not torch.isnan(a_value_loss) else 10.0,
589
+ 'b_value': b_value_loss.item() if not torch.isnan(b_value_loss) else 10.0,
590
+ 'op': op_loss.item() if not torch.isnan(op_loss) else 10.0,
591
+ 'n_words': n_words,
592
+ 'n_lookup': used_lookup.sum().item()
593
+ }
594
+
595
+
596
  def evaluate_llm(model, n_samples: int = 500):
597
+ """Evaluate LLM model on random problems (mixed digit/word format)."""
598
  model.extractor.eval()
599
  correct = 0
600
  op_correct = 0
 
629
  Args:
630
  unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen)
631
  extract_layer: Which layer to extract from (-1 = last)
632
+ position_extract: Use position-specific extraction (legacy)
633
+ digit_pred: Predict digits instead of bits (legacy)
634
+ positional_digit: Use positional digit extraction (legacy, 100% on digits only)
635
  """
636
+ hybrid = not (positional_digit or position_extract or digit_pred)
637
+
638
  print("=" * 70)
639
  print(" LLM TRAINING")
640
  if unfreeze_layers > 0:
 
643
  print(" LLM frozen")
644
  if extract_layer != -1:
645
  print(f" Extracting from layer {extract_layer}")
646
+ if hybrid:
647
+ print(" HYBRID extraction (digit lookup + word learning)")
648
+ elif positional_digit:
649
+ print(" POSITIONAL DIGIT extraction (legacy, 100% on digits only)")
650
  elif position_extract:
651
+ print(" Position-specific extraction (legacy)")
652
+ elif digit_pred:
653
+ print(" Digit-level prediction (legacy)")
654
  print("=" * 70)
655
 
656
  print("\nInitializing model...")
 
660
  extract_layer=extract_layer,
661
  position_extract=position_extract,
662
  digit_pred=digit_pred,
663
+ positional_digit=positional_digit,
664
+ hybrid=hybrid
665
  )
666
 
667
  optimizer = optim.AdamW(model.trainable_parameters(), lr=lr)
 
675
  print(f" Samples/epoch: {batch_size * 20}")
676
 
677
  print(f"\nInitial evaluation (200 samples)...")
678
+ acc, op_acc = evaluate_llm(model, 200)
679
  print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}")
680
 
681
  print(f"\nStarting training...")
 
692
  max_val = get_curriculum_max(epoch, epochs)
693
 
694
  epoch_loss = 0
695
+ if hybrid:
696
+ epoch_losses = {'result': 0, 'a_value': 0, 'b_value': 0, 'op': 0, 'n_words': 0, 'n_lookup': 0}
697
+ elif positional_digit:
698
  epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0}
699
  else:
700
  epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0}
 
732
  outputs = model(batch_texts)
733
  pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
734
 
735
+ if hybrid:
736
+ a_values, b_values, used_lookup = outputs[4], outputs[5], outputs[6]
737
+ loss, losses = compute_hybrid_loss(
738
+ pred_bits, a_values, b_values, op_logits, used_lookup,
739
+ target_result, target_a_values, target_b_values, target_op, device
740
+ )
741
+ elif positional_digit:
742
  a_digit_logits_list = outputs[7]
743
  b_digit_logits_list = outputs[8]
744
  loss, losses = compute_positional_digit_loss(
 
770
  for k in epoch_losses:
771
  epoch_losses[k] /= n_batches
772
 
773
+ acc, op_acc = evaluate_llm(model, 300)
774
  elapsed = time.perf_counter() - start_time
775
 
776
  marker = " *" if acc > best_acc else ""
 
781
  print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | "
782
  f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | "
783
  f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s")
784
+ if hybrid:
785
+ print(f" Losses - result:{epoch_losses['result']:.4f} "
786
+ f"a_val:{epoch_losses['a_value']:.4f} b_val:{epoch_losses['b_value']:.4f} "
787
+ f"op:{epoch_losses['op']:.4f} | words:{epoch_losses['n_words']:.0f} lookup:{epoch_losses['n_lookup']:.0f}")
788
+ elif positional_digit:
789
  print(f" Losses - result:{epoch_losses['result']:.4f} "
790
  f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} "
791
  f"op:{epoch_losses['op']:.4f}")
 
804
  print(" FINAL EVALUATION")
805
  print("=" * 70)
806
 
807
+ acc, op_acc = evaluate_llm(model, 1000)
808
  print(f"Final accuracy: {acc:.4f}")
809
  print(f"Final op accuracy: {op_acc:.4f}")
810
  print(f"Best accuracy: {best_acc:.4f}")
 
870
  choices=['router', 'interface', 'llm'],
871
  help='Training mode')
872
  parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
873
+ parser.add_argument('--batch_size', type=int, default=512, help='Batch size (default: 512)')
874
  parser.add_argument('--lr', type=float, default=None,
875
  help='Learning rate (default: mode-specific)')
876
  parser.add_argument('--unfreeze_layers', type=int, default=0,
877
  help='Unfreeze top N transformer layers (default 0 = frozen)')
878
+ parser.add_argument('--extract_layer', type=int, default=0,
879
+ help='Which layer to extract from (default: 0 = embeddings, best for digits)')
880
  parser.add_argument('--position_extract', action='store_true',
881
+ help='Use position-specific extraction (legacy)')
882
  parser.add_argument('--digit_pred', action='store_true',
883
+ help='Predict digits instead of bits (legacy)')
884
+ parser.add_argument('--positional_digit', action='store_true', default=False,
885
+ help='Use positional digit extraction (legacy, 100%% on digits only)')
886
  parser.add_argument('--device', type=str, default='cuda', help='Device')
887
  args = parser.parse_args()
888