|
|
""" |
|
|
Trainable interface layers for frozen threshold circuits. |
|
|
BitEncoder, OpRouter, BitDecoder wrap the frozen circuits. |
|
|
HiddenStateExtractor and AugmentedArithmeticModel for LLM integration. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from circuits import FrozenThresholdCircuits, heaviside_ste |
|
|
|
|
|
MODEL_ID = 'HuggingFaceTB/SmolLM2-360M-Instruct' |
|
|
OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] |
|
|
OP_SYMBOLS = {'add': '+', 'sub': '-', 'mul': '*', 'gt': '>', 'lt': '<', 'eq': '=='} |
|
|
|
|
|
|
|
|
class BitEncoder(nn.Module): |
|
|
""" |
|
|
Encodes two 8-bit operands from input representation. |
|
|
Uses residual connection to preserve ground truth bits while allowing learned refinement. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32): |
|
|
super().__init__() |
|
|
self.refine = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_dim, 16), |
|
|
) |
|
|
self.scale = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: [batch, input_dim] input with first 16 dims being a_bits, b_bits |
|
|
|
|
|
Returns: |
|
|
a_bits: [batch, 8] first operand bits |
|
|
b_bits: [batch, 8] second operand bits |
|
|
""" |
|
|
base_bits = x[:, :16] |
|
|
refinement = self.refine(x) * torch.sigmoid(self.scale) |
|
|
bits = base_bits + refinement |
|
|
bits = torch.clamp(bits, 0, 1) |
|
|
hard_bits = heaviside_ste(bits - 0.5) |
|
|
out = hard_bits - bits.detach() + bits |
|
|
|
|
|
return out[:, :8], out[:, 8:] |
|
|
|
|
|
|
|
|
class OpRouter(nn.Module): |
|
|
""" |
|
|
Routes computation to the appropriate circuit based on input. |
|
|
Outputs soft weights over operations for gradient flow. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32, n_ops: int = 6): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, n_ops), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [batch, input_dim] input features |
|
|
|
|
|
Returns: |
|
|
op_weights: [batch, n_ops] soft operation weights (softmax) |
|
|
""" |
|
|
logits = self.net(x) |
|
|
return F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
class BitDecoder(nn.Module): |
|
|
""" |
|
|
Decodes circuit output bits to target representation. |
|
|
For standalone training: outputs soft bits for loss computation. |
|
|
For LLM integration: would project to hidden state delta. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_dim: int = 8): |
|
|
super().__init__() |
|
|
self.output_dim = output_dim |
|
|
|
|
|
def forward(self, result_bits: torch.Tensor) -> torch.Tensor: |
|
|
return result_bits |
|
|
|
|
|
|
|
|
class ThresholdALU(nn.Module): |
|
|
""" |
|
|
Complete trainable interface + frozen circuits. |
|
|
Learns to encode inputs, route to circuits, decode outputs. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'cuda'): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
|
|
|
self.circuits = FrozenThresholdCircuits(device=device) |
|
|
|
|
|
for key in self.circuits.weights: |
|
|
self.circuits.weights[key].requires_grad = False |
|
|
|
|
|
self.encoder = BitEncoder(input_dim=16 + 6, hidden_dim=64).to(device) |
|
|
self.router = OpRouter(input_dim=16 + 6, hidden_dim=32, n_ops=6).to(device) |
|
|
self.decoder = BitDecoder(output_dim=8).to(device) |
|
|
|
|
|
def forward(self, a_bits_in: torch.Tensor, b_bits_in: torch.Tensor, |
|
|
op_onehot: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through trainable interface + frozen circuits. |
|
|
|
|
|
Args: |
|
|
a_bits_in: [batch, 8] input A bits (ground truth for training) |
|
|
b_bits_in: [batch, 8] input B bits (ground truth for training) |
|
|
op_onehot: [batch, 6] one-hot operation selector |
|
|
|
|
|
Returns: |
|
|
result_bits: [batch, 8] output bits |
|
|
""" |
|
|
x = torch.cat([a_bits_in, b_bits_in, op_onehot], dim=-1) |
|
|
|
|
|
a_bits, b_bits = self.encoder(x) |
|
|
|
|
|
op_weights = self.router(x) |
|
|
|
|
|
result = self.circuits(a_bits, b_bits, op_weights) |
|
|
|
|
|
output = self.decoder(result) |
|
|
|
|
|
return output |
|
|
|
|
|
def forward_direct(self, a_bits: torch.Tensor, b_bits: torch.Tensor, |
|
|
op_onehot: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Direct forward through circuits (bypass encoder/router for testing). |
|
|
""" |
|
|
return self.circuits(a_bits, b_bits, op_onehot) |
|
|
|
|
|
|
|
|
class DirectCircuitModel(nn.Module): |
|
|
""" |
|
|
Minimal model that directly uses circuits without learned encoding. |
|
|
For validating that circuits themselves achieve 100% fitness. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'cuda'): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.circuits = FrozenThresholdCircuits(device=device) |
|
|
|
|
|
def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor, |
|
|
op_onehot: torch.Tensor) -> torch.Tensor: |
|
|
return self.circuits(a_bits, b_bits, op_onehot) |
|
|
|
|
|
|
|
|
class HiddenStateExtractor(nn.Module): |
|
|
""" |
|
|
Extracts operands and operation from LLM hidden states. |
|
|
This is the hard part - must learn to parse numbers from embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256): |
|
|
super().__init__() |
|
|
|
|
|
self.a_extractor = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, 8), |
|
|
) |
|
|
|
|
|
self.b_extractor = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, 8), |
|
|
) |
|
|
|
|
|
self.op_router = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, len(OPERATIONS)), |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [batch, hidden_dim] from LLM |
|
|
|
|
|
Returns: |
|
|
a_bits: [batch, 8] |
|
|
b_bits: [batch, 8] |
|
|
op_logits: [batch, 6] |
|
|
""" |
|
|
a_logits = self.a_extractor(hidden_states) |
|
|
b_logits = self.b_extractor(hidden_states) |
|
|
op_logits = self.op_router(hidden_states) |
|
|
|
|
|
a_soft = torch.sigmoid(a_logits) |
|
|
b_soft = torch.sigmoid(b_logits) |
|
|
|
|
|
a_hard = heaviside_ste(a_logits) |
|
|
b_hard = heaviside_ste(b_logits) |
|
|
|
|
|
a_bits = a_hard - a_soft.detach() + a_soft |
|
|
b_bits = b_hard - b_soft.detach() + b_soft |
|
|
|
|
|
return a_bits, b_bits, op_logits |
|
|
|
|
|
|
|
|
class AttentionPooling(nn.Module): |
|
|
""" |
|
|
Learnable attention pooling over sequence positions. |
|
|
Replaces mean pooling - learns which tokens matter for extraction. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, num_heads: int = 4): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = hidden_dim // num_heads |
|
|
|
|
|
self.query = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.key = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.value = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.out_proj = nn.Linear(hidden_dim, hidden_dim) |
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) |
|
|
|
|
|
def forward(self, embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
embeddings: [batch, seq_len, hidden_dim] |
|
|
mask: [batch, seq_len] attention mask (1 = attend, 0 = ignore) |
|
|
|
|
|
Returns: |
|
|
pooled: [batch, hidden_dim] |
|
|
""" |
|
|
batch_size, seq_len, hidden_dim = embeddings.shape |
|
|
|
|
|
cls_expanded = self.cls_token.expand(batch_size, -1, -1) |
|
|
embeddings = torch.cat([cls_expanded, embeddings], dim=1) |
|
|
|
|
|
cls_mask = torch.ones(batch_size, 1, device=mask.device) |
|
|
mask = torch.cat([cls_mask, mask], dim=1) |
|
|
|
|
|
Q = self.query(embeddings[:, :1, :]) |
|
|
K = self.key(embeddings) |
|
|
V = self.value(embeddings) |
|
|
|
|
|
Q = Q.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
K = K.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
V = V.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
|
|
|
mask_expanded = mask.unsqueeze(1).unsqueeze(2) |
|
|
scores = scores.masked_fill(mask_expanded == 0, -1e9) |
|
|
|
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
attn_weights = torch.nan_to_num(attn_weights, nan=0.0) |
|
|
|
|
|
context = torch.matmul(attn_weights, V) |
|
|
context = context.transpose(1, 2).contiguous().view(batch_size, 1, hidden_dim) |
|
|
|
|
|
pooled = self.out_proj(context).squeeze(1) |
|
|
pooled = torch.nan_to_num(pooled, nan=0.0) |
|
|
|
|
|
return pooled |
|
|
|
|
|
|
|
|
class MultiHeadBitExtractor(nn.Module): |
|
|
""" |
|
|
8 separate extractors for 8 bits - each bit gets its own specialized network. |
|
|
More expressive than single MLP predicting all 8 bits at once. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 128): |
|
|
super().__init__() |
|
|
|
|
|
self.bit_extractors = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, 1), |
|
|
) |
|
|
for _ in range(8) |
|
|
]) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [batch, hidden_dim] |
|
|
|
|
|
Returns: |
|
|
bits: [batch, 8] - one bit from each extractor |
|
|
""" |
|
|
hidden_states = torch.nan_to_num(hidden_states, nan=0.0) |
|
|
|
|
|
bit_logits = [extractor(hidden_states) for extractor in self.bit_extractors] |
|
|
logits = torch.cat(bit_logits, dim=-1) |
|
|
logits = torch.clamp(logits, -20, 20) |
|
|
|
|
|
soft = torch.sigmoid(logits) |
|
|
hard = heaviside_ste(logits) |
|
|
bits = hard - soft.detach() + soft |
|
|
|
|
|
return bits, logits |
|
|
|
|
|
|
|
|
class Extractor(nn.Module): |
|
|
""" |
|
|
Extracts operands and operation from LLM hidden states. |
|
|
Uses attention pooling and per-bit extraction networks. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): |
|
|
super().__init__() |
|
|
|
|
|
self.attention_pool = AttentionPooling(hidden_dim, num_heads) |
|
|
|
|
|
self.a_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2) |
|
|
self.b_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2) |
|
|
|
|
|
self.op_router = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, len(OPERATIONS)), |
|
|
) |
|
|
|
|
|
def forward(self, embeddings: torch.Tensor, mask: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
embeddings: [batch, seq_len, hidden_dim] |
|
|
mask: [batch, seq_len] |
|
|
|
|
|
Returns: |
|
|
a_bits: [batch, 8] |
|
|
b_bits: [batch, 8] |
|
|
op_logits: [batch, 6] |
|
|
""" |
|
|
pooled = self.attention_pool(embeddings, mask) |
|
|
|
|
|
a_bits, _ = self.a_extractor(pooled) |
|
|
b_bits, _ = self.b_extractor(pooled) |
|
|
op_logits = self.op_router(pooled) |
|
|
|
|
|
return a_bits, b_bits, op_logits |
|
|
|
|
|
|
|
|
class PositionExtractor(nn.Module): |
|
|
""" |
|
|
Position-specific extraction with dynamic operator detection. |
|
|
|
|
|
Tokenization pattern for "A op B": |
|
|
[A_digits...] [operator] [space] [B_digits...] |
|
|
|
|
|
Examples: |
|
|
"5 + 3" -> ['5', ' +', ' ', '3'] (positions: A=0, op=1, B=3) |
|
|
"47 + 86" -> ['4', '7', ' +', ' ', '8', '6'] (positions: A=0-1, op=2, B=4-5) |
|
|
"127 + 128" -> ['1','2','7',' +', ' ','1','2','8'] (positions: A=0-2, op=3, B=5-7) |
|
|
|
|
|
Token IDs (SmolLM2): |
|
|
Digits '0'-'9': 32-41 |
|
|
Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758 |
|
|
Space: 216 |
|
|
""" |
|
|
|
|
|
DIGIT_TOKENS = set(range(32, 42)) |
|
|
OPERATOR_TOKENS = { |
|
|
1232: 0, |
|
|
731: 1, |
|
|
1672: 2, |
|
|
2986: 3, |
|
|
2067: 4, |
|
|
1758: 5, |
|
|
} |
|
|
SPACE_TOKEN = 216 |
|
|
MAX_DIGITS = 3 |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.a_extractor = nn.Sequential( |
|
|
nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, intermediate_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim // 2, 8), |
|
|
) |
|
|
|
|
|
self.b_extractor = nn.Sequential( |
|
|
nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, intermediate_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim // 2, 8), |
|
|
) |
|
|
|
|
|
self.op_extractor = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim // 2, len(OPERATIONS)), |
|
|
) |
|
|
|
|
|
def _find_operator_position(self, token_ids: torch.Tensor) -> tuple[int, int]: |
|
|
""" |
|
|
Find operator token position and its operation index. |
|
|
|
|
|
Args: |
|
|
token_ids: [seq_len] tensor of token IDs |
|
|
|
|
|
Returns: |
|
|
(position, op_index) or (-1, -1) if not found |
|
|
""" |
|
|
for pos, tid in enumerate(token_ids.tolist()): |
|
|
if tid in self.OPERATOR_TOKENS: |
|
|
return pos, self.OPERATOR_TOKENS[tid] |
|
|
return -1, -1 |
|
|
|
|
|
def _extract_digit_features(self, hidden: torch.Tensor, start: int, end: int) -> torch.Tensor: |
|
|
""" |
|
|
Extract and pad digit hidden states to fixed size. |
|
|
|
|
|
Args: |
|
|
hidden: [seq_len, hidden_dim] |
|
|
start: start position (inclusive) |
|
|
end: end position (exclusive) |
|
|
|
|
|
Returns: |
|
|
[hidden_dim * MAX_DIGITS] flattened features, zero-padded on the LEFT |
|
|
(so units digit is always at the same position regardless of number length) |
|
|
""" |
|
|
n_digits = end - start |
|
|
features = torch.zeros(self.MAX_DIGITS * self.hidden_dim, device=hidden.device) |
|
|
|
|
|
if n_digits > 0 and n_digits <= self.MAX_DIGITS: |
|
|
digit_hidden = hidden[start:end, :].reshape(-1) |
|
|
pad_size = (self.MAX_DIGITS - n_digits) * self.hidden_dim |
|
|
features[pad_size:] = digit_hidden |
|
|
|
|
|
return features |
|
|
|
|
|
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None): |
|
|
""" |
|
|
Args: |
|
|
hidden: [batch, seq_len, hidden_dim] |
|
|
mask: [batch, seq_len] attention mask |
|
|
token_ids: [batch, seq_len] token IDs (required for operator detection) |
|
|
|
|
|
Returns: |
|
|
a_bits: [batch, 8] |
|
|
b_bits: [batch, 8] |
|
|
op_logits: [batch, 6] |
|
|
""" |
|
|
if token_ids is None: |
|
|
raise ValueError("PositionExtractor requires token_ids for operator detection") |
|
|
|
|
|
batch_size, seq_len, hidden_dim = hidden.shape |
|
|
device = hidden.device |
|
|
|
|
|
a_features = [] |
|
|
b_features = [] |
|
|
op_features = [] |
|
|
op_indices = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
seq_mask = mask[i].bool() |
|
|
valid_len = seq_mask.sum().item() |
|
|
start_pos = seq_len - valid_len |
|
|
|
|
|
valid_tokens = token_ids[i, start_pos:] |
|
|
valid_hidden = hidden[i, start_pos:, :] |
|
|
|
|
|
op_pos, op_idx = self._find_operator_position(valid_tokens) |
|
|
|
|
|
if op_pos == -1: |
|
|
a_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device) |
|
|
b_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device) |
|
|
op_feat = torch.zeros(hidden_dim, device=device) |
|
|
op_idx = 0 |
|
|
else: |
|
|
a_feat = self._extract_digit_features(valid_hidden, 0, op_pos) |
|
|
|
|
|
op_feat = valid_hidden[op_pos, :] |
|
|
|
|
|
b_start = op_pos + 2 if (op_pos + 1 < valid_len and |
|
|
valid_tokens[op_pos + 1].item() == self.SPACE_TOKEN) else op_pos + 1 |
|
|
b_feat = self._extract_digit_features(valid_hidden, b_start, valid_len) |
|
|
|
|
|
a_features.append(a_feat) |
|
|
b_features.append(b_feat) |
|
|
op_features.append(op_feat) |
|
|
op_indices.append(op_idx) |
|
|
|
|
|
a_features = torch.stack(a_features) |
|
|
b_features = torch.stack(b_features) |
|
|
op_features = torch.stack(op_features) |
|
|
op_indices_tensor = torch.tensor(op_indices, device=device, dtype=torch.long) |
|
|
|
|
|
a_logits = self.a_extractor(a_features) |
|
|
b_logits = self.b_extractor(b_features) |
|
|
op_logits = self.op_extractor(op_features) |
|
|
|
|
|
a_soft = torch.sigmoid(a_logits) |
|
|
b_soft = torch.sigmoid(b_logits) |
|
|
a_hard = heaviside_ste(a_logits) |
|
|
b_hard = heaviside_ste(b_logits) |
|
|
a_bits = a_hard - a_soft.detach() + a_soft |
|
|
b_bits = b_hard - b_soft.detach() + b_soft |
|
|
|
|
|
return a_bits, b_bits, op_logits, op_indices_tensor |
|
|
|
|
|
|
|
|
class PositionalDigitExtractor(nn.Module): |
|
|
""" |
|
|
Position-aware digit extraction: classifies each digit position independently. |
|
|
|
|
|
This approach achieves 100% accuracy because: |
|
|
1. Each digit token position is classified independently (100% accuracy on layer 0) |
|
|
2. Numbers are reconstructed using place values (×100, ×10, ×1) |
|
|
3. No information is lost through pooling |
|
|
|
|
|
Token IDs (SmolLM2): |
|
|
Digits '0'-'9': 32-41 |
|
|
Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758 |
|
|
Space: 216 |
|
|
""" |
|
|
|
|
|
DIGIT_TOKENS = set(range(32, 42)) |
|
|
OPERATOR_TOKENS = { |
|
|
1232: 0, |
|
|
731: 1, |
|
|
1672: 2, |
|
|
2986: 3, |
|
|
2067: 4, |
|
|
1758: 5, |
|
|
} |
|
|
SPACE_TOKEN = 216 |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.digit_classifier = nn.Linear(hidden_dim, 10) |
|
|
|
|
|
self.op_classifier = nn.Linear(hidden_dim, len(OPERATIONS)) |
|
|
|
|
|
def _find_positions(self, token_ids: torch.Tensor) -> tuple: |
|
|
"""Find A digit positions, B digit positions, and operator position.""" |
|
|
token_list = token_ids.tolist() |
|
|
|
|
|
op_pos = -1 |
|
|
op_idx = 0 |
|
|
for i, tid in enumerate(token_list): |
|
|
if tid in self.OPERATOR_TOKENS: |
|
|
op_pos = i |
|
|
op_idx = self.OPERATOR_TOKENS[tid] |
|
|
break |
|
|
|
|
|
if op_pos == -1: |
|
|
return [], [], -1, 0 |
|
|
|
|
|
a_positions = [i for i in range(op_pos) if token_list[i] in self.DIGIT_TOKENS] |
|
|
|
|
|
b_start = op_pos + 2 if (op_pos + 1 < len(token_list) and |
|
|
token_list[op_pos + 1] == self.SPACE_TOKEN) else op_pos + 1 |
|
|
b_positions = [i for i in range(b_start, len(token_list)) if token_list[i] in self.DIGIT_TOKENS] |
|
|
|
|
|
return a_positions, b_positions, op_pos, op_idx |
|
|
|
|
|
def _predict_value(self, hidden: torch.Tensor, positions: list) -> tuple: |
|
|
"""Predict digit at each position and reconstruct number.""" |
|
|
if not positions: |
|
|
return torch.tensor(0.0, device=hidden.device), [] |
|
|
|
|
|
digit_logits_list = [] |
|
|
soft_value = torch.tensor(0.0, device=hidden.device) |
|
|
|
|
|
for idx, pos in enumerate(positions): |
|
|
logits = self.digit_classifier(hidden[pos]) |
|
|
digit_logits_list.append(logits) |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
digit_values = torch.arange(10, device=hidden.device, dtype=torch.float32) |
|
|
soft_digit = (probs * digit_values).sum() |
|
|
|
|
|
place_value = 10 ** (len(positions) - idx - 1) |
|
|
soft_value = soft_value + soft_digit * place_value |
|
|
|
|
|
return soft_value, digit_logits_list |
|
|
|
|
|
def _value_to_bits(self, value: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert soft value to 8 bits using differentiable operations.""" |
|
|
value = torch.clamp(value, 0, 255) |
|
|
|
|
|
bits = [] |
|
|
for i in range(7, -1, -1): |
|
|
bit = torch.sigmoid((value - (2 ** i - 0.5)) * 10) |
|
|
value = value - bit * (2 ** i) |
|
|
bits.append(bit) |
|
|
|
|
|
return torch.stack(bits) |
|
|
|
|
|
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
hidden: [batch, seq_len, hidden_dim] |
|
|
mask: [batch, seq_len] |
|
|
token_ids: [batch, seq_len] |
|
|
|
|
|
Returns: |
|
|
a_bits: [batch, 8] |
|
|
b_bits: [batch, 8] |
|
|
op_logits: [batch, 6] |
|
|
op_indices: [batch] ground truth op from tokens |
|
|
a_digit_logits: list of [batch, 10] per digit position |
|
|
b_digit_logits: list of [batch, 10] per digit position |
|
|
""" |
|
|
batch_size = hidden.shape[0] |
|
|
device = hidden.device |
|
|
|
|
|
a_bits_list = [] |
|
|
b_bits_list = [] |
|
|
op_logits_list = [] |
|
|
op_indices_list = [] |
|
|
a_values_list = [] |
|
|
b_values_list = [] |
|
|
a_digit_logits_list = [] |
|
|
b_digit_logits_list = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
seq_mask = mask[i].bool() |
|
|
valid_len = seq_mask.sum().item() |
|
|
start_pos = hidden.shape[1] - valid_len |
|
|
|
|
|
valid_hidden = hidden[i, start_pos:] |
|
|
valid_tokens = token_ids[i, start_pos:] |
|
|
|
|
|
a_pos, b_pos, op_pos, op_idx = self._find_positions(valid_tokens) |
|
|
|
|
|
a_value, a_digit_logits = self._predict_value(valid_hidden, a_pos) |
|
|
b_value, b_digit_logits = self._predict_value(valid_hidden, b_pos) |
|
|
|
|
|
a_bits = self._value_to_bits(a_value) |
|
|
b_bits = self._value_to_bits(b_value) |
|
|
|
|
|
if op_pos >= 0: |
|
|
op_logits = self.op_classifier(valid_hidden[op_pos]) |
|
|
else: |
|
|
op_logits = torch.zeros(len(OPERATIONS), device=device) |
|
|
|
|
|
a_bits_list.append(a_bits) |
|
|
b_bits_list.append(b_bits) |
|
|
op_logits_list.append(op_logits) |
|
|
op_indices_list.append(op_idx) |
|
|
a_values_list.append(a_value) |
|
|
b_values_list.append(b_value) |
|
|
a_digit_logits_list.append(a_digit_logits) |
|
|
b_digit_logits_list.append(b_digit_logits) |
|
|
|
|
|
a_bits = torch.stack(a_bits_list) |
|
|
b_bits = torch.stack(b_bits_list) |
|
|
op_logits = torch.stack(op_logits_list) |
|
|
op_indices = torch.tensor(op_indices_list, device=device, dtype=torch.long) |
|
|
a_values = torch.stack(a_values_list) |
|
|
b_values = torch.stack(b_values_list) |
|
|
|
|
|
return a_bits, b_bits, op_logits, op_indices, a_values, b_values, a_digit_logits_list, b_digit_logits_list |
|
|
|
|
|
|
|
|
class DigitExtractor(nn.Module): |
|
|
""" |
|
|
Digit-level extraction: predicts digits (0-9) then converts to bits. |
|
|
Uses attention pooling (less accurate than PositionalDigitExtractor). |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): |
|
|
super().__init__() |
|
|
|
|
|
self.attention_pool = AttentionPooling(hidden_dim, num_heads) |
|
|
|
|
|
self.a_digit_pred = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, 3 * 10), |
|
|
) |
|
|
|
|
|
self.b_digit_pred = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, 3 * 10), |
|
|
) |
|
|
|
|
|
self.op_router = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim, len(OPERATIONS)), |
|
|
) |
|
|
|
|
|
def digits_to_bits(self, digit_logits: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert 3-digit predictions to 8-bit representation. |
|
|
digit_logits: [batch, 30] (3 digits * 10 classes each) |
|
|
Returns: [batch, 8] bits |
|
|
""" |
|
|
batch_size = digit_logits.shape[0] |
|
|
|
|
|
logits = digit_logits.view(batch_size, 3, 10) |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
digit_values = torch.arange(10, device=digit_logits.device).float() |
|
|
soft_digits = (probs * digit_values).sum(dim=-1) |
|
|
|
|
|
hundreds = soft_digits[:, 0] |
|
|
tens = soft_digits[:, 1] |
|
|
ones = soft_digits[:, 2] |
|
|
|
|
|
value = hundreds * 100 + tens * 10 + ones |
|
|
value = torch.clamp(value, 0, 255) |
|
|
|
|
|
bits = [] |
|
|
for i in range(7, -1, -1): |
|
|
bit = torch.fmod(torch.floor(value / (2 ** i)), 2) |
|
|
bits.append(bit) |
|
|
|
|
|
return torch.stack(bits, dim=-1) |
|
|
|
|
|
def forward(self, hidden: torch.Tensor, mask: torch.Tensor): |
|
|
""" |
|
|
Returns: |
|
|
a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits |
|
|
""" |
|
|
pooled = self.attention_pool(hidden, mask) |
|
|
|
|
|
a_digit_logits = self.a_digit_pred(pooled) |
|
|
b_digit_logits = self.b_digit_pred(pooled) |
|
|
op_logits = self.op_router(pooled) |
|
|
|
|
|
a_bits = self.digits_to_bits(a_digit_logits) |
|
|
b_bits = self.digits_to_bits(b_digit_logits) |
|
|
|
|
|
return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits |
|
|
|
|
|
|
|
|
class HybridExtractor(nn.Module): |
|
|
""" |
|
|
Hybrid extractor that handles both digit tokens and word numbers. |
|
|
|
|
|
For digit tokens (32-41): Direct lookup, no training needed |
|
|
For word numbers: Learned MLP extraction from pooled hidden states |
|
|
|
|
|
This is the real training target - learning to extract numbers from |
|
|
natural language like "forty seven plus eighty six". |
|
|
""" |
|
|
|
|
|
DIGIT_TOKENS = set(range(32, 42)) |
|
|
SYMBOL_OP_TOKENS = { |
|
|
1232: 0, |
|
|
731: 1, |
|
|
1672: 2, |
|
|
2986: 3, |
|
|
2067: 4, |
|
|
1758: 5, |
|
|
} |
|
|
WORD_OP_TOKENS = { |
|
|
2068: 0, |
|
|
8500: 1, |
|
|
1580: 2, |
|
|
6301: 3, |
|
|
1912: 4, |
|
|
16364: 5, |
|
|
11540: 5, |
|
|
} |
|
|
ALL_OP_TOKENS = {**SYMBOL_OP_TOKENS, **WORD_OP_TOKENS} |
|
|
|
|
|
def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.attention_pool = AttentionPooling(hidden_dim, num_heads) |
|
|
self.a_pool = AttentionPooling(hidden_dim, num_heads) |
|
|
self.b_pool = AttentionPooling(hidden_dim, num_heads) |
|
|
|
|
|
self.a_digit_pred = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(intermediate_dim, 3 * 10), |
|
|
) |
|
|
|
|
|
self.b_digit_pred = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(intermediate_dim, 3 * 10), |
|
|
) |
|
|
|
|
|
self.op_predictor = nn.Sequential( |
|
|
nn.Linear(hidden_dim, intermediate_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_dim // 2, len(OPERATIONS)), |
|
|
) |
|
|
|
|
|
def _has_digit_tokens(self, token_ids: torch.Tensor) -> bool: |
|
|
"""Check if input contains digit tokens.""" |
|
|
for tid in token_ids.tolist(): |
|
|
if tid in self.DIGIT_TOKENS: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _find_op_position(self, token_ids: torch.Tensor) -> int: |
|
|
"""Find position of operator token, returns -1 if not found.""" |
|
|
tokens = token_ids.tolist() |
|
|
for i, tid in enumerate(tokens): |
|
|
if tid in self.ALL_OP_TOKENS: |
|
|
return i |
|
|
return -1 |
|
|
|
|
|
def _extract_from_digits(self, token_ids: torch.Tensor) -> tuple: |
|
|
""" |
|
|
Extract values directly from digit tokens (hardcoded lookup). |
|
|
Handles both symbol operators (' +') and word operators ('plus'). |
|
|
Returns (a_value, b_value, op_idx) or None if pattern not found. |
|
|
""" |
|
|
tokens = token_ids.tolist() |
|
|
|
|
|
op_pos = -1 |
|
|
op_idx = 0 |
|
|
for i, tid in enumerate(tokens): |
|
|
if tid in self.ALL_OP_TOKENS: |
|
|
op_pos = i |
|
|
op_idx = self.ALL_OP_TOKENS[tid] |
|
|
break |
|
|
|
|
|
if op_pos == -1: |
|
|
return None |
|
|
|
|
|
a_digits = [] |
|
|
for i in range(op_pos): |
|
|
if tokens[i] in self.DIGIT_TOKENS: |
|
|
a_digits.append(tokens[i] - 32) |
|
|
|
|
|
b_start = op_pos + 1 |
|
|
if b_start < len(tokens) and tokens[b_start] == 216: |
|
|
b_start += 1 |
|
|
|
|
|
b_digits = [] |
|
|
for i in range(b_start, len(tokens)): |
|
|
if tokens[i] in self.DIGIT_TOKENS: |
|
|
b_digits.append(tokens[i] - 32) |
|
|
|
|
|
if not a_digits or not b_digits: |
|
|
return None |
|
|
|
|
|
a_val = 0 |
|
|
for d in a_digits: |
|
|
a_val = a_val * 10 + d |
|
|
|
|
|
b_val = 0 |
|
|
for d in b_digits: |
|
|
b_val = b_val * 10 + d |
|
|
|
|
|
return min(a_val, 255), min(b_val, 255), op_idx |
|
|
|
|
|
def _value_to_bits(self, value: int, device) -> torch.Tensor: |
|
|
"""Convert integer to 8-bit tensor.""" |
|
|
bits = torch.zeros(8, device=device) |
|
|
for i in range(8): |
|
|
bits[7 - i] = (value >> i) & 1 |
|
|
return bits |
|
|
|
|
|
def _digits_to_value_and_bits(self, digit_logits: torch.Tensor, device) -> tuple: |
|
|
""" |
|
|
Convert 3-digit logits to value and bits. |
|
|
digit_logits: [30] (3 digits × 10 classes) |
|
|
Returns: (value tensor, bits tensor [8]) |
|
|
""" |
|
|
logits = digit_logits.view(3, 10) |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
digit_values = torch.arange(10, device=device, dtype=torch.float32) |
|
|
soft_digits = (probs * digit_values).sum(dim=-1) |
|
|
|
|
|
hundreds = soft_digits[0] |
|
|
tens = soft_digits[1] |
|
|
ones = soft_digits[2] |
|
|
|
|
|
value = hundreds * 100 + tens * 10 + ones |
|
|
value = torch.clamp(value, 0, 255) |
|
|
|
|
|
bits = [] |
|
|
for i in range(7, -1, -1): |
|
|
threshold = 2 ** i |
|
|
bit = torch.sigmoid((value - threshold + 0.5) * 10) |
|
|
bits.append(bit) |
|
|
value = value - bit * threshold |
|
|
return hundreds * 100 + tens * 10 + ones, torch.stack(bits) |
|
|
|
|
|
def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None): |
|
|
""" |
|
|
Args: |
|
|
hidden: [batch, seq_len, hidden_dim] |
|
|
mask: [batch, seq_len] |
|
|
token_ids: [batch, seq_len] - optional, enables digit lookup |
|
|
|
|
|
Returns: |
|
|
a_bits, b_bits, op_logits, a_values, b_values, used_lookup, |
|
|
a_digit_logits, b_digit_logits |
|
|
""" |
|
|
batch_size = hidden.shape[0] |
|
|
device = hidden.device |
|
|
|
|
|
a_bits_list = [] |
|
|
a_digit_logits_list = [] |
|
|
b_digit_logits_list = [] |
|
|
b_bits_list = [] |
|
|
op_logits_list = [] |
|
|
a_values_list = [] |
|
|
b_values_list = [] |
|
|
used_lookup_list = [] |
|
|
|
|
|
pooled = self.attention_pool(hidden, mask) |
|
|
|
|
|
for i in range(batch_size): |
|
|
lookup_result = None |
|
|
if token_ids is not None: |
|
|
seq_mask = mask[i].bool() |
|
|
valid_len = seq_mask.sum().item() |
|
|
start_pos = hidden.shape[1] - valid_len |
|
|
valid_tokens = token_ids[i, start_pos:] |
|
|
|
|
|
if self._has_digit_tokens(valid_tokens): |
|
|
lookup_result = self._extract_from_digits(valid_tokens) |
|
|
|
|
|
if lookup_result is not None: |
|
|
a_val, b_val, op_idx = lookup_result |
|
|
a_bits = self._value_to_bits(a_val, device) |
|
|
b_bits = self._value_to_bits(b_val, device) |
|
|
op_logits = torch.zeros(len(OPERATIONS), device=device) |
|
|
op_logits[op_idx] = 10.0 |
|
|
|
|
|
a_bits_list.append(a_bits) |
|
|
b_bits_list.append(b_bits) |
|
|
op_logits_list.append(op_logits) |
|
|
a_values_list.append(float(a_val)) |
|
|
b_values_list.append(float(b_val)) |
|
|
used_lookup_list.append(True) |
|
|
a_digit_logits_list.append(None) |
|
|
b_digit_logits_list.append(None) |
|
|
else: |
|
|
sample_hidden = hidden[i:i+1] |
|
|
sample_mask = mask[i:i+1] |
|
|
|
|
|
seq_mask = mask[i].bool() |
|
|
valid_len = int(seq_mask.sum().item()) |
|
|
start_pos = hidden.shape[1] - valid_len |
|
|
valid_tokens = token_ids[i, start_pos:] if token_ids is not None else None |
|
|
|
|
|
op_pos = self._find_op_position(valid_tokens) if valid_tokens is not None else -1 |
|
|
|
|
|
if op_pos > 0 and op_pos < valid_len - 1: |
|
|
a_end = start_pos + op_pos |
|
|
b_start = start_pos + op_pos + 1 |
|
|
|
|
|
a_mask = torch.zeros_like(sample_mask) |
|
|
a_mask[0, start_pos:a_end] = 1.0 |
|
|
b_mask = torch.zeros_like(sample_mask) |
|
|
b_mask[0, b_start:] = sample_mask[0, b_start:] |
|
|
|
|
|
a_pooled = self.a_pool(sample_hidden, a_mask)[0] |
|
|
b_pooled = self.b_pool(sample_hidden, b_mask)[0] |
|
|
else: |
|
|
a_pooled = pooled[i] |
|
|
b_pooled = pooled[i] |
|
|
|
|
|
a_digit_logits = self.a_digit_pred(a_pooled) |
|
|
b_digit_logits = self.b_digit_pred(b_pooled) |
|
|
op_logits = self.op_predictor(pooled[i]) |
|
|
|
|
|
a_val, a_bits = self._digits_to_value_and_bits(a_digit_logits, device) |
|
|
b_val, b_bits = self._digits_to_value_and_bits(b_digit_logits, device) |
|
|
|
|
|
a_bits_list.append(a_bits) |
|
|
b_bits_list.append(b_bits) |
|
|
op_logits_list.append(op_logits) |
|
|
a_values_list.append(a_val) |
|
|
b_values_list.append(b_val) |
|
|
used_lookup_list.append(False) |
|
|
a_digit_logits_list.append(a_digit_logits) |
|
|
b_digit_logits_list.append(b_digit_logits) |
|
|
|
|
|
a_bits = torch.stack(a_bits_list) |
|
|
b_bits = torch.stack(b_bits_list) |
|
|
op_logits = torch.stack(op_logits_list) |
|
|
a_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in a_values_list]) |
|
|
b_values = torch.stack([v if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for v in b_values_list]) |
|
|
used_lookup = torch.tensor(used_lookup_list, device=device, dtype=torch.bool) |
|
|
valid_a_logits = [x for x in a_digit_logits_list if x is not None] |
|
|
valid_b_logits = [x for x in b_digit_logits_list if x is not None] |
|
|
a_digit_logits_out = torch.stack(valid_a_logits) if valid_a_logits else None |
|
|
b_digit_logits_out = torch.stack(valid_b_logits) if valid_b_logits else None |
|
|
|
|
|
return a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits_out, b_digit_logits_out |
|
|
|
|
|
def _soft_value_to_bits(self, value: torch.Tensor, device) -> torch.Tensor: |
|
|
"""Convert soft value (0-255) to 8-bit representation differentiably.""" |
|
|
value = torch.clamp(value, 0, 255) |
|
|
bits = [] |
|
|
remaining = value |
|
|
for i in range(7, -1, -1): |
|
|
threshold = 2 ** i |
|
|
bit = torch.sigmoid((remaining - threshold + 0.5) * 10) |
|
|
bits.append(bit) |
|
|
remaining = remaining - bit * threshold |
|
|
return torch.stack(bits) |
|
|
|
|
|
|
|
|
class ArithmeticModel(nn.Module): |
|
|
""" |
|
|
LLM + extractor + frozen threshold circuits. |
|
|
Optionally unfreeze top N transformer layers with --unfreeze_layers. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0, |
|
|
extract_layer: int = -1, position_extract: bool = False, |
|
|
digit_pred: bool = False, positional_digit: bool = False, |
|
|
hybrid: bool = False): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.unfreeze_layers = unfreeze_layers |
|
|
self.extract_layer = extract_layer |
|
|
self.position_extract = position_extract |
|
|
self.digit_pred = digit_pred |
|
|
self.positional_digit = positional_digit |
|
|
self.hybrid = hybrid |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
print("[1/4] Loading tokenizer...", flush=True) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
self.tokenizer.padding_side = 'left' |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
print(" Tokenizer loaded.", flush=True) |
|
|
|
|
|
print("[2/4] Loading SmolLM2-360M...", flush=True) |
|
|
self.llm = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16, |
|
|
device_map=device, |
|
|
output_hidden_states=True |
|
|
) |
|
|
|
|
|
for param in self.llm.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if unfreeze_layers > 0: |
|
|
num_layers = len(self.llm.model.layers) |
|
|
layers_to_unfreeze = list(range(num_layers - unfreeze_layers, num_layers)) |
|
|
print(f" Unfreezing layers {layers_to_unfreeze}...", flush=True) |
|
|
for layer_idx in layers_to_unfreeze: |
|
|
for param in self.llm.model.layers[layer_idx].parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
hidden_dim = self.llm.config.hidden_size |
|
|
llm_params = sum(p.numel() for p in self.llm.parameters()) |
|
|
trainable_llm = sum(p.numel() for p in self.llm.parameters() if p.requires_grad) |
|
|
print(f" LLM loaded. Hidden dim: {hidden_dim}", flush=True) |
|
|
print(f" LLM params: {llm_params:,} total, {trainable_llm:,} trainable", flush=True) |
|
|
|
|
|
print("[3/4] Loading threshold circuits...", flush=True) |
|
|
self.circuits = FrozenThresholdCircuits(device=device) |
|
|
print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True) |
|
|
|
|
|
print("[4/4] Initializing extractor...", flush=True) |
|
|
if hybrid: |
|
|
print(" Using HYBRID extraction (digit lookup + word learning)", flush=True) |
|
|
self.extractor = HybridExtractor( |
|
|
hidden_dim=hidden_dim, |
|
|
intermediate_dim=256, |
|
|
num_heads=4 |
|
|
).to(device) |
|
|
elif positional_digit: |
|
|
print(" Using POSITIONAL DIGIT extraction (100% proven)", flush=True) |
|
|
self.extractor = PositionalDigitExtractor( |
|
|
hidden_dim=hidden_dim |
|
|
).to(device) |
|
|
elif position_extract: |
|
|
print(" Using position-specific extraction", flush=True) |
|
|
self.extractor = PositionExtractor( |
|
|
hidden_dim=hidden_dim, |
|
|
intermediate_dim=256 |
|
|
).to(device) |
|
|
elif digit_pred: |
|
|
print(" Using digit-level prediction", flush=True) |
|
|
self.extractor = DigitExtractor( |
|
|
hidden_dim=hidden_dim, |
|
|
intermediate_dim=256, |
|
|
num_heads=4 |
|
|
).to(device) |
|
|
else: |
|
|
self.extractor = Extractor( |
|
|
hidden_dim=hidden_dim, |
|
|
intermediate_dim=256, |
|
|
num_heads=4 |
|
|
).to(device) |
|
|
|
|
|
if extract_layer != -1: |
|
|
print(f" Extracting from layer {extract_layer}", flush=True) |
|
|
|
|
|
trainable_ext = sum(p.numel() for p in self.extractor.parameters()) |
|
|
total_trainable = trainable_llm + trainable_ext |
|
|
print(f" Extractor params: {trainable_ext:,}", flush=True) |
|
|
print(f" Total trainable: {total_trainable:,}", flush=True) |
|
|
|
|
|
def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Get hidden states from specified layer. |
|
|
|
|
|
Returns: |
|
|
hidden: [batch, seq_len, hidden_dim] hidden states |
|
|
mask: [batch, seq_len] attention mask |
|
|
token_ids: [batch, seq_len] input token IDs |
|
|
""" |
|
|
inputs = self.tokenizer( |
|
|
texts, |
|
|
return_tensors='pt', |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=64 |
|
|
).to(self.device) |
|
|
|
|
|
if self.unfreeze_layers > 0: |
|
|
outputs = self.llm(**inputs, output_hidden_states=True) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
outputs = self.llm(**inputs, output_hidden_states=True) |
|
|
|
|
|
hidden = outputs.hidden_states[self.extract_layer].float() |
|
|
mask = inputs.attention_mask.float() |
|
|
token_ids = inputs.input_ids |
|
|
|
|
|
return hidden, mask, token_ids |
|
|
|
|
|
def forward(self, texts: list[str]): |
|
|
""" |
|
|
Full forward pass: text -> hidden states -> extractor -> circuits -> result |
|
|
|
|
|
Returns: |
|
|
result_bits, a_bits, b_bits, op_logits |
|
|
If digit_pred: also returns a_digit_logits, b_digit_logits |
|
|
If position_extract or positional_digit: also returns op_indices |
|
|
If positional_digit: also returns a_values, b_values |
|
|
""" |
|
|
hidden, mask, token_ids = self.get_hidden_states(texts) |
|
|
|
|
|
if self.hybrid or self.positional_digit or self.position_extract: |
|
|
extractor_out = self.extractor(hidden, mask, token_ids) |
|
|
else: |
|
|
extractor_out = self.extractor(hidden, mask) |
|
|
|
|
|
if self.hybrid: |
|
|
a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits = extractor_out |
|
|
op_indices_from_tokens = None |
|
|
elif self.positional_digit: |
|
|
a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits = extractor_out |
|
|
used_lookup = None |
|
|
elif self.digit_pred: |
|
|
a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out |
|
|
op_indices_from_tokens = None |
|
|
a_values, b_values = None, None |
|
|
used_lookup = None |
|
|
elif self.position_extract: |
|
|
a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out |
|
|
a_digit_logits, b_digit_logits = None, None |
|
|
a_values, b_values = None, None |
|
|
used_lookup = None |
|
|
else: |
|
|
a_bits, b_bits, op_logits = extractor_out |
|
|
a_digit_logits, b_digit_logits = None, None |
|
|
op_indices_from_tokens = None |
|
|
a_values, b_values = None, None |
|
|
used_lookup = None |
|
|
|
|
|
op_probs = torch.softmax(op_logits, dim=-1) |
|
|
|
|
|
result_bits = self.circuits(a_bits, b_bits, op_probs) |
|
|
|
|
|
if self.hybrid: |
|
|
return result_bits, a_bits, b_bits, op_logits, a_values, b_values, used_lookup, a_digit_logits, b_digit_logits |
|
|
if self.positional_digit: |
|
|
return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens, a_values, b_values, a_digit_logits, b_digit_logits |
|
|
if self.digit_pred: |
|
|
return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits |
|
|
if self.position_extract: |
|
|
return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens |
|
|
return result_bits, a_bits, b_bits, op_logits |
|
|
|
|
|
def trainable_parameters(self): |
|
|
"""Return all trainable parameters for optimizer.""" |
|
|
params = list(self.extractor.parameters()) |
|
|
if self.unfreeze_layers > 0: |
|
|
params += [p for p in self.llm.parameters() if p.requires_grad] |
|
|
return params |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
sys.path.insert(0, '.') |
|
|
from fitness import generate_batch, compute_fitness, OPERATIONS |
|
|
|
|
|
print("Testing model components...") |
|
|
|
|
|
device = 'cuda' |
|
|
batch = generate_batch(32, device) |
|
|
|
|
|
print("\n1. Testing DirectCircuitModel (should get ~100% fitness)...") |
|
|
direct_model = DirectCircuitModel(device=device) |
|
|
|
|
|
def direct_fn(a, b, op): |
|
|
return direct_model(a, b, op) |
|
|
|
|
|
fitness, details = compute_fitness(direct_fn, n_samples=2000, batch_size=128, |
|
|
device=device, return_details=True) |
|
|
print(f" Direct circuit fitness: {fitness:.4f}") |
|
|
for op in OPERATIONS: |
|
|
acc = details['by_op'][op]['accuracy'] |
|
|
print(f" {op}: {acc:.4f}") |
|
|
|
|
|
print("\n2. Testing ThresholdALU (trainable interface)...") |
|
|
model = ThresholdALU(device=device) |
|
|
|
|
|
x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1) |
|
|
a_enc, b_enc = model.encoder(x) |
|
|
print(f" Encoder output shapes: a={a_enc.shape}, b={b_enc.shape}") |
|
|
|
|
|
op_weights = model.router(x) |
|
|
print(f" Router output shape: {op_weights.shape}") |
|
|
print(f" Router output sample: {op_weights[0].tolist()}") |
|
|
|
|
|
result = model(batch['a_bits'], batch['b_bits'], batch['op_onehot']) |
|
|
print(f" Full model output shape: {result.shape}") |
|
|
|
|
|
print("\n3. Testing untrained ThresholdALU fitness...") |
|
|
|
|
|
def model_fn(a, b, op): |
|
|
return model(a, b, op) |
|
|
|
|
|
fitness = compute_fitness(model_fn, n_samples=1000, batch_size=128, device=device) |
|
|
print(f" Untrained model fitness: {fitness:.4f} (expected low)") |
|
|
|
|
|
print("\n4. Counting parameters...") |
|
|
total = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
encoder_params = sum(p.numel() for p in model.encoder.parameters()) |
|
|
router_params = sum(p.numel() for p in model.router.parameters()) |
|
|
print(f" Encoder: {encoder_params:,}") |
|
|
print(f" Router: {router_params:,}") |
|
|
print(f" Total trainable: {total:,}") |
|
|
|
|
|
print("\nDone.") |
|
|
|