RameshArvind commited on
Commit
a8ccddf
·
verified ·
1 Parent(s): 0caf6af

Upload iterative_sat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. iterative_sat.py +44 -48
iterative_sat.py CHANGED
@@ -112,26 +112,33 @@ class TransformerBlock(nn.Module):
112
  # ---------------------------------------------------------------------------
113
 
114
  class IterativeSATModel(nn.Module):
 
 
 
 
 
 
 
 
115
  def __init__(self, config: SATConfig):
116
  super().__init__()
117
  self.config = config
118
  d = config.d_model
119
  N = config.max_vars
120
  S = config.n_scratch
121
- total_pos = N + S # variable positions + scratchpad positions
122
 
123
- # Input: clause membership (max_clauses) + polarity (max_clauses) = 2 * max_clauses
124
  self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False)
125
 
126
- # (hidden carry handled via pred_proj — no separate gate needed)
 
 
127
 
128
- # Buffer 2: Scratchpad / register tokens (learned initial embeddings)
129
  if S > 0:
130
  self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02)
131
 
132
- # Buffer 3: Rich feedback — previous hidden projected back (not just 2 scalars)
133
- self.pred_proj = nn.Linear(d + 2, d, bias=False) # prev_hidden(d) + assign(1) + violation(1)
134
-
135
  # Shared transformer
136
  self.layers = nn.ModuleList([
137
  TransformerBlock(d, config.n_heads, config.d_ff, config.dropout)
@@ -139,7 +146,7 @@ class IterativeSATModel(nn.Module):
139
  ])
140
  self.final_norm = nn.RMSNorm(d)
141
 
142
- # Output: assignment logit per variable (T/F) — only for variable positions, not scratch
143
  self.assign_head = nn.Linear(d, 1, bias=False)
144
 
145
  cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base)
@@ -147,64 +154,53 @@ class IterativeSATModel(nn.Module):
147
  self.register_buffer("rope_sin", sin)
148
 
149
  def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None):
150
- """
151
- clause_mask: (B, max_vars, max_clauses) — 1 if variable appears in clause
152
- clause_sign: (B, max_vars, max_clauses) — polarity (+1/-1, 0=not present)
153
-
154
- Returns: list of assignment logits (B, max_vars), one per iteration
155
- """
156
  if n_iters is None:
157
  n_iters = self.config.train_iters
158
 
159
  B = clause_mask.shape[0]
160
  N = self.config.max_vars
161
  S = self.config.n_scratch
162
- d = self.config.d_model
163
  device = clause_mask.device
164
 
165
- # Encode variable features
166
- features = torch.cat([clause_mask, clause_sign], dim=-1) # (B, N, 2*max_clauses)
167
- h_vars = self.input_proj(features) # (B, N, d)
168
 
169
- # Append scratchpad tokens
170
  if S > 0:
171
- h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1) # (B, S, d)
172
- h = torch.cat([h_vars, h_scratch], dim=1) # (B, N+S, d)
173
- else:
174
- h = h_vars
175
 
176
  all_logits = []
177
- assign_prob = torch.full((B, N), 0.5, device=device)
178
- prev_hidden = torch.zeros(B, N + S, d, device=device) # buffer 1: carried hidden state
 
 
179
 
180
  for _ in range(n_iters):
181
- # Buffer 3: rich feedback — full prev hidden + scalar signals
182
- violation = self._compute_violations(assign_prob, clause_mask, clause_sign) # (B, N)
183
- # Pad violation/assign for scratch positions
184
- assign_padded = F.pad(assign_prob, (0, S)) # (B, N+S)
185
- violation_padded = F.pad(violation, (0, S)) # (B, N+S)
186
- feedback = torch.cat([
187
- prev_hidden, # (B, N+S, d) — full hidden buffer
188
- assign_padded.unsqueeze(-1), # (B, N+S, 1)
189
- violation_padded.unsqueeze(-1), # (B, N+S, 1)
190
- ], dim=-1) # (B, N+S, d+2)
191
- h_input = h + self.pred_proj(feedback)
192
-
193
- # Transformer step (all positions: variables + scratchpad)
194
- x = h_input
195
- for layer in self.layers:
196
- x = layer(x, self.rope_cos, self.rope_sin)
197
- x = self.final_norm(x)
198
 
