CharlesCNorton commited on
Commit
fe691a6
·
1 Parent(s): 7542035

Add operator-aware splitting for word number extraction

Browse files

- SYMBOL_OP_TOKENS for ' +', ' -', etc.
- WORD_OP_TOKENS for 'plus', 'minus', etc. (token IDs)
- ALL_OP_TOKENS combines both for unified lookup
- Separate a_pool and b_pool attention modules
- _find_op_position() to locate operator in sequence
- Split hidden states at operator: a tokens before, b tokens after
- Each operand pooled separately before digit prediction

Files changed (2) hide show
  1. .gitignore +1 -0
  2. llm_integration/model.py +49 -9
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
  *.pyc
 
 
1
  __pycache__/
2
  *.pyc
3
+ .pt file
llm_integration/model.py CHANGED
@@ -757,7 +757,7 @@ class HybridExtractor(nn.Module):
757
  """
758
 
759
  DIGIT_TOKENS = set(range(32, 42))
760
- OPERATOR_TOKENS = {
761
  1232: 0, # ' +' -> add
762
  731: 1, # ' -' -> sub
763
  1672: 2, # ' *' -> mul
@@ -766,15 +766,23 @@ class HybridExtractor(nn.Module):
766
  1758: 5, # ' ==' -> eq
767
  }
768
  WORD_OP_TOKENS = {
769
- 'plus': 0, 'minus': 1, 'times': 2,
770
- 'greater': 3, 'less': 4, 'equals': 5, 'equal': 5,
 
 
 
 
 
771
  }
 
772
 
773
  def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
774
  super().__init__()
775
  self.hidden_dim = hidden_dim
776
 
777
  self.attention_pool = AttentionPooling(hidden_dim, num_heads)
 
 
778
 
779
  self.a_digit_pred = nn.Sequential(
780
  nn.Linear(hidden_dim, intermediate_dim),
@@ -803,9 +811,18 @@ class HybridExtractor(nn.Module):
803
  return True
804
  return False
805
 
 
 
 
 
 
 
 
 
806
  def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
807
  """
808
  Extract values directly from digit tokens (hardcoded lookup).
 
809
  Returns (a_value, b_value, op_idx) or None if pattern not found.
810
  """
811
  tokens = token_ids.tolist()
@@ -813,9 +830,9 @@ class HybridExtractor(nn.Module):
813
  op_pos = -1
814
  op_idx = 0
815
  for i, tid in enumerate(tokens):
816
- if tid in self.OPERATOR_TOKENS:
817
  op_pos = i
818
- op_idx = self.OPERATOR_TOKENS[tid]
819
  break
820
 
821
  if op_pos == -1:
@@ -934,11 +951,34 @@ class HybridExtractor(nn.Module):
934
  a_digit_logits_list.append(None)
935
  b_digit_logits_list.append(None)
936
  else:
937
- sample_pooled = pooled[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
 
939
- a_digit_logits = self.a_digit_pred(sample_pooled)
940
- b_digit_logits = self.b_digit_pred(sample_pooled)
941
- op_logits = self.op_predictor(sample_pooled)
942
 
943
  a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device)
944
  b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device)
 
757
  """
758
 
759
  DIGIT_TOKENS = set(range(32, 42))
760
+ SYMBOL_OP_TOKENS = {
761
  1232: 0, # ' +' -> add
762
  731: 1, # ' -' -> sub
763
  1672: 2, # ' *' -> mul
 
766
  1758: 5, # ' ==' -> eq
767
  }
768
  WORD_OP_TOKENS = {
769
+ 2068: 0, # 'plus' -> add
770
+ 8500: 1, # 'minus' -> sub
771
+ 1580: 2, # 'times' -> mul
772
+ 6301: 3, # 'greater' -> gt
773
+ 1912: 4, # 'less' -> lt
774
+ 16364: 5, # 'equals' -> eq
775
+ 11540: 5, # 'equal' -> eq
776
  }
777
+ ALL_OP_TOKENS = {**SYMBOL_OP_TOKENS, **WORD_OP_TOKENS}
778
 
779
  def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
780
  super().__init__()
781
  self.hidden_dim = hidden_dim
782
 
783
  self.attention_pool = AttentionPooling(hidden_dim, num_heads)
784
+ self.a_pool = AttentionPooling(hidden_dim, num_heads)
785
+ self.b_pool = AttentionPooling(hidden_dim, num_heads)
786
 
787
  self.a_digit_pred = nn.Sequential(
788
  nn.Linear(hidden_dim, intermediate_dim),
 
811
  return True
812
  return False
813
 
814
+ def _find_op_position(self, token_ids: torch.Tensor) -> int:
815
+ """Find position of operator token, returns -1 if not found."""
816
+ tokens = token_ids.tolist()
817
+ for i, tid in enumerate(tokens):
818
+ if tid in self.ALL_OP_TOKENS:
819
+ return i
820
+ return -1
821
+
822
  def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
823
  """
824
  Extract values directly from digit tokens (hardcoded lookup).
825
+ Handles both symbol operators (' +') and word operators ('plus').
826
  Returns (a_value, b_value, op_idx) or None if pattern not found.
827
  """
828
  tokens = token_ids.tolist()
 
830
  op_pos = -1
831
  op_idx = 0
832
  for i, tid in enumerate(tokens):
833
+ if tid in self.ALL_OP_TOKENS:
834
  op_pos = i
835
+ op_idx = self.ALL_OP_TOKENS[tid]
836
  break
837
 
838
  if op_pos == -1:
 
951
  a_digit_logits_list.append(None)
952
  b_digit_logits_list.append(None)
953
  else:
954
+ sample_hidden = hidden[i:i+1]
955
+ sample_mask = mask[i:i+1]
956
+
957
+ seq_mask = mask[i].bool()
958
+ valid_len = int(seq_mask.sum().item())
959
+ start_pos = hidden.shape[1] - valid_len
960
+ valid_tokens = token_ids[i, start_pos:] if token_ids is not None else None
961
+
962
+ op_pos = self._find_op_position(valid_tokens) if valid_tokens is not None else -1
963
+
964
+ if op_pos > 0 and op_pos < valid_len - 1:
965
+ a_end = start_pos + op_pos
966
+ b_start = start_pos + op_pos + 1
967
+
968
+ a_mask = torch.zeros_like(sample_mask)
969
+ a_mask[0, start_pos:a_end] = 1.0
970
+ b_mask = torch.zeros_like(sample_mask)
971
+ b_mask[0, b_start:] = sample_mask[0, b_start:]
972
+
973
+ a_pooled = self.a_pool(sample_hidden, a_mask)[0]
974
+ b_pooled = self.b_pool(sample_hidden, b_mask)[0]
975
+ else:
976
+ a_pooled = pooled[i]
977
+ b_pooled = pooled[i]
978
 
979
+ a_digit_logits = self.a_digit_pred(a_pooled)
980
+ b_digit_logits = self.b_digit_pred(b_pooled)
981
+ op_logits = self.op_predictor(pooled[i])
982
 
983
  a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device)
984
  b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device)