Upload iterative_sat.py with huggingface_hub
Browse files- iterative_sat.py +44 -48
iterative_sat.py
CHANGED
|
@@ -112,26 +112,33 @@ class TransformerBlock(nn.Module):
|
|
| 112 |
# ---------------------------------------------------------------------------
|
| 113 |
|
| 114 |
class IterativeSATModel(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def __init__(self, config: SATConfig):
|
| 116 |
super().__init__()
|
| 117 |
self.config = config
|
| 118 |
d = config.d_model
|
| 119 |
N = config.max_vars
|
| 120 |
S = config.n_scratch
|
| 121 |
-
total_pos = N + S
|
| 122 |
|
| 123 |
-
# Input: clause
|
| 124 |
self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False)
|
| 125 |
|
| 126 |
-
#
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
#
|
| 129 |
if S > 0:
|
| 130 |
self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02)
|
| 131 |
|
| 132 |
-
# Buffer 3: Rich feedback — previous hidden projected back (not just 2 scalars)
|
| 133 |
-
self.pred_proj = nn.Linear(d + 2, d, bias=False) # prev_hidden(d) + assign(1) + violation(1)
|
| 134 |
-
|
| 135 |
# Shared transformer
|
| 136 |
self.layers = nn.ModuleList([
|
| 137 |
TransformerBlock(d, config.n_heads, config.d_ff, config.dropout)
|
|
@@ -139,7 +146,7 @@ class IterativeSATModel(nn.Module):
|
|
| 139 |
])
|
| 140 |
self.final_norm = nn.RMSNorm(d)
|
| 141 |
|
| 142 |
-
# Output
|
| 143 |
self.assign_head = nn.Linear(d, 1, bias=False)
|
| 144 |
|
| 145 |
cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base)
|
|
@@ -147,64 +154,53 @@ class IterativeSATModel(nn.Module):
|
|
| 147 |
self.register_buffer("rope_sin", sin)
|
| 148 |
|
| 149 |
def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None):
|
| 150 |
-
"""
|
| 151 |
-
clause_mask: (B, max_vars, max_clauses) — 1 if variable appears in clause
|
| 152 |
-
clause_sign: (B, max_vars, max_clauses) — polarity (+1/-1, 0=not present)
|
| 153 |
-
|
| 154 |
-
Returns: list of assignment logits (B, max_vars), one per iteration
|
| 155 |
-
"""
|
| 156 |
if n_iters is None:
|
| 157 |
n_iters = self.config.train_iters
|
| 158 |
|
| 159 |
B = clause_mask.shape[0]
|
| 160 |
N = self.config.max_vars
|
| 161 |
S = self.config.n_scratch
|
| 162 |
-
d = self.config.d_model
|
| 163 |
device = clause_mask.device
|
| 164 |
|
| 165 |
-
#
|
| 166 |
-
features = torch.cat([clause_mask, clause_sign], dim=-1)
|
| 167 |
-
|
| 168 |
|
| 169 |
-
# Append scratchpad
|
| 170 |
if S > 0:
|
| 171 |
-
h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
|
| 176 |
all_logits = []
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
| 179 |
|
| 180 |
for _ in range(n_iters):
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
# Pad violation/assign for scratch positions
|
| 184 |
-
assign_padded = F.pad(assign_prob, (0, S)) # (B, N+S)
|
| 185 |
-
violation_padded = F.pad(violation, (0, S)) # (B, N+S)
|
| 186 |
-
feedback = torch.cat([
|
| 187 |
-
prev_hidden, # (B, N+S, d) — full hidden buffer
|
| 188 |
-
assign_padded.unsqueeze(-1), # (B, N+S, 1)
|
| 189 |
-
violation_padded.unsqueeze(-1), # (B, N+S, 1)
|
| 190 |
-
], dim=-1) # (B, N+S, d+2)
|
| 191 |
-
h_input = h + self.pred_proj(feedback)
|
| 192 |
-
|
| 193 |
-
# Transformer step (all positions: variables + scratchpad)
|
| 194 |
-
x = h_input
|
| 195 |
-
for layer in self.layers:
|
| 196 |
-
x = layer(x, self.rope_cos, self.rope_sin)
|
| 197 |
-
x = self.final_norm(x)
|
| 198 |
|
| 199 |
-
#
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
#
|
| 203 |
-
|
| 204 |
-
logits = self.assign_head(var_out).squeeze(-1) # (B, N)
|
| 205 |
|
|
|
|
|
|
|
| 206 |
all_logits.append(logits)
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
return all_logits
|
| 210 |
|
|
|
|
| 112 |
# ---------------------------------------------------------------------------
|
| 113 |
|
| 114 |
class IterativeSATModel(nn.Module):
|
| 115 |
+
"""Sotaku-style iterative SAT solver.
|
| 116 |
+
|
| 117 |
+
Key design (matching sotaku):
|
| 118 |
+
- h_prev carries the full hidden state directly (residual across iterations)
|
| 119 |
+
- pred_proj adds a small correction from detached predictions (not the hidden state)
|
| 120 |
+
- Scratchpad tokens provide extra working memory positions
|
| 121 |
+
- Gradients flow through h_prev, predictions are detached
|
| 122 |
+
"""
|
| 123 |
def __init__(self, config: SATConfig):
|
| 124 |
super().__init__()
|
| 125 |
self.config = config
|
| 126 |
d = config.d_model
|
| 127 |
N = config.max_vars
|
| 128 |
S = config.n_scratch
|
| 129 |
+
total_pos = N + S
|
| 130 |
|
| 131 |
+
# Input encoder: clause structure → initial hidden state (one-time)
|
| 132 |
self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False)
|
| 133 |
|
| 134 |
+
# Prediction feedback: small correction from detached predictions
|
| 135 |
+
# assign(1) + violation(1) → d_model (like sotaku's pred_proj on softmax preds)
|
| 136 |
+
self.pred_proj = nn.Linear(2, d, bias=False)
|
| 137 |
|
| 138 |
+
# Scratchpad tokens (extra working memory)
|
| 139 |
if S > 0:
|
| 140 |
self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02)
|
| 141 |
|
|
|
|
|
|
|
|
|
|
| 142 |
# Shared transformer
|
| 143 |
self.layers = nn.ModuleList([
|
| 144 |
TransformerBlock(d, config.n_heads, config.d_ff, config.dropout)
|
|
|
|
| 146 |
])
|
| 147 |
self.final_norm = nn.RMSNorm(d)
|
| 148 |
|
| 149 |
+
# Output head (variable positions only)
|
| 150 |
self.assign_head = nn.Linear(d, 1, bias=False)
|
| 151 |
|
| 152 |
cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base)
|
|
|
|
| 154 |
self.register_buffer("rope_sin", sin)
|
| 155 |
|
| 156 |
def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if n_iters is None:
|
| 158 |
n_iters = self.config.train_iters
|
| 159 |
|
| 160 |
B = clause_mask.shape[0]
|
| 161 |
N = self.config.max_vars
|
| 162 |
S = self.config.n_scratch
|
|
|
|
| 163 |
device = clause_mask.device
|
| 164 |
|
| 165 |
+
# One-time encoding (re-added every iteration to prevent forgetting)
|
| 166 |
+
features = torch.cat([clause_mask, clause_sign], dim=-1)
|
| 167 |
+
h_init = self.input_proj(features) # (B, N, d)
|
| 168 |
|
| 169 |
+
# Append scratchpad
|
| 170 |
if S > 0:
|
| 171 |
+
h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1)
|
| 172 |
+
h_init = torch.cat([h_init, h_scratch], dim=1) # (B, N+S, d)
|
| 173 |
+
|
| 174 |
+
h_prev = h_init # first iteration starts from input encoding
|
| 175 |
|
| 176 |
all_logits = []
|
| 177 |
+
# Initial predictions: uniform
|
| 178 |
+
preds = torch.zeros(B, N + S, 2, device=device)
|
| 179 |
+
preds[:, :N, 0] = 0.5
|
| 180 |
+
# violation starts at 0
|
| 181 |
|
| 182 |
for _ in range(n_iters):
|
| 183 |
+
# Clean carry + fresh input + prediction correction
|
| 184 |
+
h = h_prev + h_init + self.pred_proj(preds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
# Shared transformer
|
| 187 |
+
for layer in self.layers:
|
| 188 |
+
h = layer(h, self.rope_cos, self.rope_sin)
|
| 189 |
+
h = self.final_norm(h)
|
| 190 |
|
| 191 |
+
# h becomes h_prev for next iteration (direct carry, with gradients)
|
| 192 |
+
h_prev = h
|
|
|
|
| 193 |
|
| 194 |
+
# Predict assignments from variable positions only
|
| 195 |
+
logits = self.assign_head(h[:, :N, :]).squeeze(-1) # (B, N)
|
| 196 |
all_logits.append(logits)
|
| 197 |
+
|
| 198 |
+
# Build detached prediction feedback for next iteration
|
| 199 |
+
assign_prob = torch.sigmoid(logits).detach()
|
| 200 |
+
violation = self._compute_violations(assign_prob, clause_mask, clause_sign)
|
| 201 |
+
preds = torch.zeros(B, N + S, 2, device=device)
|
| 202 |
+
preds[:, :N, 0] = assign_prob
|
| 203 |
+
preds[:, :N, 1] = violation
|
| 204 |
|
| 205 |
return all_logits
|
| 206 |
|