CharlesCNorton
Add operator-aware splitting for word number extraction
fe691a6
"""
Trainable interface layers for frozen threshold circuits.
BitEncoder, OpRouter, BitDecoder wrap the frozen circuits.
HiddenStateExtractor and AugmentedArithmeticModel for LLM integration.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from circuits import FrozenThresholdCircuits, heaviside_ste
MODEL_ID = 'HuggingFaceTB/SmolLM2-360M-Instruct'
OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq']
OP_SYMBOLS = {'add': '+', 'sub': '-', 'mul': '*', 'gt': '>', 'lt': '<', 'eq': '=='}
class BitEncoder(nn.Module):
"""
Encodes two 8-bit operands from input representation.
Uses residual connection to preserve ground truth bits while allowing learned refinement.
"""
def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32):
super().__init__()
self.refine = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 16),
)
self.scale = nn.Parameter(torch.tensor(0.0))
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: [batch, input_dim] input with first 16 dims being a_bits, b_bits
Returns:
a_bits: [batch, 8] first operand bits
b_bits: [batch, 8] second operand bits
"""
base_bits = x[:, :16]
refinement = self.refine(x) * torch.sigmoid(self.scale)
bits = base_bits + refinement
bits = torch.clamp(bits, 0, 1)
hard_bits = heaviside_ste(bits - 0.5)
out = hard_bits - bits.detach() + bits
return out[:, :8], out[:, 8:]
class OpRouter(nn.Module):
"""
Routes computation to the appropriate circuit based on input.
Outputs soft weights over operations for gradient flow.
"""
def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32, n_ops: int = 6):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_ops),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, input_dim] input features
Returns:
op_weights: [batch, n_ops] soft operation weights (softmax)
"""
logits = self.net(x)
return F.softmax(logits, dim=-1)
class BitDecoder(nn.Module):
"""
Decodes circuit output bits to target representation.
For standalone training: outputs soft bits for loss computation.
For LLM integration: would project to hidden state delta.
"""
def __init__(self, output_dim: int = 8):
super().__init__()
self.output_dim = output_dim
def forward(self, result_bits: torch.Tensor) -> torch.Tensor:
return result_bits
class ThresholdALU(nn.Module):
"""
Complete trainable interface + frozen circuits.
Learns to encode inputs, route to circuits, decode outputs.
"""
def __init__(self, device: str = 'cuda'):
super().__init__()
self.device = device
self.circuits = FrozenThresholdCircuits(device=device)
for key in self.circuits.weights:
self.circuits.weights[key].requires_grad = False
self.encoder = BitEncoder(input_dim=16 + 6, hidden_dim=64).to(device)
self.router = OpRouter(input_dim=16 + 6, hidden_dim=32, n_ops=6).to(device)
self.decoder = BitDecoder(output_dim=8).to(device)
def forward(self, a_bits_in: torch.Tensor, b_bits_in: torch.Tensor,
op_onehot: torch.Tensor) -> torch.Tensor:
"""
Forward pass through trainable interface + frozen circuits.
Args:
a_bits_in: [batch, 8] input A bits (ground truth for training)
b_bits_in: [batch, 8] input B bits (ground truth for training)
op_onehot: [batch, 6] one-hot operation selector
Returns:
result_bits: [batch, 8] output bits
"""
x = torch.cat([a_bits_in, b_bits_in, op_onehot], dim=-1)
a_bits, b_bits = self.encoder(x)
op_weights = self.router(x)
result = self.circuits(a_bits, b_bits, op_weights)
output = self.decoder(result)
return output
def forward_direct(self, a_bits: torch.Tensor, b_bits: torch.Tensor,
op_onehot: torch.Tensor) -> torch.Tensor:
"""
Direct forward through circuits (bypass encoder/router for testing).
"""
return self.circuits(a_bits, b_bits, op_onehot)
class DirectCircuitModel(nn.Module):
"""
Minimal model that directly uses circuits without learned encoding.
For validating that circuits themselves achieve 100% fitness.
"""
def __init__(self, device: str = 'cuda'):
super().__init__()
self.device = device
self.circuits = FrozenThresholdCircuits(device=device)
def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor,
op_onehot: torch.Tensor) -> torch.Tensor:
return self.circuits(a_bits, b_bits, op_onehot)
class HiddenStateExtractor(nn.Module):
"""
Extracts operands and operation from LLM hidden states.
This is the hard part - must learn to parse numbers from embeddings.
"""
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
super().__init__()
self.a_extractor = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, 8),
)
self.b_extractor = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, 8),
)
self.op_router = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, len(OPERATIONS)),
)
def forward(self, hidden_states: torch.Tensor):
"""
Args:
hidden_states: [batch, hidden_dim] from LLM
Returns:
a_bits: [batch, 8]
b_bits: [batch, 8]
op_logits: [batch, 6]
"""
a_logits = self.a_extractor(hidden_states)
b_logits = self.b_extractor(hidden_states)
op_logits = self.op_router(hidden_states)
a_soft = torch.sigmoid(a_logits)
b_soft = torch.sigmoid(b_logits)
a_hard = heaviside_ste(a_logits)
b_hard = heaviside_ste(b_logits)
a_bits = a_hard - a_soft.detach() + a_soft
b_bits = b_hard - b_soft.detach() + b_soft
return a_bits, b_bits, op_logits
class AttentionPooling(nn.Module):
"""
Learnable attention pooling over sequence positions.
Replaces mean pooling - learns which tokens matter for extraction.
"""
def __init__(self, hidden_dim: int = 960, num_heads: int = 4):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.query = nn.Linear(hidden_dim, hidden_dim)
self.key = nn.Linear(hidden_dim, hidden_dim)
self.value = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
def forward(self, embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Args:
embeddings: [batch, seq_len, hidden_dim]
mask: [batch, seq_len] attention mask (1 = attend, 0 = ignore)
Returns:
pooled: [batch, hidden_dim]
"""
batch_size, seq_len, hidden_dim = embeddings.shape
cls_expanded = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat([cls_expanded, embeddings], dim=1)
cls_mask = torch.ones(batch_size, 1, device=mask.device)
mask = torch.cat([cls_mask, mask], dim=1)
Q = self.query(embeddings[:, :1, :])
K = self.key(embeddings)
V = self.value(embeddings)
Q = Q.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
mask_expanded = mask.unsqueeze(1).unsqueeze(2)
scores = scores.masked_fill(mask_expanded == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, 1, hidden_dim)
pooled = self.out_proj(context).squeeze(1)
pooled = torch.nan_to_num(pooled, nan=0.0)
return pooled
class MultiHeadBitExtractor(nn.Module):
"""
8 separate extractors for 8 bits - each bit gets its own specialized network.
More expressive than single MLP predicting all 8 bits at once.
"""
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 128):
super().__init__()
self.bit_extractors = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, 1),
)
for _ in range(8)
])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: [batch, hidden_dim]
Returns:
bits: [batch, 8] - one bit from each extractor
"""
hidden_states = torch.nan_to_num(hidden_states, nan=0.0)
bit_logits = [extractor(hidden_states) for extractor in self.bit_extractors]
logits = torch.cat(bit_logits, dim=-1)
logits = torch.clamp(logits, -20, 20)
soft = torch.sigmoid(logits)
hard = heaviside_ste(logits)
bits = hard - soft.detach() + soft
return bits, logits
class Extractor(nn.Module):
"""
Extracts operands and operation from LLM hidden states.
Uses attention pooling and per-bit extraction networks.
"""
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
super().__init__()
self.attention_pool = AttentionPooling(hidden_dim, num_heads)
self.a_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2)
self.b_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2)
self.op_router = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, len(OPERATIONS)),
)
def forward(self, embeddings: torch.Tensor, mask: torch.Tensor):
"""
Args:
embeddings: [batch, seq_len, hidden_dim]
mask: [batch, seq_len]
Returns:
a_bits: [batch, 8]
b_bits: [batch, 8]
op_logits: [batch, 6]
"""
pooled = self.attention_pool(embeddings, mask)
a_bits, _ = self.a_extractor(pooled)
b_bits, _ = self.b_extractor(pooled)
op_logits = self.op_router(pooled)
return a_bits, b_bits, op_logits
class PositionExtractor(nn.Module):
"""
Position-specific extraction with dynamic operator detection.
Tokenization pattern for "A op B":
[A_digits...] [operator] [space] [B_digits...]
Examples:
"5 + 3" -> ['5', ' +', ' ', '3'] (positions: A=0, op=1, B=3)
"47 + 86" -> ['4', '7', ' +', ' ', '8', '6'] (positions: A=0-1, op=2, B=4-5)
"127 + 128" -> ['1','2','7',' +', ' ','1','2','8'] (positions: A=0-2, op=3, B=5-7)
Token IDs (SmolLM2):
Digits '0'-'9': 32-41
Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758
Space: 216
"""
DIGIT_TOKENS = set(range(32, 42))
OPERATOR_TOKENS = {
1232: 0, # ' +' -> add
731: 1, # ' -' -> sub
1672: 2, # ' *' -> mul
2986: 3, # ' >' -> gt
2067: 4, # ' <' -> lt
1758: 5, # ' ==' -> eq
}
SPACE_TOKEN = 216
MAX_DIGITS = 3
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
super().__init__()
self.hidden_dim = hidden_dim
self.a_extractor = nn.Sequential(
nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, intermediate_dim // 2),
nn.GELU(),
nn.Linear(intermediate_dim // 2, 8),
)
self.b_extractor = nn.Sequential(
nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, intermediate_dim // 2),
nn.GELU(),
nn.Linear(intermediate_dim // 2, 8),
)
self.op_extractor = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim // 2),
nn.GELU(),
nn.Linear(intermediate_dim // 2, len(OPERATIONS)),
)
def _find_operator_position(self, token_ids: torch.Tensor) -> tuple[int, int]:
"""
Find operator token position and its operation index.
Args:
token_ids: [seq_len] tensor of token IDs
Returns:
(position, op_index) or (-1, -1) if not found
"""
for pos, tid in enumerate(token_ids.tolist()):
if tid in self.OPERATOR_TOKENS:
return pos, self.OPERATOR_TOKENS[tid]
return -1, -1
def _extract_digit_features(self, hidden: torch.Tensor, start: int, end: int) -> torch.Tensor:
"""
Extract and pad digit hidden states to fixed size.
Args:
hidden: [seq_len, hidden_dim]
start: start position (inclusive)
end: end position (exclusive)
Returns:
[hidden_dim * MAX_DIGITS] flattened features, zero-padded on the LEFT
(so units digit is always at the same position regardless of number length)
"""
n_digits = end - start
features = torch.zeros(self.MAX_DIGITS * self.hidden_dim, device=hidden.device)
if n_digits > 0 and n_digits <= self.MAX_DIGITS:
digit_hidden = hidden[start:end, :].reshape(-1)
pad_size = (self.MAX_DIGITS - n_digits) * self.hidden_dim
features[pad_size:] = digit_hidden
return features
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None):
"""
Args:
hidden: [batch, seq_len, hidden_dim]
mask: [batch, seq_len] attention mask
token_ids: [batch, seq_len] token IDs (required for operator detection)
Returns:
a_bits: [batch, 8]
b_bits: [batch, 8]
op_logits: [batch, 6]
"""
if token_ids is None:
raise ValueError("PositionExtractor requires token_ids for operator detection")
batch_size, seq_len, hidden_dim = hidden.shape
device = hidden.device
a_features = []
b_features = []
op_features = []
op_indices = []
for i in range(batch_size):
seq_mask = mask[i].bool()
valid_len = seq_mask.sum().item()
start_pos = seq_len - valid_len
valid_tokens = token_ids[i, start_pos:]
valid_hidden = hidden[i, start_pos:, :]
op_pos, op_idx = self._find_operator_position(valid_tokens)
if op_pos == -1:
a_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device)
b_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device)
op_feat = torch.zeros(hidden_dim, device=device)
op_idx = 0
else:
a_feat = self._extract_digit_features(valid_hidden, 0, op_pos)
op_feat = valid_hidden[op_pos, :]
b_start = op_pos + 2 if (op_pos + 1 < valid_len and
valid_tokens[op_pos + 1].item() == self.SPACE_TOKEN) else op_pos + 1
b_feat = self._extract_digit_features(valid_hidden, b_start, valid_len)
a_features.append(a_feat)
b_features.append(b_feat)
op_features.append(op_feat)
op_indices.append(op_idx)
a_features = torch.stack(a_features)
b_features = torch.stack(b_features)
op_features = torch.stack(op_features)
op_indices_tensor = torch.tensor(op_indices, device=device, dtype=torch.long)
a_logits = self.a_extractor(a_features)
b_logits = self.b_extractor(b_features)
op_logits = self.op_extractor(op_features)
a_soft = torch.sigmoid(a_logits)
b_soft = torch.sigmoid(b_logits)
a_hard = heaviside_ste(a_logits)
b_hard = heaviside_ste(b_logits)
a_bits = a_hard - a_soft.detach() + a_soft
b_bits = b_hard - b_soft.detach() + b_soft
return a_bits, b_bits, op_logits, op_indices_tensor
class PositionalDigitExtractor(nn.Module):
"""
Position-aware digit extraction: classifies each digit position independently.
This approach achieves 100% accuracy because:
1. Each digit token position is classified independently (100% accuracy on layer 0)
2. Numbers are reconstructed using place values (×100, ×10, ×1)
3. No information is lost through pooling
Token IDs (SmolLM2):
Digits '0'-'9': 32-41
Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758
Space: 216
"""
DIGIT_TOKENS = set(range(32, 42))
OPERATOR_TOKENS = {
1232: 0, # ' +' -> add
731: 1, # ' -' -> sub
1672: 2, # ' *' -> mul
2986: 3, # ' >' -> gt
2067: 4, # ' <' -> lt
1758: 5, # ' ==' -> eq
}
SPACE_TOKEN = 216
def __init__(self, hidden_dim: int = 960):
super().__init__()
self.hidden_dim = hidden_dim
self.digit_classifier = nn.Linear(hidden_dim, 10)
self.op_classifier = nn.Linear(hidden_dim, len(OPERATIONS))
def _find_positions(self, token_ids: torch.Tensor) -> tuple:
"""Find A digit positions, B digit positions, and operator position."""
token_list = token_ids.tolist()
op_pos = -1
op_idx = 0
for i, tid in enumerate(token_list):
if tid in self.OPERATOR_TOKENS:
op_pos = i
op_idx = self.OPERATOR_TOKENS[tid]
break
if op_pos == -1:
return [], [], -1, 0
a_positions = [i for i in range(op_pos) if token_list[i] in self.DIGIT_TOKENS]
b_start = op_pos + 2 if (op_pos + 1 < len(token_list) and
token_list[op_pos + 1] == self.SPACE_TOKEN) else op_pos + 1
b_positions = [i for i in range(b_start, len(token_list)) if token_list[i] in self.DIGIT_TOKENS]
return a_positions, b_positions, op_pos, op_idx
def _predict_value(self, hidden: torch.Tensor, positions: list) -> tuple:
"""Predict digit at each position and reconstruct number."""
if not positions:
return torch.tensor(0.0, device=hidden.device), []
digit_logits_list = []
soft_value = torch.tensor(0.0, device=hidden.device)
for idx, pos in enumerate(positions):
logits = self.digit_classifier(hidden[pos])
digit_logits_list.append(logits)
probs = torch.softmax(logits, dim=-1)
digit_values = torch.arange(10, device=hidden.device, dtype=torch.float32)
soft_digit = (probs * digit_values).sum()
place_value = 10 ** (len(positions) - idx - 1)
soft_value = soft_value + soft_digit * place_value
return soft_value, digit_logits_list
def _value_to_bits(self, value: torch.Tensor) -> torch.Tensor:
"""Convert soft value to 8 bits using differentiable operations."""
value = torch.clamp(value, 0, 255)
bits = []
for i in range(7, -1, -1):
bit = torch.sigmoid((value - (2 ** i - 0.5)) * 10)
value = value - bit * (2 ** i)
bits.append(bit)
return torch.stack(bits)
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor):
"""
Args:
hidden: [batch, seq_len, hidden_dim]
mask: [batch, seq_len]
token_ids: [batch, seq_len]
Returns:
a_bits: [batch, 8]
b_bits: [batch, 8]
op_logits: [batch, 6]
op_indices: [batch] ground truth op from tokens
a_digit_logits: list of [batch, 10] per digit position
b_digit_logits: list of [batch, 10] per digit position
"""
batch_size = hidden.shape[0]
device = hidden.device
a_bits_list = []
b_bits_list = []
op_logits_list = []
op_indices_list = []
a_values_list = []
b_values_list = []
a_digit_logits_list = []
b_digit_logits_list = []
for i in range(batch_size):
seq_mask = mask[i].bool()
valid_len = seq_mask.sum().item()
start_pos = hidden.shape[1] - valid_len
valid_hidden = hidden[i, start_pos:]
valid_tokens = token_ids[i, start_pos:]
a_pos, b_pos, op_pos, op_idx = self._find_positions(valid_tokens)
a_value, a_digit_logits = self._predict_value(valid_hidden, a_pos)
b_value, b_digit_logits = self._predict_value(valid_hidden, b_pos)
a_bits = self._value_to_bits(a_value)
b_bits = self._value_to_bits(b_value)
if op_pos >= 0:
op_logits = self.op_classifier(valid_hidden[op_pos])
else:
op_logits = torch.zeros(len(OPERATIONS), device=device)
a_bits_list.append(a_bits)
b_bits_list.append(b_bits)
op_logits_list.append(op_logits)
op_indices_list.append(op_idx)
a_values_list.append(a_value)
b_values_list.append(b_value)
a_digit_logits_list.append(a_digit_logits)
b_digit_logits_list.append(b_digit_logits)
a_bits = torch.stack(a_bits_list)
b_bits = torch.stack(b_bits_list)
op_logits = torch.stack(op_logits_list)
op_indices = torch.tensor(op_indices_list, device=device, dtype=torch.long)
a_values = torch.stack(a_values_list)
b_values = torch.stack(b_values_list)
return a_bits, b_bits, op_logits, op_indices, a_values, b_values, a_digit_logits_list, b_digit_logits_list
class DigitExtractor(nn.Module):
"""
Digit-level extraction: predicts digits (0-9) then converts to bits.
Uses attention pooling (less accurate than PositionalDigitExtractor).
"""
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
super().__init__()
self.attention_pool = AttentionPooling(hidden_dim, num_heads)
self.a_digit_pred = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, 3 * 10),
)
self.b_digit_pred = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, 3 * 10),
)
self.op_router = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Linear(intermediate_dim, len(OPERATIONS)),
)
def digits_to_bits(self, digit_logits: torch.Tensor) -> torch.Tensor:
"""
Convert 3-digit predictions to 8-bit representation.
digit_logits: [batch, 30] (3 digits * 10 classes each)
Returns: [batch, 8] bits
"""
batch_size = digit_logits.shape[0]
logits = digit_logits.view(batch_size, 3, 10)
probs = torch.softmax(logits, dim=-1)
digit_values = torch.arange(10, device=digit_logits.device).float()
soft_digits = (probs * digit_values).sum(dim=-1)
hundreds = soft_digits[:, 0]
tens = soft_digits[:, 1]
ones = soft_digits[:, 2]
value = hundreds * 100 + tens * 10 + ones
value = torch.clamp(value, 0, 255)
bits = []
for i in range(7, -1, -1):
bit = torch.fmod(torch.floor(value / (2 ** i)), 2)
bits.append(bit)
return torch.stack(bits, dim=-1)
def forward(self, hidden: torch.Tensor, mask: torch.Tensor):
"""
Returns:
a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
"""
pooled = self.attention_pool(hidden, mask)
a_digit_logits = self.a_digit_pred(pooled)
b_digit_logits = self.b_digit_pred(pooled)
op_logits = self.op_router(pooled)
a_bits = self.digits_to_bits(a_digit_logits)
b_bits = self.digits_to_bits(b_digit_logits)
return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
class HybridExtractor(nn.Module):
"""
Hybrid extractor that handles both digit tokens and word numbers.
For digit tokens (32-41): Direct lookup, no training needed
For word numbers: Learned MLP extraction from pooled hidden states
This is the real training target - learning to extract numbers from
natural language like "forty seven plus eighty six".
"""
DIGIT_TOKENS = set(range(32, 42))
SYMBOL_OP_TOKENS = {
1232: 0, # ' +' -> add
731: 1, # ' -' -> sub
1672: 2, # ' *' -> mul
2986: 3, # ' >' -> gt
2067: 4, # ' <' -> lt
1758: 5, # ' ==' -> eq
}
WORD_OP_TOKENS = {
2068: 0, # 'plus' -> add
8500: 1, # 'minus' -> sub
1580: 2, # 'times' -> mul
6301: 3, # 'greater' -> gt
1912: 4, # 'less' -> lt
16364: 5, # 'equals' -> eq
11540: 5, # 'equal' -> eq
}
ALL_OP_TOKENS = {**SYMBOL_OP_TOKENS, **WORD_OP_TOKENS}
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
super().__init__()
self.hidden_dim = hidden_dim
self.attention_pool = AttentionPooling(hidden_dim, num_heads)
self.a_pool = AttentionPooling(hidden_dim, num_heads)
self.b_pool = AttentionPooling(hidden_dim, num_heads)
self.a_digit_pred = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(intermediate_dim, 3 * 10),
)
self.b_digit_pred = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(intermediate_dim, 3 * 10),
)
self.op_predictor = nn.Sequential(
nn.Linear(hidden_dim, intermediate_dim // 2),
nn.GELU(),
nn.Linear(intermediate_dim // 2, len(OPERATIONS)),
)
def _has_digit_tokens(self, token_ids: torch.Tensor) -> bool:
"""Check if input contains digit tokens."""
for tid in token_ids.tolist():
if tid in self.DIGIT_TOKENS:
return True
return False
def _find_op_position(self, token_ids: torch.Tensor) -> int:
"""Find position of operator token, returns -1 if not found."""
tokens = token_ids.tolist()
for i, tid in enumerate(tokens):
if tid in self.ALL_OP_TOKENS:
return i
return -1
def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple:
"""
Extract values directly from digit tokens (hardcoded lookup).
Handles both symbol operators (' +') and word operators ('plus').
Returns (a_value, b_value, op_idx) or None if pattern not found.
"""
tokens = token_ids.tolist()
op_pos = -1
op_idx = 0
for i, tid in enumerate(tokens):
if tid in self.ALL_OP_TOKENS:
op_pos = i
op_idx = self.ALL_OP_TOKENS[tid]
break
if op_pos == -1:
return None
a_digits = []
for i in range(op_pos):
if tokens[i] in self.DIGIT_TOKENS:
a_digits.append(tokens[i] - 32)
b_start = op_pos + 1
if b_start < len(tokens) and tokens[b_start] == 216:
b_start += 1
b_digits = []
for i in range(b_start, len(tokens)):
if tokens[i] in self.DIGIT_TOKENS:
b_digits.append(tokens[i] - 32)
if not a_digits or not b_digits:
return None
a_val = 0
for d in a_digits:
a_val = a_val * 10 + d
b_val = 0
for d in b_digits:
b_val = b_val * 10 + d
return min(a_val, 255), min(b_val, 255), op_idx
def _value_to_bits(self, value: int, device) -> torch.Tensor:
"""Convert integer to 8-bit tensor."""
bits = torch.zeros(8, device=device)
for i in range(8):
bits[7 - i] = (value >> i) & 1
return bits
def _digits_to_value_and_bits(self, digit_logits: torch.Tensor, device) -> tuple:
"""
Convert 3-digit logits to value and bits.
digit_logits: [30] (3 digits × 10 classes)
Returns: (value tensor, bits tensor [8])
"""
logits = digit_logits.view(3, 10)
probs = torch.softmax(logits, dim=-1)
digit_values = torch.arange(10, device=device, dtype=torch.float32)
soft_digits = (probs * digit_values).sum(dim=-1)
hundreds = soft_digits[0]
tens = soft_digits[1]
ones = soft_digits[2]
value = hundreds * 100 + tens * 10 + ones
value = torch.clamp(value, 0, 255)
bits = []
for i in range(7, -1, -1):
threshold = 2 ** i
bit = torch.sigmoid((value - threshold + 0.5) * 10)
bits.append(bit)
value = value - bit * threshold
return hundreds * 100 + tens * 10 + ones, torch.stack(bits)
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None):
"""
Args:
hidden: [batch, seq_len, hidden_dim]
mask: [batch, seq_len]
token_ids: [batch, seq_len] - optional, enables digit lookup
Returns:
a_bits, b_bits, op_logits, a_values, b_values, used_lookup,
a_digit_logits, b_digit_logits
"""
batch_size = hidden.shape[0]
device = hidden.device
a_bits_list = []
a_digit_logits_list = []
b_digit_logits_list = []
b_bits_list = []
op_logits_list = []
a_values_list = []
b_values_list = []
used_lookup_list = []
pooled = self.attention_pool(hidden, mask)
for i in range(batch_size):
lookup_result = None
if token_ids is not None:
seq_mask = mask[i].bool()
valid_len = seq_mask.sum().item()
start_pos = hidden.shape[1] - valid_len
valid_tokens = token_ids[i, start_pos:]
if self._has_digit_tokens(valid_tokens):
lookup_result = self._extract_from_digits(valid_tokens)
if lookup_result is not None:
a_val, b_val, op_idx = lookup_result
a_bits = self._value_to_bits(a_val, device)
b_bits = self._value_to_bits(b_val, device)
op_logits = torch.zeros(len(OPERATIONS), device=device)
op_logits[op_idx] = 10.0
a_bits_list.append(a_bits)
b_bits_list.append(b_bits)
op_logits_list.append(op_logits)
a_values_list.append(float(a_val))
b_values_list.append(float(b_val))
used_lookup_list.append(True)
a_digit_logits_list.append(None)
b_digit_logits_list.append(None)
else:
sample_hidden = hidden[i:i+1]
sample_mask = mask[i:i+1]
seq_mask = mask[i].bool()
valid_len = int(seq_mask.sum().item())
start_pos = hidden.shape[1] - valid_len
valid_tokens = token_ids[i, start_pos:] if token_ids is not None else None
op_pos = self._find_op_position(valid_tokens) if valid_tokens is not None else -1
if op_pos > 0 and op_pos < valid_len - 1:
a_end = start_pos + op_pos
b_start = start_pos + op_pos + 1
a_mask = torch.zeros_like(sample_mask)
a_mask[0, start_pos:a_end] = 1.0
b_mask = torch.zeros_like(sample_mask)
b_mask[0, b_start:] = sample_mask[0, b_start:]
a_pooled = self.a_pool(sample_hidden, a_mask)[0]
b_pooled = self.b_pool(sample_hidden, b_mask)[0]
else:
a_pooled = pooled[i]
b_pooled = pooled[i]
a_digit_logits = self.a_digit_pred(a_pooled)
b_digit_logits = self.b_digit_pred(b_pooled)
op_logits = self.op_predictor(pooled[i])
a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device)
b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device)
a_bits_list.append(a_bits)
b_bits_list.append(b_bits)
op_logits_list.append(op_logits)
a_values_list.append(a_val)
b_values_list.append(b_val)
used_lookup_list.append(False)
a_digit_logits_list.append(a_digit_logits)
b_digit_logits_list.append(b_digit_logits)
a_bits = torch.stack(a_bits_list)
b_bits = torch.stack(b_bits_list)
op_logits = torch.stack(op_logits_list)
a_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in a_values_list])
b_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in b_values_list])
used_lookup = torch.tensor(used_lookup_list, device=device, dtype=torch.bool)
valid_a_logits = [x for x in a_digit_logits_list if x is not None]
valid_b_logits = [x for x in b_digit_logits_list if x is not None]
a_digit_logits_out = torch.stack(valid_a_logits) if valid_a_logits else None
b_digit_logits_out = torch.stack(valid_b_logits) if valid_b_logits else None
return a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits_out, b_digit_logits_out
def _soft_value_to_bits(self, value: torch.Tensor, device) -> torch.Tensor:
"""Convert soft value (0-255) to 8-bit representation differentiably."""
value = torch.clamp(value, 0, 255)
bits = []
remaining = value
for i in range(7, -1, -1):
threshold = 2 ** i
bit = torch.sigmoid((remaining - threshold + 0.5) * 10)
bits.append(bit)
remaining = remaining - bit * threshold
return torch.stack(bits)
class ArithmeticModel(nn.Module):
"""
LLM + extractor + frozen threshold circuits.
Optionally unfreeze top N transformer layers with --unfreeze_layers.
"""
def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0,
extract_layer: int = -1, position_extract: bool = False,
digit_pred: bool = False, positional_digit: bool = False,
hybrid: bool = False):
super().__init__()
self.device = device
self.unfreeze_layers = unfreeze_layers
self.extract_layer = extract_layer
self.position_extract = position_extract
self.digit_pred = digit_pred
self.positional_digit = positional_digit
self.hybrid = hybrid
from transformers import AutoModelForCausalLM, AutoTokenizer
print("[1/4] Loading tokenizer...", flush=True)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
self.tokenizer.padding_side = 'left'
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(" Tokenizer loaded.", flush=True)
print("[2/4] Loading SmolLM2-360M...", flush=True)
self.llm = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map=device,
output_hidden_states=True
)
for param in self.llm.parameters():
param.requires_grad = False
if unfreeze_layers > 0:
num_layers = len(self.llm.model.layers)
layers_to_unfreeze = list(range(num_layers - unfreeze_layers, num_layers))
print(f" Unfreezing layers {layers_to_unfreeze}...", flush=True)
for layer_idx in layers_to_unfreeze:
for param in self.llm.model.layers[layer_idx].parameters():
param.requires_grad = True
hidden_dim = self.llm.config.hidden_size
llm_params = sum(p.numel() for p in self.llm.parameters())
trainable_llm = sum(p.numel() for p in self.llm.parameters() if p.requires_grad)
print(f" LLM loaded. Hidden dim: {hidden_dim}", flush=True)
print(f" LLM params: {llm_params:,} total, {trainable_llm:,} trainable", flush=True)
print("[3/4] Loading threshold circuits...", flush=True)
self.circuits = FrozenThresholdCircuits(device=device)
print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True)
print("[4/4] Initializing extractor...", flush=True)
if hybrid:
print(" Using HYBRID extraction (digit lookup + word learning)", flush=True)
self.extractor = HybridExtractor(
hidden_dim=hidden_dim,
intermediate_dim=256,
num_heads=4
).to(device)
elif positional_digit:
print(" Using POSITIONAL DIGIT extraction (100% proven)", flush=True)
self.extractor = PositionalDigitExtractor(
hidden_dim=hidden_dim
).to(device)
elif position_extract:
print(" Using position-specific extraction", flush=True)
self.extractor = PositionExtractor(
hidden_dim=hidden_dim,
intermediate_dim=256
).to(device)
elif digit_pred:
print(" Using digit-level prediction", flush=True)
self.extractor = DigitExtractor(
hidden_dim=hidden_dim,
intermediate_dim=256,
num_heads=4
).to(device)
else:
self.extractor = Extractor(
hidden_dim=hidden_dim,
intermediate_dim=256,
num_heads=4
).to(device)
if extract_layer != -1:
print(f" Extracting from layer {extract_layer}", flush=True)
trainable_ext = sum(p.numel() for p in self.extractor.parameters())
total_trainable = trainable_llm + trainable_ext
print(f" Extractor params: {trainable_ext:,}", flush=True)
print(f" Total trainable: {total_trainable:,}", flush=True)
def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get hidden states from specified layer.
Returns:
hidden: [batch, seq_len, hidden_dim] hidden states
mask: [batch, seq_len] attention mask
token_ids: [batch, seq_len] input token IDs
"""
inputs = self.tokenizer(
texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=64
).to(self.device)
if self.unfreeze_layers > 0:
outputs = self.llm(**inputs, output_hidden_states=True)
else:
with torch.no_grad():
outputs = self.llm(**inputs, output_hidden_states=True)
hidden = outputs.hidden_states[self.extract_layer].float()
mask = inputs.attention_mask.float()
token_ids = inputs.input_ids
return hidden, mask, token_ids
def forward(self, texts: list[str]):
"""
Full forward pass: text -> hidden states -> extractor -> circuits -> result
Returns:
result_bits, a_bits, b_bits, op_logits
If digit_pred: also returns a_digit_logits, b_digit_logits
If position_extract or positional_digit: also returns op_indices
If positional_digit: also returns a_values, b_values
"""
hidden, mask, token_ids = self.get_hidden_states(texts)
if self.hybrid or self.positional_digit or self.position_extract:
extractor_out = self.extractor(hidden, mask, token_ids)
else:
extractor_out = self.extractor(hidden, mask)
if self.hybrid:
a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits = extractor_out
op_indices_from_tokens = None
elif self.positional_digit:
a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits = extractor_out
used_lookup = None
elif self.digit_pred:
a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
op_indices_from_tokens = None
a_values, b_values = None, None
used_lookup = None
elif self.position_extract:
a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out
a_digit_logits, b_digit_logits = None, None
a_values, b_values = None, None
used_lookup = None
else:
a_bits, b_bits, op_logits = extractor_out
a_digit_logits, b_digit_logits = None, None
op_indices_from_tokens = None
a_values, b_values = None, None
used_lookup = None
op_probs = torch.softmax(op_logits, dim=-1)
result_bits = self.circuits(a_bits, b_bits, op_probs)
if self.hybrid:
return result_bits, a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits
if self.positional_digit:
return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits
if self.digit_pred:
return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
if self.position_extract:
return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens
return result_bits, a_bits, b_bits, op_logits
def trainable_parameters(self):
"""Return all trainable parameters for optimizer."""
params = list(self.extractor.parameters())
if self.unfreeze_layers > 0:
params += [p for p in self.llm.parameters() if p.requires_grad]
return params
if __name__ == "__main__":
import sys
sys.path.insert(0, '.')
from fitness import generate_batch, compute_fitness, OPERATIONS
print("Testing model components...")
device = 'cuda'
batch = generate_batch(32, device)
print("\n1. Testing DirectCircuitModel (should get ~100% fitness)...")
direct_model = DirectCircuitModel(device=device)
def direct_fn(a, b, op):
return direct_model(a, b, op)
fitness, details = compute_fitness(direct_fn, n_samples=2000, batch_size=128,
device=device, return_details=True)
print(f" Direct circuit fitness: {fitness:.4f}")
for op in OPERATIONS:
acc = details['by_op'][op]['accuracy']
print(f" {op}: {acc:.4f}")
print("\n2. Testing ThresholdALU (trainable interface)...")
model = ThresholdALU(device=device)
x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1)
a_enc, b_enc = model.encoder(x)
print(f" Encoder output shapes: a={a_enc.shape}, b={b_enc.shape}")
op_weights = model.router(x)
print(f" Router output shape: {op_weights.shape}")
print(f" Router output sample: {op_weights[0].tolist()}")
result = model(batch['a_bits'], batch['b_bits'], batch['op_onehot'])
print(f" Full model output shape: {result.shape}")
print("\n3. Testing untrained ThresholdALU fitness...")
def model_fn(a, b, op):
return model(a, b, op)
fitness = compute_fitness(model_fn, n_samples=1000, batch_size=128, device=device)
print(f" Untrained model fitness: {fitness:.4f} (expected low)")
print("\n4. Counting parameters...")
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
encoder_params = sum(p.numel() for p in model.encoder.parameters())
router_params = sum(p.numel() for p in model.router.parameters())
print(f" Encoder: {encoder_params:,}")
print(f" Router: {router_params:,}")
print(f" Total trainable: {total:,}")
print("\nDone.")