""" Trainable interface layers for frozen threshold circuits. BitEncoder, OpRouter, BitDecoder wrap the frozen circuits. HiddenStateExtractor and AugmentedArithmeticModel for LLM integration. """ import torch import torch.nn as nn import torch.nn.functional as F from circuits import FrozenThresholdCircuits, heaviside_ste MODEL_ID = 'HuggingFaceTB/SmolLM2-360M-Instruct' OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] OP_SYMBOLS = {'add': '+', 'sub': '-', 'mul': '*', 'gt': '>', 'lt': '<', 'eq': '=='} class BitEncoder(nn.Module): """ Encodes two 8-bit operands from input representation. Uses residual connection to preserve ground truth bits while allowing learned refinement. """ def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32): super().__init__() self.refine = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 16), ) self.scale = nn.Parameter(torch.tensor(0.0)) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Args: x: [batch, input_dim] input with first 16 dims being a_bits, b_bits Returns: a_bits: [batch, 8] first operand bits b_bits: [batch, 8] second operand bits """ base_bits = x[:, :16] refinement = self.refine(x) * torch.sigmoid(self.scale) bits = base_bits + refinement bits = torch.clamp(bits, 0, 1) hard_bits = heaviside_ste(bits - 0.5) out = hard_bits - bits.detach() + bits return out[:, :8], out[:, 8:] class OpRouter(nn.Module): """ Routes computation to the appropriate circuit based on input. Outputs soft weights over operations for gradient flow. """ def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32, n_ops: int = 6): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, n_ops), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [batch, input_dim] input features Returns: op_weights: [batch, n_ops] soft operation weights (softmax) """ logits = self.net(x) return F.softmax(logits, dim=-1) class BitDecoder(nn.Module): """ Decodes circuit output bits to target representation. For standalone training: outputs soft bits for loss computation. For LLM integration: would project to hidden state delta. """ def __init__(self, output_dim: int = 8): super().__init__() self.output_dim = output_dim def forward(self, result_bits: torch.Tensor) -> torch.Tensor: return result_bits class ThresholdALU(nn.Module): """ Complete trainable interface + frozen circuits. Learns to encode inputs, route to circuits, decode outputs. """ def __init__(self, device: str = 'cuda'): super().__init__() self.device = device self.circuits = FrozenThresholdCircuits(device=device) for key in self.circuits.weights: self.circuits.weights[key].requires_grad = False self.encoder = BitEncoder(input_dim=16 + 6, hidden_dim=64).to(device) self.router = OpRouter(input_dim=16 + 6, hidden_dim=32, n_ops=6).to(device) self.decoder = BitDecoder(output_dim=8).to(device) def forward(self, a_bits_in: torch.Tensor, b_bits_in: torch.Tensor, op_onehot: torch.Tensor) -> torch.Tensor: """ Forward pass through trainable interface + frozen circuits. Args: a_bits_in: [batch, 8] input A bits (ground truth for training) b_bits_in: [batch, 8] input B bits (ground truth for training) op_onehot: [batch, 6] one-hot operation selector Returns: result_bits: [batch, 8] output bits """ x = torch.cat([a_bits_in, b_bits_in, op_onehot], dim=-1) a_bits, b_bits = self.encoder(x) op_weights = self.router(x) result = self.circuits(a_bits, b_bits, op_weights) output = self.decoder(result) return output def forward_direct(self, a_bits: torch.Tensor, b_bits: torch.Tensor, op_onehot: torch.Tensor) -> torch.Tensor: """ Direct forward through circuits (bypass encoder/router for testing). """ return self.circuits(a_bits, b_bits, op_onehot) class DirectCircuitModel(nn.Module): """ Minimal model that directly uses circuits without learned encoding. For validating that circuits themselves achieve 100% fitness. """ def __init__(self, device: str = 'cuda'): super().__init__() self.device = device self.circuits = FrozenThresholdCircuits(device=device) def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor, op_onehot: torch.Tensor) -> torch.Tensor: return self.circuits(a_bits, b_bits, op_onehot) class HiddenStateExtractor(nn.Module): """ Extracts operands and operation from LLM hidden states. This is the hard part - must learn to parse numbers from embeddings. """ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256): super().__init__() self.a_extractor = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, 8), ) self.b_extractor = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, 8), ) self.op_router = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, len(OPERATIONS)), ) def forward(self, hidden_states: torch.Tensor): """ Args: hidden_states: [batch, hidden_dim] from LLM Returns: a_bits: [batch, 8] b_bits: [batch, 8] op_logits: [batch, 6] """ a_logits = self.a_extractor(hidden_states) b_logits = self.b_extractor(hidden_states) op_logits = self.op_router(hidden_states) a_soft = torch.sigmoid(a_logits) b_soft = torch.sigmoid(b_logits) a_hard = heaviside_ste(a_logits) b_hard = heaviside_ste(b_logits) a_bits = a_hard - a_soft.detach() + a_soft b_bits = b_hard - b_soft.detach() + b_soft return a_bits, b_bits, op_logits class AttentionPooling(nn.Module): """ Learnable attention pooling over sequence positions. Replaces mean pooling - learns which tokens matter for extraction. """ def __init__(self, hidden_dim: int = 960, num_heads: int = 4): super().__init__() self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.query = nn.Linear(hidden_dim, hidden_dim) self.key = nn.Linear(hidden_dim, hidden_dim) self.value = nn.Linear(hidden_dim, hidden_dim) self.out_proj = nn.Linear(hidden_dim, hidden_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) def forward(self, embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Args: embeddings: [batch, seq_len, hidden_dim] mask: [batch, seq_len] attention mask (1 = attend, 0 = ignore) Returns: pooled: [batch, hidden_dim] """ batch_size, seq_len, hidden_dim = embeddings.shape cls_expanded = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat([cls_expanded, embeddings], dim=1) cls_mask = torch.ones(batch_size, 1, device=mask.device) mask = torch.cat([cls_mask, mask], dim=1) Q = self.query(embeddings[:, :1, :]) K = self.key(embeddings) V = self.value(embeddings) Q = Q.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2) K = K.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2) V = V.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) mask_expanded = mask.unsqueeze(1).unsqueeze(2) scores = scores.masked_fill(mask_expanded == 0, -1e9) attn_weights = torch.softmax(scores, dim=-1) attn_weights = torch.nan_to_num(attn_weights, nan=0.0) context = torch.matmul(attn_weights, V) context = context.transpose(1, 2).contiguous().view(batch_size, 1, hidden_dim) pooled = self.out_proj(context).squeeze(1) pooled = torch.nan_to_num(pooled, nan=0.0) return pooled class MultiHeadBitExtractor(nn.Module): """ 8 separate extractors for 8 bits - each bit gets its own specialized network. More expressive than single MLP predicting all 8 bits at once. """ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 128): super().__init__() self.bit_extractors = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, 1), ) for _ in range(8) ]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states: [batch, hidden_dim] Returns: bits: [batch, 8] - one bit from each extractor """ hidden_states = torch.nan_to_num(hidden_states, nan=0.0) bit_logits = [extractor(hidden_states) for extractor in self.bit_extractors] logits = torch.cat(bit_logits, dim=-1) logits = torch.clamp(logits, -20, 20) soft = torch.sigmoid(logits) hard = heaviside_ste(logits) bits = hard - soft.detach() + soft return bits, logits class Extractor(nn.Module): """ Extracts operands and operation from LLM hidden states. Uses attention pooling and per-bit extraction networks. """ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): super().__init__() self.attention_pool = AttentionPooling(hidden_dim, num_heads) self.a_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2) self.b_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2) self.op_router = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, len(OPERATIONS)), ) def forward(self, embeddings: torch.Tensor, mask: torch.Tensor): """ Args: embeddings: [batch, seq_len, hidden_dim] mask: [batch, seq_len] Returns: a_bits: [batch, 8] b_bits: [batch, 8] op_logits: [batch, 6] """ pooled = self.attention_pool(embeddings, mask) a_bits, _ = self.a_extractor(pooled) b_bits, _ = self.b_extractor(pooled) op_logits = self.op_router(pooled) return a_bits, b_bits, op_logits class PositionExtractor(nn.Module): """ Position-specific extraction with dynamic operator detection. Tokenization pattern for "A op B": [A_digits...] [operator] [space] [B_digits...] Examples: "5 + 3" -> ['5', ' +', ' ', '3'] (positions: A=0, op=1, B=3) "47 + 86" -> ['4', '7', ' +', ' ', '8', '6'] (positions: A=0-1, op=2, B=4-5) "127 + 128" -> ['1','2','7',' +', ' ','1','2','8'] (positions: A=0-2, op=3, B=5-7) Token IDs (SmolLM2): Digits '0'-'9': 32-41 Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758 Space: 216 """ DIGIT_TOKENS = set(range(32, 42)) OPERATOR_TOKENS = { 1232: 0, # ' +' -> add 731: 1, # ' -' -> sub 1672: 2, # ' *' -> mul 2986: 3, # ' >' -> gt 2067: 4, # ' <' -> lt 1758: 5, # ' ==' -> eq } SPACE_TOKEN = 216 MAX_DIGITS = 3 def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256): super().__init__() self.hidden_dim = hidden_dim self.a_extractor = nn.Sequential( nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, intermediate_dim // 2), nn.GELU(), nn.Linear(intermediate_dim // 2, 8), ) self.b_extractor = nn.Sequential( nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, intermediate_dim // 2), nn.GELU(), nn.Linear(intermediate_dim // 2, 8), ) self.op_extractor = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim // 2), nn.GELU(), nn.Linear(intermediate_dim // 2, len(OPERATIONS)), ) def _find_operator_position(self, token_ids: torch.Tensor) -> tuple[int, int]: """ Find operator token position and its operation index. Args: token_ids: [seq_len] tensor of token IDs Returns: (position, op_index) or (-1, -1) if not found """ for pos, tid in enumerate(token_ids.tolist()): if tid in self.OPERATOR_TOKENS: return pos, self.OPERATOR_TOKENS[tid] return -1, -1 def _extract_digit_features(self, hidden: torch.Tensor, start: int, end: int) -> torch.Tensor: """ Extract and pad digit hidden states to fixed size. Args: hidden: [seq_len, hidden_dim] start: start position (inclusive) end: end position (exclusive) Returns: [hidden_dim * MAX_DIGITS] flattened features, zero-padded on the LEFT (so units digit is always at the same position regardless of number length) """ n_digits = end - start features = torch.zeros(self.MAX_DIGITS * self.hidden_dim, device=hidden.device) if n_digits > 0 and n_digits <= self.MAX_DIGITS: digit_hidden = hidden[start:end, :].reshape(-1) pad_size = (self.MAX_DIGITS - n_digits) * self.hidden_dim features[pad_size:] = digit_hidden return features def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None): """ Args: hidden: [batch, seq_len, hidden_dim] mask: [batch, seq_len] attention mask token_ids: [batch, seq_len] token IDs (required for operator detection) Returns: a_bits: [batch, 8] b_bits: [batch, 8] op_logits: [batch, 6] """ if token_ids is None: raise ValueError("PositionExtractor requires token_ids for operator detection") batch_size, seq_len, hidden_dim = hidden.shape device = hidden.device a_features = [] b_features = [] op_features = [] op_indices = [] for i in range(batch_size): seq_mask = mask[i].bool() valid_len = seq_mask.sum().item() start_pos = seq_len - valid_len valid_tokens = token_ids[i, start_pos:] valid_hidden = hidden[i, start_pos:, :] op_pos, op_idx = self._find_operator_position(valid_tokens) if op_pos == -1: a_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device) b_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device) op_feat = torch.zeros(hidden_dim, device=device) op_idx = 0 else: a_feat = self._extract_digit_features(valid_hidden, 0, op_pos) op_feat = valid_hidden[op_pos, :] b_start = op_pos + 2 if (op_pos + 1 < valid_len and valid_tokens[op_pos + 1].item() == self.SPACE_TOKEN) else op_pos + 1 b_feat = self._extract_digit_features(valid_hidden, b_start, valid_len) a_features.append(a_feat) b_features.append(b_feat) op_features.append(op_feat) op_indices.append(op_idx) a_features = torch.stack(a_features) b_features = torch.stack(b_features) op_features = torch.stack(op_features) op_indices_tensor = torch.tensor(op_indices, device=device, dtype=torch.long) a_logits = self.a_extractor(a_features) b_logits = self.b_extractor(b_features) op_logits = self.op_extractor(op_features) a_soft = torch.sigmoid(a_logits) b_soft = torch.sigmoid(b_logits) a_hard = heaviside_ste(a_logits) b_hard = heaviside_ste(b_logits) a_bits = a_hard - a_soft.detach() + a_soft b_bits = b_hard - b_soft.detach() + b_soft return a_bits, b_bits, op_logits, op_indices_tensor class PositionalDigitExtractor(nn.Module): """ Position-aware digit extraction: classifies each digit position independently. This approach achieves 100% accuracy because: 1. Each digit token position is classified independently (100% accuracy on layer 0) 2. Numbers are reconstructed using place values (×100, ×10, ×1) 3. No information is lost through pooling Token IDs (SmolLM2): Digits '0'-'9': 32-41 Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758 Space: 216 """ DIGIT_TOKENS = set(range(32, 42)) OPERATOR_TOKENS = { 1232: 0, # ' +' -> add 731: 1, # ' -' -> sub 1672: 2, # ' *' -> mul 2986: 3, # ' >' -> gt 2067: 4, # ' <' -> lt 1758: 5, # ' ==' -> eq } SPACE_TOKEN = 216 def __init__(self, hidden_dim: int = 960): super().__init__() self.hidden_dim = hidden_dim self.digit_classifier = nn.Linear(hidden_dim, 10) self.op_classifier = nn.Linear(hidden_dim, len(OPERATIONS)) def _find_positions(self, token_ids: torch.Tensor) -> tuple: """Find A digit positions, B digit positions, and operator position.""" token_list = token_ids.tolist() op_pos = -1 op_idx = 0 for i, tid in enumerate(token_list): if tid in self.OPERATOR_TOKENS: op_pos = i op_idx = self.OPERATOR_TOKENS[tid] break if op_pos == -1: return [], [], -1, 0 a_positions = [i for i in range(op_pos) if token_list[i] in self.DIGIT_TOKENS] b_start = op_pos + 2 if (op_pos + 1 < len(token_list) and token_list[op_pos + 1] == self.SPACE_TOKEN) else op_pos + 1 b_positions = [i for i in range(b_start, len(token_list)) if token_list[i] in self.DIGIT_TOKENS] return a_positions, b_positions, op_pos, op_idx def _predict_value(self, hidden: torch.Tensor, positions: list) -> tuple: """Predict digit at each position and reconstruct number.""" if not positions: return torch.tensor(0.0, device=hidden.device), [] digit_logits_list = [] soft_value = torch.tensor(0.0, device=hidden.device) for idx, pos in enumerate(positions): logits = self.digit_classifier(hidden[pos]) digit_logits_list.append(logits) probs = torch.softmax(logits, dim=-1) digit_values = torch.arange(10, device=hidden.device, dtype=torch.float32) soft_digit = (probs * digit_values).sum() place_value = 10 ** (len(positions) - idx - 1) soft_value = soft_value + soft_digit * place_value return soft_value, digit_logits_list def _value_to_bits(self, value: torch.Tensor) -> torch.Tensor: """Convert soft value to 8 bits using differentiable operations.""" value = torch.clamp(value, 0, 255) bits = [] for i in range(7, -1, -1): bit = torch.sigmoid((value - (2 ** i - 0.5)) * 10) value = value - bit * (2 ** i) bits.append(bit) return torch.stack(bits) def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor): """ Args: hidden: [batch, seq_len, hidden_dim] mask: [batch, seq_len] token_ids: [batch, seq_len] Returns: a_bits: [batch, 8] b_bits: [batch, 8] op_logits: [batch, 6] op_indices: [batch] ground truth op from tokens a_digit_logits: list of [batch, 10] per digit position b_digit_logits: list of [batch, 10] per digit position """ batch_size = hidden.shape[0] device = hidden.device a_bits_list = [] b_bits_list = [] op_logits_list = [] op_indices_list = [] a_values_list = [] b_values_list = [] a_digit_logits_list = [] b_digit_logits_list = [] for i in range(batch_size): seq_mask = mask[i].bool() valid_len = seq_mask.sum().item() start_pos = hidden.shape[1] - valid_len valid_hidden = hidden[i, start_pos:] valid_tokens = token_ids[i, start_pos:] a_pos, b_pos, op_pos, op_idx = self._find_positions(valid_tokens) a_value, a_digit_logits = self._predict_value(valid_hidden, a_pos) b_value, b_digit_logits = self._predict_value(valid_hidden, b_pos) a_bits = self._value_to_bits(a_value) b_bits = self._value_to_bits(b_value) if op_pos >= 0: op_logits = self.op_classifier(valid_hidden[op_pos]) else: op_logits = torch.zeros(len(OPERATIONS), device=device) a_bits_list.append(a_bits) b_bits_list.append(b_bits) op_logits_list.append(op_logits) op_indices_list.append(op_idx) a_values_list.append(a_value) b_values_list.append(b_value) a_digit_logits_list.append(a_digit_logits) b_digit_logits_list.append(b_digit_logits) a_bits = torch.stack(a_bits_list) b_bits = torch.stack(b_bits_list) op_logits = torch.stack(op_logits_list) op_indices = torch.tensor(op_indices_list, device=device, dtype=torch.long) a_values = torch.stack(a_values_list) b_values = torch.stack(b_values_list) return a_bits, b_bits, op_logits, op_indices, a_values, b_values, a_digit_logits_list, b_digit_logits_list class DigitExtractor(nn.Module): """ Digit-level extraction: predicts digits (0-9) then converts to bits. Uses attention pooling (less accurate than PositionalDigitExtractor). """ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): super().__init__() self.attention_pool = AttentionPooling(hidden_dim, num_heads) self.a_digit_pred = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, 3 * 10), ) self.b_digit_pred = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, 3 * 10), ) self.op_router = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, len(OPERATIONS)), ) def digits_to_bits(self, digit_logits: torch.Tensor) -> torch.Tensor: """ Convert 3-digit predictions to 8-bit representation. digit_logits: [batch, 30] (3 digits * 10 classes each) Returns: [batch, 8] bits """ batch_size = digit_logits.shape[0] logits = digit_logits.view(batch_size, 3, 10) probs = torch.softmax(logits, dim=-1) digit_values = torch.arange(10, device=digit_logits.device).float() soft_digits = (probs * digit_values).sum(dim=-1) hundreds = soft_digits[:, 0] tens = soft_digits[:, 1] ones = soft_digits[:, 2] value = hundreds * 100 + tens * 10 + ones value = torch.clamp(value, 0, 255) bits = [] for i in range(7, -1, -1): bit = torch.fmod(torch.floor(value / (2 ** i)), 2) bits.append(bit) return torch.stack(bits, dim=-1) def forward(self, hidden: torch.Tensor, mask: torch.Tensor): """ Returns: a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits """ pooled = self.attention_pool(hidden, mask) a_digit_logits = self.a_digit_pred(pooled) b_digit_logits = self.b_digit_pred(pooled) op_logits = self.op_router(pooled) a_bits = self.digits_to_bits(a_digit_logits) b_bits = self.digits_to_bits(b_digit_logits) return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits class HybridExtractor(nn.Module): """ Hybrid extractor that handles both digit tokens and word numbers. For digit tokens (32-41): Direct lookup, no training needed For word numbers: Learned MLP extraction from pooled hidden states This is the real training target - learning to extract numbers from natural language like "forty seven plus eighty six". """ DIGIT_TOKENS = set(range(32, 42)) SYMBOL_OP_TOKENS = { 1232: 0, # ' +' -> add 731: 1, # ' -' -> sub 1672: 2, # ' *' -> mul 2986: 3, # ' >' -> gt 2067: 4, # ' <' -> lt 1758: 5, # ' ==' -> eq } WORD_OP_TOKENS = { 2068: 0, # 'plus' -> add 8500: 1, # 'minus' -> sub 1580: 2, # 'times' -> mul 6301: 3, # 'greater' -> gt 1912: 4, # 'less' -> lt 16364: 5, # 'equals' -> eq 11540: 5, # 'equal' -> eq } ALL_OP_TOKENS = {**SYMBOL_OP_TOKENS, **WORD_OP_TOKENS} def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): super().__init__() self.hidden_dim = hidden_dim self.attention_pool = AttentionPooling(hidden_dim, num_heads) self.a_pool = AttentionPooling(hidden_dim, num_heads) self.b_pool = AttentionPooling(hidden_dim, num_heads) self.a_digit_pred = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(intermediate_dim, 3 * 10), ) self.b_digit_pred = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(intermediate_dim, 3 * 10), ) self.op_predictor = nn.Sequential( nn.Linear(hidden_dim, intermediate_dim // 2), nn.GELU(), nn.Linear(intermediate_dim // 2, len(OPERATIONS)), ) def _has_digit_tokens(self, token_ids: torch.Tensor) -> bool: """Check if input contains digit tokens.""" for tid in token_ids.tolist(): if tid in self.DIGIT_TOKENS: return True return False def _find_op_position(self, token_ids: torch.Tensor) -> int: """Find position of operator token, returns -1 if not found.""" tokens = token_ids.tolist() for i, tid in enumerate(tokens): if tid in self.ALL_OP_TOKENS: return i return -1 def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple: """ Extract values directly from digit tokens (hardcoded lookup). Handles both symbol operators (' +') and word operators ('plus'). Returns (a_value, b_value, op_idx) or None if pattern not found. """ tokens = token_ids.tolist() op_pos = -1 op_idx = 0 for i, tid in enumerate(tokens): if tid in self.ALL_OP_TOKENS: op_pos = i op_idx = self.ALL_OP_TOKENS[tid] break if op_pos == -1: return None a_digits = [] for i in range(op_pos): if tokens[i] in self.DIGIT_TOKENS: a_digits.append(tokens[i] - 32) b_start = op_pos + 1 if b_start < len(tokens) and tokens[b_start] == 216: b_start += 1 b_digits = [] for i in range(b_start, len(tokens)): if tokens[i] in self.DIGIT_TOKENS: b_digits.append(tokens[i] - 32) if not a_digits or not b_digits: return None a_val = 0 for d in a_digits: a_val = a_val * 10 + d b_val = 0 for d in b_digits: b_val = b_val * 10 + d return min(a_val, 255), min(b_val, 255), op_idx def _value_to_bits(self, value: int, device) -> torch.Tensor: """Convert integer to 8-bit tensor.""" bits = torch.zeros(8, device=device) for i in range(8): bits[7 - i] = (value >> i) & 1 return bits def _digits_to_value_and_bits(self, digit_logits: torch.Tensor, device) -> tuple: """ Convert 3-digit logits to value and bits. digit_logits: [30] (3 digits × 10 classes) Returns: (value tensor, bits tensor [8]) """ logits = digit_logits.view(3, 10) probs = torch.softmax(logits, dim=-1) digit_values = torch.arange(10, device=device, dtype=torch.float32) soft_digits = (probs * digit_values).sum(dim=-1) hundreds = soft_digits[0] tens = soft_digits[1] ones = soft_digits[2] value = hundreds * 100 + tens * 10 + ones value = torch.clamp(value, 0, 255) bits = [] for i in range(7, -1, -1): threshold = 2 ** i bit = torch.sigmoid((value - threshold + 0.5) * 10) bits.append(bit) value = value - bit * threshold return hundreds * 100 + tens * 10 + ones, torch.stack(bits) def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None): """ Args: hidden: [batch, seq_len, hidden_dim] mask: [batch, seq_len] token_ids: [batch, seq_len] - optional, enables digit lookup Returns: a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits """ batch_size = hidden.shape[0] device = hidden.device a_bits_list = [] a_digit_logits_list = [] b_digit_logits_list = [] b_bits_list = [] op_logits_list = [] a_values_list = [] b_values_list = [] used_lookup_list = [] pooled = self.attention_pool(hidden, mask) for i in range(batch_size): lookup_result = None if token_ids is not None: seq_mask = mask[i].bool() valid_len = seq_mask.sum().item() start_pos = hidden.shape[1] - valid_len valid_tokens = token_ids[i, start_pos:] if self._has_digit_tokens(valid_tokens): lookup_result = self._extract_from_digits(valid_tokens) if lookup_result is not None: a_val, b_val, op_idx = lookup_result a_bits = self._value_to_bits(a_val, device) b_bits = self._value_to_bits(b_val, device) op_logits = torch.zeros(len(OPERATIONS), device=device) op_logits[op_idx] = 10.0 a_bits_list.append(a_bits) b_bits_list.append(b_bits) op_logits_list.append(op_logits) a_values_list.append(float(a_val)) b_values_list.append(float(b_val)) used_lookup_list.append(True) a_digit_logits_list.append(None) b_digit_logits_list.append(None) else: sample_hidden = hidden[i:i+1] sample_mask = mask[i:i+1] seq_mask = mask[i].bool() valid_len = int(seq_mask.sum().item()) start_pos = hidden.shape[1] - valid_len valid_tokens = token_ids[i, start_pos:] if token_ids is not None else None op_pos = self._find_op_position(valid_tokens) if valid_tokens is not None else -1 if op_pos > 0 and op_pos < valid_len - 1: a_end = start_pos + op_pos b_start = start_pos + op_pos + 1 a_mask = torch.zeros_like(sample_mask) a_mask[0, start_pos:a_end] = 1.0 b_mask = torch.zeros_like(sample_mask) b_mask[0, b_start:] = sample_mask[0, b_start:] a_pooled = self.a_pool(sample_hidden, a_mask)[0] b_pooled = self.b_pool(sample_hidden, b_mask)[0] else: a_pooled = pooled[i] b_pooled = pooled[i] a_digit_logits = self.a_digit_pred(a_pooled) b_digit_logits = self.b_digit_pred(b_pooled) op_logits = self.op_predictor(pooled[i]) a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device) b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device) a_bits_list.append(a_bits) b_bits_list.append(b_bits) op_logits_list.append(op_logits) a_values_list.append(a_val) b_values_list.append(b_val) used_lookup_list.append(False) a_digit_logits_list.append(a_digit_logits) b_digit_logits_list.append(b_digit_logits) a_bits = torch.stack(a_bits_list) b_bits = torch.stack(b_bits_list) op_logits = torch.stack(op_logits_list) a_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in a_values_list]) b_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in b_values_list]) used_lookup = torch.tensor(used_lookup_list, device=device, dtype=torch.bool) valid_a_logits = [x for x in a_digit_logits_list if x is not None] valid_b_logits = [x for x in b_digit_logits_list if x is not None] a_digit_logits_out = torch.stack(valid_a_logits) if valid_a_logits else None b_digit_logits_out = torch.stack(valid_b_logits) if valid_b_logits else None return a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits_out, b_digit_logits_out def _soft_value_to_bits(self, value: torch.Tensor, device) -> torch.Tensor: """Convert soft value (0-255) to 8-bit representation differentiably.""" value = torch.clamp(value, 0, 255) bits = [] remaining = value for i in range(7, -1, -1): threshold = 2 ** i bit = torch.sigmoid((remaining - threshold + 0.5) * 10) bits.append(bit) remaining = remaining - bit * threshold return torch.stack(bits) class ArithmeticModel(nn.Module): """ LLM + extractor + frozen threshold circuits. Optionally unfreeze top N transformer layers with --unfreeze_layers. """ def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0, extract_layer: int = -1, position_extract: bool = False, digit_pred: bool = False, positional_digit: bool = False, hybrid: bool = False): super().__init__() self.device = device self.unfreeze_layers = unfreeze_layers self.extract_layer = extract_layer self.position_extract = position_extract self.digit_pred = digit_pred self.positional_digit = positional_digit self.hybrid = hybrid from transformers import AutoModelForCausalLM, AutoTokenizer print("[1/4] Loading tokenizer...", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) self.tokenizer.padding_side = 'left' if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(" Tokenizer loaded.", flush=True) print("[2/4] Loading SmolLM2-360M...", flush=True) self.llm = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map=device, output_hidden_states=True ) for param in self.llm.parameters(): param.requires_grad = False if unfreeze_layers > 0: num_layers = len(self.llm.model.layers) layers_to_unfreeze = list(range(num_layers - unfreeze_layers, num_layers)) print(f" Unfreezing layers {layers_to_unfreeze}...", flush=True) for layer_idx in layers_to_unfreeze: for param in self.llm.model.layers[layer_idx].parameters(): param.requires_grad = True hidden_dim = self.llm.config.hidden_size llm_params = sum(p.numel() for p in self.llm.parameters()) trainable_llm = sum(p.numel() for p in self.llm.parameters() if p.requires_grad) print(f" LLM loaded. Hidden dim: {hidden_dim}", flush=True) print(f" LLM params: {llm_params:,} total, {trainable_llm:,} trainable", flush=True) print("[3/4] Loading threshold circuits...", flush=True) self.circuits = FrozenThresholdCircuits(device=device) print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True) print("[4/4] Initializing extractor...", flush=True) if hybrid: print(" Using HYBRID extraction (digit lookup + word learning)", flush=True) self.extractor = HybridExtractor( hidden_dim=hidden_dim, intermediate_dim=256, num_heads=4 ).to(device) elif positional_digit: print(" Using POSITIONAL DIGIT extraction (100% proven)", flush=True) self.extractor = PositionalDigitExtractor( hidden_dim=hidden_dim ).to(device) elif position_extract: print(" Using position-specific extraction", flush=True) self.extractor = PositionExtractor( hidden_dim=hidden_dim, intermediate_dim=256 ).to(device) elif digit_pred: print(" Using digit-level prediction", flush=True) self.extractor = DigitExtractor( hidden_dim=hidden_dim, intermediate_dim=256, num_heads=4 ).to(device) else: self.extractor = Extractor( hidden_dim=hidden_dim, intermediate_dim=256, num_heads=4 ).to(device) if extract_layer != -1: print(f" Extracting from layer {extract_layer}", flush=True) trainable_ext = sum(p.numel() for p in self.extractor.parameters()) total_trainable = trainable_llm + trainable_ext print(f" Extractor params: {trainable_ext:,}", flush=True) print(f" Total trainable: {total_trainable:,}", flush=True) def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get hidden states from specified layer. Returns: hidden: [batch, seq_len, hidden_dim] hidden states mask: [batch, seq_len] attention mask token_ids: [batch, seq_len] input token IDs """ inputs = self.tokenizer( texts, return_tensors='pt', padding=True, truncation=True, max_length=64 ).to(self.device) if self.unfreeze_layers > 0: outputs = self.llm(**inputs, output_hidden_states=True) else: with torch.no_grad(): outputs = self.llm(**inputs, output_hidden_states=True) hidden = outputs.hidden_states[self.extract_layer].float() mask = inputs.attention_mask.float() token_ids = inputs.input_ids return hidden, mask, token_ids def forward(self, texts: list[str]): """ Full forward pass: text -> hidden states -> extractor -> circuits -> result Returns: result_bits, a_bits, b_bits, op_logits If digit_pred: also returns a_digit_logits, b_digit_logits If position_extract or positional_digit: also returns op_indices If positional_digit: also returns a_values, b_values """ hidden, mask, token_ids = self.get_hidden_states(texts) if self.hybrid or self.positional_digit or self.position_extract: extractor_out = self.extractor(hidden, mask, token_ids) else: extractor_out = self.extractor(hidden, mask) if self.hybrid: a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits = extractor_out op_indices_from_tokens = None elif self.positional_digit: a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits = extractor_out used_lookup = None elif self.digit_pred: a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out op_indices_from_tokens = None a_values, b_values = None, None used_lookup = None elif self.position_extract: a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out a_digit_logits, b_digit_logits = None, None a_values, b_values = None, None used_lookup = None else: a_bits, b_bits, op_logits = extractor_out a_digit_logits, b_digit_logits = None, None op_indices_from_tokens = None a_values, b_values = None, None used_lookup = None op_probs = torch.softmax(op_logits, dim=-1) result_bits = self.circuits(a_bits, b_bits, op_probs) if self.hybrid: return result_bits, a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits if self.positional_digit: return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits if self.digit_pred: return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits if self.position_extract: return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens return result_bits, a_bits, b_bits, op_logits def trainable_parameters(self): """Return all trainable parameters for optimizer.""" params = list(self.extractor.parameters()) if self.unfreeze_layers > 0: params += [p for p in self.llm.parameters() if p.requires_grad] return params if __name__ == "__main__": import sys sys.path.insert(0, '.') from fitness import generate_batch, compute_fitness, OPERATIONS print("Testing model components...") device = 'cuda' batch = generate_batch(32, device) print("\n1. Testing DirectCircuitModel (should get ~100% fitness)...") direct_model = DirectCircuitModel(device=device) def direct_fn(a, b, op): return direct_model(a, b, op) fitness, details = compute_fitness(direct_fn, n_samples=2000, batch_size=128, device=device, return_details=True) print(f" Direct circuit fitness: {fitness:.4f}") for op in OPERATIONS: acc = details['by_op'][op]['accuracy'] print(f" {op}: {acc:.4f}") print("\n2. Testing ThresholdALU (trainable interface)...") model = ThresholdALU(device=device) x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1) a_enc, b_enc = model.encoder(x) print(f" Encoder output shapes: a={a_enc.shape}, b={b_enc.shape}") op_weights = model.router(x) print(f" Router output shape: {op_weights.shape}") print(f" Router output sample: {op_weights[0].tolist()}") result = model(batch['a_bits'], batch['b_bits'], batch['op_onehot']) print(f" Full model output shape: {result.shape}") print("\n3. Testing untrained ThresholdALU fitness...") def model_fn(a, b, op): return model(a, b, op) fitness = compute_fitness(model_fn, n_samples=1000, batch_size=128, device=device) print(f" Untrained model fitness: {fitness:.4f} (expected low)") print("\n4. Counting parameters...") total = sum(p.numel() for p in model.parameters() if p.requires_grad) encoder_params = sum(p.numel() for p in model.encoder.parameters()) router_params = sum(p.numel() for p in model.router.parameters()) print(f" Encoder: {encoder_params:,}") print(f" Router: {router_params:,}") print(f" Total trainable: {total:,}") print("\nDone.")