CharlesCNorton commited on
Commit ·
fe691a6
1
Parent(s): 7542035
Add operator-aware splitting for word number extraction
Browse files- SYMBOL_OP_TOKENS for ' +', ' -', etc.
- WORD_OP_TOKENS for 'plus', 'minus', etc. (token IDs)
- ALL_OP_TOKENS combines both for unified lookup
- Separate a_pool and b_pool attention modules
- _find_op_position() to locate operator in sequence
- Split hidden states at operator: a tokens before, b tokens after
- Each operand pooled separately before digit prediction
- .gitignore +1 -0
- llm_integration/model.py +49 -9
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
|
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
+
.pt file
|
llm_integration/model.py
CHANGED
|
@@ -757,7 +757,7 @@ class HybridExtractor(nn.Module):
|
|
| 757 |
"""
|
| 758 |
|
| 759 |
DIGIT_TOKENS = set(range(32, 42))
|
| 760 |
-
|
| 761 |
1232: 0, # ' +' -> add
|
| 762 |
731: 1, # ' -' -> sub
|
| 763 |
1672: 2, # ' *' -> mul
|
|
@@ -766,15 +766,23 @@ class HybridExtractor(nn.Module):
|
|
| 766 |
1758: 5, # ' ==' -> eq
|
| 767 |
}
|
| 768 |
WORD_OP_TOKENS = {
|
| 769 |
-
|
| 770 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
}
|
|
|
|
| 772 |
|
| 773 |
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
|
| 774 |
super().__init__()
|
| 775 |
self.hidden_dim = hidden_dim
|
| 776 |
|
| 777 |
self.attention_pool = AttentionPooling(hidden_dim, num_heads)
|
|
|
|
|
|
|
| 778 |
|
| 779 |
self.a_digit_pred = nn.Sequential(
|
| 780 |
nn.Linear(hidden_dim, intermediate_dim),
|
|
@@ -803,9 +811,18 @@ class HybridExtractor(nn.Module):
|
|
| 803 |
return True
|
| 804 |
return False
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
|
| 807 |
"""
|
| 808 |
Extract values directly from digit tokens (hardcoded lookup).
|
|
|
|
| 809 |
Returns (a_value, b_value, op_idx) or None if pattern not found.
|
| 810 |
"""
|
| 811 |
tokens = token_ids.tolist()
|
|
@@ -813,9 +830,9 @@ class HybridExtractor(nn.Module):
|
|
| 813 |
op_pos = -1
|
| 814 |
op_idx = 0
|
| 815 |
for i, tid in enumerate(tokens):
|
| 816 |
-
if tid in self.
|
| 817 |
op_pos = i
|
| 818 |
-
op_idx = self.
|
| 819 |
break
|
| 820 |
|
| 821 |
if op_pos == -1:
|
|
@@ -934,11 +951,34 @@ class HybridExtractor(nn.Module):
|
|
| 934 |
a_digit_logits_list.append(None)
|
| 935 |
b_digit_logits_list.append(None)
|
| 936 |
else:
|
| 937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
| 939 |
-
a_digit_logits = self.a_digit_pred(
|
| 940 |
-
b_digit_logits = self.b_digit_pred(
|
| 941 |
-
op_logits = self.op_predictor(
|
| 942 |
|
| 943 |
a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device)
|
| 944 |
b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device)
|
|
|
|
| 757 |
"""
|
| 758 |
|
| 759 |
DIGIT_TOKENS = set(range(32, 42))
|
| 760 |
+
SYMBOL_OP_TOKENS = {
|
| 761 |
1232: 0, # ' +' -> add
|
| 762 |
731: 1, # ' -' -> sub
|
| 763 |
1672: 2, # ' *' -> mul
|
|
|
|
| 766 |
1758: 5, # ' ==' -> eq
|
| 767 |
}
|
| 768 |
WORD_OP_TOKENS = {
|
| 769 |
+
2068: 0, # 'plus' -> add
|
| 770 |
+
8500: 1, # 'minus' -> sub
|
| 771 |
+
1580: 2, # 'times' -> mul
|
| 772 |
+
6301: 3, # 'greater' -> gt
|
| 773 |
+
1912: 4, # 'less' -> lt
|
| 774 |
+
16364: 5, # 'equals' -> eq
|
| 775 |
+
11540: 5, # 'equal' -> eq
|
| 776 |
}
|
| 777 |
+
ALL_OP_TOKENS = {**SYMBOL_OP_TOKENS, **WORD_OP_TOKENS}
|
| 778 |
|
| 779 |
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
|
| 780 |
super().__init__()
|
| 781 |
self.hidden_dim = hidden_dim
|
| 782 |
|
| 783 |
self.attention_pool = AttentionPooling(hidden_dim, num_heads)
|
| 784 |
+
self.a_pool = AttentionPooling(hidden_dim, num_heads)
|
| 785 |
+
self.b_pool = AttentionPooling(hidden_dim, num_heads)
|
| 786 |
|
| 787 |
self.a_digit_pred = nn.Sequential(
|
| 788 |
nn.Linear(hidden_dim, intermediate_dim),
|
|
|
|
| 811 |
return True
|
| 812 |
return False
|
| 813 |
|
| 814 |
+
def _find_op_position(self, token_ids: torch.Tensor) -> int:
|
| 815 |
+
"""Find position of operator token, returns -1 if not found."""
|
| 816 |
+
tokens = token_ids.tolist()
|
| 817 |
+
for i, tid in enumerate(tokens):
|
| 818 |
+
if tid in self.ALL_OP_TOKENS:
|
| 819 |
+
return i
|
| 820 |
+
return -1
|
| 821 |
+
|
| 822 |
def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
|
| 823 |
"""
|
| 824 |
Extract values directly from digit tokens (hardcoded lookup).
|
| 825 |
+
Handles both symbol operators (' +') and word operators ('plus').
|
| 826 |
Returns (a_value, b_value, op_idx) or None if pattern not found.
|
| 827 |
"""
|
| 828 |
tokens = token_ids.tolist()
|
|
|
|
| 830 |
op_pos = -1
|
| 831 |
op_idx = 0
|
| 832 |
for i, tid in enumerate(tokens):
|
| 833 |
+
if tid in self.ALL_OP_TOKENS:
|
| 834 |
op_pos = i
|
| 835 |
+
op_idx = self.ALL_OP_TOKENS[tid]
|
| 836 |
break
|
| 837 |
|
| 838 |
if op_pos == -1:
|
|
|
|
| 951 |
a_digit_logits_list.append(None)
|
| 952 |
b_digit_logits_list.append(None)
|
| 953 |
else:
|
| 954 |
+
sample_hidden = hidden[i:i+1]
|
| 955 |
+
sample_mask = mask[i:i+1]
|
| 956 |
+
|
| 957 |
+
seq_mask = mask[i].bool()
|
| 958 |
+
valid_len = int(seq_mask.sum().item())
|
| 959 |
+
start_pos = hidden.shape[1] - valid_len
|
| 960 |
+
valid_tokens = token_ids[i, start_pos:] if token_ids is not None else None
|
| 961 |
+
|
| 962 |
+
op_pos = self._find_op_position(valid_tokens) if valid_tokens is not None else -1
|
| 963 |
+
|
| 964 |
+
if op_pos > 0 and op_pos < valid_len - 1:
|
| 965 |
+
a_end = start_pos + op_pos
|
| 966 |
+
b_start = start_pos + op_pos + 1
|
| 967 |
+
|
| 968 |
+
a_mask = torch.zeros_like(sample_mask)
|
| 969 |
+
a_mask[0, start_pos:a_end] = 1.0
|
| 970 |
+
b_mask = torch.zeros_like(sample_mask)
|
| 971 |
+
b_mask[0, b_start:] = sample_mask[0, b_start:]
|
| 972 |
+
|
| 973 |
+
a_pooled = self.a_pool(sample_hidden, a_mask)[0]
|
| 974 |
+
b_pooled = self.b_pool(sample_hidden, b_mask)[0]
|
| 975 |
+
else:
|
| 976 |
+
a_pooled = pooled[i]
|
| 977 |
+
b_pooled = pooled[i]
|
| 978 |
|
| 979 |
+
a_digit_logits = self.a_digit_pred(a_pooled)
|
| 980 |
+
b_digit_logits = self.b_digit_pred(b_pooled)
|
| 981 |
+
op_logits = self.op_predictor(pooled[i])
|
| 982 |
|
| 983 |
a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device)
|
| 984 |
b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device)
|