199
- # Carry full hidden state to next iteration (with gradients)
200
- prev_hidden = x
 
 
201
 
202
- # Output: only variable positions predict assignments (not scratchpad)
203
- var_out = x[:, :N, :] # (B, N, d)
204
- logits = self.assign_head(var_out).squeeze(-1) # (B, N)
205
 
 
 
206
  all_logits.append(logits)
207
- assign_prob = torch.sigmoid(logits) # gradients flow through
 
 
 
 
 
 
208
 
209
  return all_logits
210
 
 
112
  # ---------------------------------------------------------------------------
113
 
114
  class IterativeSATModel(nn.Module):
115
+ """Sotaku-style iterative SAT solver.
116
+
117
+ Key design (matching sotaku):
118
+ - h_prev carries the full hidden state directly (residual across iterations)
119
+ - pred_proj adds a small correction from detached predictions (not the hidden state)
120
+ - Scratchpad tokens provide extra working memory positions
121
+ - Gradients flow through h_prev, predictions are detached
122
+ """
123
  def __init__(self, config: SATConfig):
124
  super().__init__()
125
  self.config = config
126
  d = config.d_model
127
  N = config.max_vars
128
  S = config.n_scratch
129
+ total_pos = N + S
130
 
131
+ # Input encoder: clause structure → initial hidden state (one-time)
132
  self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False)
133
 
134
+ # Prediction feedback: small correction from detached predictions
135
+ # assign(1) + violation(1) → d_model (like sotaku's pred_proj on softmax preds)
136
+ self.pred_proj = nn.Linear(2, d, bias=False)
137
 
138
+ # Scratchpad tokens (extra working memory)
139
  if S > 0:
140
  self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02)
141
 
 
 
 
142
  # Shared transformer
143
  self.layers = nn.ModuleList([
144
  TransformerBlock(d, config.n_heads, config.d_ff, config.dropout)
 
146
  ])
147
  self.final_norm = nn.RMSNorm(d)
148
 
149
+ # Output head (variable positions only)
150
  self.assign_head = nn.Linear(d, 1, bias=False)
151
 
152
  cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base)
 
154
  self.register_buffer("rope_sin", sin)
155
 
156
  def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None):
 
 
 
 
 
 
157
  if n_iters is None:
158
  n_iters = self.config.train_iters
159
 
160
  B = clause_mask.shape[0]
161
  N = self.config.max_vars
162
  S = self.config.n_scratch
 
163
  device = clause_mask.device
164
 
165
+ # One-time encoding (re-added every iteration to prevent forgetting)
166
+ features = torch.cat([clause_mask, clause_sign], dim=-1)
167
+ h_init = self.input_proj(features) # (B, N, d)
168
 
169
+ # Append scratchpad
170
  if S > 0:
171
+ h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1)
172
+ h_init = torch.cat([h_init, h_scratch], dim=1) # (B, N+S, d)
173
+
174
+ h_prev = h_init # first iteration starts from input encoding
175
 
176
  all_logits = []
177
+ # Initial predictions: uniform
178
+ preds = torch.zeros(B, N + S, 2, device=device)
179
+ preds[:, :N, 0] = 0.5
180
+ # violation starts at 0
181
 
182
  for _ in range(n_iters):
183
+ # Clean carry + fresh input + prediction correction
184
+ h = h_prev + h_init + self.pred_proj(preds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # Shared transformer
187
+ for layer in self.layers:
188
+ h = layer(h, self.rope_cos, self.rope_sin)
189
+ h = self.final_norm(h)
190
 
191
+ # h becomes h_prev for next iteration (direct carry, with gradients)
192
+ h_prev = h
 
193
 
194
+ # Predict assignments from variable positions only
195
+ logits = self.assign_head(h[:, :N, :]).squeeze(-1) # (B, N)
196
  all_logits.append(logits)
197
+
198
+ # Build detached prediction feedback for next iteration
199
+ assign_prob = torch.sigmoid(logits).detach()
200
+ violation = self._compute_violations(assign_prob, clause_mask, clause_sign)
201
+ preds = torch.zeros(B, N + S, 2, device=device)
202
+ preds[:, :N, 0] = assign_prob
203
+ preds[:, :N, 1] = violation
204
 
205
  return all_logits
206