| """ |
| Non-autoregressive iterative SAT solver. |
| |
| Learns to solve Boolean satisfiability via iterative refinement: |
| - Input: clause membership + polarity per variable (from pre-generated .pt data) |
| - Shared transformer body (bidirectional attention) |
| - Output: T/F assignment per variable |
| - QuerySAT-style feedback: current assignment's violation count fed back |
| - Train with K=16 iterations, eval with K=16..256+ |
| |
| Trained on SR distribution (paired SAT/UNSAT instances differing by one literal). |
| Classification task: predict SAT vs UNSAT, and for SAT instances, predict assignments. |
| |
| Usage: |
| # Generate data first (CPU, one-time) |
| python scripts/sat_data_gen.py --n-problems 100000 --output data/sat/train.pt |
| python scripts/sat_data_gen.py --n-problems 2000 --output data/sat/eval.pt --seed 99999 |
| |
| # Train (GPU) |
| python scripts/iterative_sat.py --train-data data/sat/train.pt --eval-data data/sat/eval.pt --compile |
| |
| # Quick local test |
| python scripts/iterative_sat.py --train-data /tmp/sat_test.pt --device cpu --steps 200 --batch 32 |
| """ |
|
|
| import argparse |
| import math |
| import time |
| from contextlib import nullcontext |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class SATConfig: |
| max_vars: int = 40 |
| max_clauses: int = 256 |
| d_model: int = 128 |
| n_heads: int = 4 |
| n_layers: int = 4 |
| d_ff: int = 512 |
| dropout: float = 0.1 |
| train_iters: int = 16 |
| rope_base: float = 10.0 |
| n_scratch: int = 16 |
|
|
|
|
| |
| |
| |
|
|
| def build_rope_cache(seq_len, head_dim, base=10.0, device="cpu"): |
| theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| freqs = torch.outer(torch.arange(seq_len, device=device).float(), theta) |
| return freqs.cos(), freqs.sin() |
|
|
|
|
| def apply_rope(x, cos, sin): |
| d2 = x.shape[-1] // 2 |
| x1, x2 = x[..., :d2], x[..., d2:] |
| cos, sin = cos[:x.shape[2], :], sin[:x.shape[2], :] |
| return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.wq = nn.Linear(d_model, d_model, bias=False) |
| self.wk = nn.Linear(d_model, d_model, bias=False) |
| self.wv = nn.Linear(d_model, d_model, bias=False) |
| self.wo = nn.Linear(d_model, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, cos, sin): |
| B, N, D = x.shape |
| q = self.wq(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.wk(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.wv(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) |
| q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) |
| attn = F.scaled_dot_product_attention( |
| q, k, v, dropout_p=self.dropout.p if self.training else 0.0) |
| return self.wo(attn.transpose(1, 2).contiguous().view(B, N, D)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.1): |
| super().__init__() |
| self.norm1 = nn.RMSNorm(d_model) |
| self.attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.norm2 = nn.RMSNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff, bias=False), nn.ReLU(), |
| nn.Linear(d_ff, d_model, bias=False), nn.Dropout(dropout)) |
|
|
| def forward(self, x, cos, sin): |
| x = x + self.attn(self.norm1(x), cos, sin) |
| x = x + self.ff(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class IterativeSATModel(nn.Module): |
| """Sotaku-style iterative SAT solver. |
| |
| Key design (matching sotaku): |
| - h_prev carries the full hidden state directly (residual across iterations) |
| - pred_proj adds a small correction from detached predictions (not the hidden state) |
| - Scratchpad tokens provide extra working memory positions |
| - Gradients flow through h_prev, predictions are detached |
| """ |
| def __init__(self, config: SATConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
| N = config.max_vars |
| S = config.n_scratch |
| total_pos = N + S |
|
|
| |
| self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False) |
|
|
| |
| |
| self.pred_proj = nn.Linear(2, d, bias=False) |
|
|
| |
| if S > 0: |
| self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02) |
|
|
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock(d, config.n_heads, config.d_ff, config.dropout) |
| for _ in range(config.n_layers) |
| ]) |
| self.final_norm = nn.RMSNorm(d) |
|
|
| |
| self.assign_head = nn.Linear(d, 1, bias=False) |
|
|
| cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base) |
| self.register_buffer("rope_cos", cos) |
| self.register_buffer("rope_sin", sin) |
|
|
| def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None): |
| if n_iters is None: |
| n_iters = self.config.train_iters |
|
|
| B = clause_mask.shape[0] |
| N = self.config.max_vars |
| S = self.config.n_scratch |
| device = clause_mask.device |
|
|
| |
| features = torch.cat([clause_mask, clause_sign], dim=-1) |
| h_init = self.input_proj(features) |
|
|
| |
| if S > 0: |
| h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1) |
| h_init = torch.cat([h_init, h_scratch], dim=1) |
|
|
| h_prev = h_init |
|
|
| all_logits = [] |
| |
| preds = torch.zeros(B, N + S, 2, device=device) |
| preds[:, :N, 0] = 0.5 |
| |
|
|
| for _ in range(n_iters): |
| |
| h = h_prev + h_init + self.pred_proj(preds) |
|
|
| |
| for layer in self.layers: |
| h = layer(h, self.rope_cos, self.rope_sin) |
| h = self.final_norm(h) |
|
|
| |
| h_prev = h |
|
|
| |
| logits = self.assign_head(h[:, :N, :]).squeeze(-1) |
| all_logits.append(logits) |
|
|
| |
| assign_prob = torch.sigmoid(logits).detach() |
| violation = self._compute_violations(assign_prob, clause_mask, clause_sign) |
| preds = torch.zeros(B, N + S, 2, device=device) |
| preds[:, :N, 0] = assign_prob |
| preds[:, :N, 1] = violation |
|
|
| return all_logits |
|
|
| def _compute_violations(self, assign_prob, clause_mask, clause_sign): |
| """Compute per-variable violation signal (QuerySAT-style). |
| |
| For each variable, compute how many of its clauses are currently violated |
| by the soft assignment. A clause is "violated" if all its literals are false. |
| |
| assign_prob: (B, N) — probability of each variable being True |
| clause_mask: (B, N, max_clauses) — 1 if variable in clause |
| clause_sign: (B, N, max_clauses) — polarity |
| |
| Returns: (B, N) — per-variable violation signal (0-1, higher = more violated) |
| """ |
| |
| |
| |
| lit_sat = torch.where( |
| clause_sign > 0, |
| assign_prob.unsqueeze(-1), |
| torch.where(clause_sign < 0, 1 - assign_prob.unsqueeze(-1), torch.zeros_like(clause_sign)) |
| ) |
|
|
| |
| |
| clause_sat = (lit_sat * clause_mask).sum(dim=1) |
| |
| clause_size = clause_mask.sum(dim=1).clamp(min=1) |
| clause_unsat = 1 - (clause_sat / clause_size).clamp(max=1) |
|
|
| |
| var_violation = (clause_unsat.unsqueeze(1) * clause_mask).sum(dim=-1) |
| var_n_clauses = clause_mask.sum(dim=-1).clamp(min=1) |
| return var_violation / var_n_clauses |
|
|
|
|
| |
| |
| |
|
|
| def load_dataset(path, device="cpu"): |
| """Load pre-generated .pt dataset.""" |
| data = torch.load(path, weights_only=True) |
| result = {} |
| for k, v in data.items(): |
| if isinstance(v, torch.Tensor): |
| if k in ("n_clauses", "n_vars"): |
| result[k] = v.long().to(device) |
| else: |
| result[k] = v.float().to(device) |
| else: |
| result[k] = v |
| return result |
|
|
|
|
| def train(config, args): |
| device = args.device |
|
|
| if device == "cuda": |
| torch.set_float32_matmul_precision('high') |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| model = IterativeSATModel(config).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model params: {n_params:,} ({n_params/1e6:.2f}M)") |
| print(f"Config: {config.n_layers}L, d={config.d_model}, h={config.n_heads}, " |
| f"ff={config.d_ff}, iters={config.train_iters}") |
| print(f"Max vars: {config.max_vars}, max clauses: {config.max_clauses}") |
| print(f"Device: {device}") |
|
|
| |
| print(f"\nLoading training data from {args.train_data}...") |
| train_data = load_dataset(args.train_data, device) |
| n_train = train_data["sat_mask"].shape[0] |
| print(f" {n_train} problems loaded") |
|
|
| eval_data = None |
| if args.eval_data: |
| print(f"Loading eval data from {args.eval_data}...") |
| eval_data = load_dataset(args.eval_data, device) |
| print(f" {eval_data['sat_mask'].shape[0]} problems loaded") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) |
|
|
| def lr_schedule(step): |
| if step < args.warmup: |
| return step / args.warmup |
| progress = (step - args.warmup) / max(1, args.steps - args.warmup) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) |
|
|
| if args.compile and device == "cuda": |
| print("Compiling...") |
| model = torch.compile(model) |
| print("Done.") |
|
|
| use_amp = device == "cuda" |
| scaler = torch.amp.GradScaler('cuda', enabled=use_amp) |
| autocast_ctx = torch.amp.autocast('cuda', dtype=torch.bfloat16) if use_amp else nullcontext() |
|
|
| t0 = time.time() |
|
|
| for step in range(args.steps + 1): |
| model.train() |
|
|
| |
| idx = torch.randint(0, n_train, (args.batch,), device=device) |
| is_sat = torch.rand(args.batch, device=device) < 0.5 |
|
|
| |
| mask = torch.where(is_sat.view(-1, 1, 1), train_data["sat_mask"][idx], train_data["unsat_mask"][idx]) |
| sign = torch.where(is_sat.view(-1, 1, 1), train_data["sat_sign"][idx], train_data["unsat_sign"][idx]) |
| solutions = train_data["solutions"][idx] |
| n_vars = train_data["n_vars"][idx] |
|
|
| with autocast_ctx: |
| all_logits = model(mask, sign, n_vars) |
|
|
| |
| |
| |
| loss = 0.0 |
| var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1) |
|
|
| for logits in all_logits: |
| |
| assign_loss = F.binary_cross_entropy_with_logits( |
| logits, solutions, reduction='none') |
| |
| sat_mask = is_sat.unsqueeze(1) & var_mask |
| if sat_mask.any(): |
| loss += (assign_loss * sat_mask).sum() / sat_mask.sum() |
|
|
| loss /= len(all_logits) |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| if step % args.log_interval == 0: |
| elapsed = time.time() - t0 |
| with torch.no_grad(): |
| final_assign = (all_logits[-1] > 0).float() |
| |
| sat_solved = _check_sat(final_assign, mask, sign, n_vars, is_sat) |
|
|
| print(f"Step {step:5d} | Loss: {loss.item():.4f} | " |
| f"SAT solved: {sat_solved:.1%} | {elapsed:.1f}s") |
|
|
| if step > 0 and step % args.eval_interval == 0 and eval_data is not None: |
| evaluate(model, config, eval_data, device) |
|
|
| print("\n" + "=" * 70) |
| print("FINAL EVALUATION") |
| print("=" * 70) |
| if eval_data is not None: |
| evaluate(model, config, eval_data, device, verbose=True) |
|
|
| if args.save_path: |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| checkpoint = { |
| "model_state_dict": raw_model.state_dict(), |
| "config": vars(config), |
| } |
| torch.save(checkpoint, args.save_path) |
| print(f"\nCheckpoint saved to {args.save_path}") |
|
|
| if args.upload_hf: |
| from huggingface_hub import HfApi |
| import os |
| api = HfApi() |
| try: |
| api.create_repo(args.upload_hf, exist_ok=True) |
| except Exception as e: |
| print(f"Warning: {e}") |
| api.upload_file(path_or_fileobj=args.save_path, path_in_repo="model.pt", repo_id=args.upload_hf) |
| api.upload_file(path_or_fileobj=os.path.abspath(__file__), path_in_repo="iterative_sat.py", repo_id=args.upload_hf) |
| print(f"Uploaded to https://huggingface.co/{args.upload_hf}") |
|
|
| return model |
|
|
|
|
| def _check_sat(assignments, clause_mask, clause_sign, n_vars, is_sat_label): |
| """Check what fraction of SAT-labeled instances are actually solved.""" |
| B, N, M = clause_mask.shape |
| device = assignments.device |
|
|
| |
| lit_sat = torch.where( |
| clause_sign > 0, assignments.unsqueeze(-1), |
| torch.where(clause_sign < 0, 1 - assignments.unsqueeze(-1), torch.ones_like(clause_sign)) |
| ) |
|
|
| |
| clause_has_sat_lit = (lit_sat * clause_mask).sum(dim=1) > 0 |
| clause_exists = clause_mask.sum(dim=1) > 0 |
|
|
| |
| all_sat = (clause_has_sat_lit | ~clause_exists).all(dim=1) |
|
|
| |
| n_sat = is_sat_label.sum() |
| if n_sat == 0: |
| return 0.0 |
| return (all_sat & is_sat_label).sum().float() / n_sat |
|
|
|
|
| def evaluate(model, config, eval_data, device, verbose=False): |
| """Evaluate with different iteration counts.""" |
| model.eval() |
|
|
| n_eval = eval_data["sat_mask"].shape[0] |
| iter_counts = [config.train_iters, 32, 64, 128, 256] |
|
|
| |
| mask = eval_data["sat_mask"] |
| sign = eval_data["sat_sign"] |
| solutions = eval_data["solutions"] |
| n_vars = eval_data["n_vars"] |
|
|
| print(f"\n SAT instances (n={n_eval})") |
| print(f" {'Iters':>6s} | {'Solved':>8s} | {'Bit Acc':>8s}") |
| print(f" {'-'*6} | {'-'*8} | {'-'*8}") |
|
|
| for n_iters in iter_counts: |
| with torch.no_grad(): |
| all_logits = model(mask, sign, n_vars, n_iters=n_iters) |
| final_assign = (all_logits[-1] > 0).float() |
|
|
| is_sat = torch.ones(n_eval, dtype=torch.bool, device=device) |
| solved = _check_sat(final_assign, mask, sign, n_vars, is_sat) |
|
|
| |
| var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1) |
| correct_bits = ((final_assign == solutions) & var_mask).sum().float() |
| total_bits = var_mask.sum().float() |
| bit_acc = correct_bits / total_bits |
|
|
| print(f" {n_iters:6d} | {solved.item():>7.1%} | {bit_acc.item():>7.1%}") |
|
|
| if verbose: |
| |
| with torch.no_grad(): |
| all_logits = model(mask[:8], sign[:8], n_vars[:8], n_iters=256) |
| final_assign = (all_logits[-1] > 0).float() |
|
|
| print(f"\n Sample predictions (256 iters):") |
| for i in range(min(8, n_eval)): |
| nv = n_vars[i].item() |
| pred = final_assign[i, :nv].long().tolist() |
| true = solutions[i, :nv].long().tolist() |
| nc = eval_data["n_clauses"][i].item() |
|
|
| is_sat_i = torch.ones(1, dtype=torch.bool, device=device) |
| solved_i = _check_sat( |
| final_assign[i:i+1], mask[i:i+1], sign[i:i+1], n_vars[i:i+1], is_sat_i) |
| status = "✓" if solved_i > 0.5 else "✗" |
|
|
| n_diff = sum(p != t for p, t in zip(pred, true)) |
| print(f" {status} vars={nv}, clauses={nc}, diff={n_diff}/{nv}") |
| if nv <= 20: |
| print(f" Pred: {pred}") |
| print(f" True: {true}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Iterative SAT solver") |
| parser.add_argument("--train-data", required=True, help="Training .pt file") |
| parser.add_argument("--eval-data", default=None, help="Eval .pt file") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--steps", type=int, default=30000) |
| parser.add_argument("--batch", type=int, default=512) |
| parser.add_argument("--lr", type=float, default=2e-3) |
| parser.add_argument("--warmup", type=int, default=1400) |
| parser.add_argument("--log-interval", type=int, default=100) |
| parser.add_argument("--eval-interval", type=int, default=5000) |
| parser.add_argument("--compile", action="store_true") |
| parser.add_argument("--save-path", default=None) |
| parser.add_argument("--upload-hf", default=None) |
|
|
| parser.add_argument("--d-model", type=int, default=128) |
| parser.add_argument("--n-layers", type=int, default=4) |
| parser.add_argument("--n-heads", type=int, default=4) |
| parser.add_argument("--d-ff", type=int, default=512) |
| parser.add_argument("--train-iters", type=int, default=16) |
| parser.add_argument("--max-vars", type=int, default=40) |
| parser.add_argument("--max-clauses", type=int, default=256) |
| parser.add_argument("--dropout", type=float, default=0.1) |
| parser.add_argument("--n-scratch", type=int, default=16, help="Number of scratchpad tokens") |
|
|
| args = parser.parse_args() |
|
|
| config = SATConfig( |
| max_vars=args.max_vars, |
| max_clauses=args.max_clauses, |
| d_model=args.d_model, |
| n_heads=args.n_heads, |
| n_layers=args.n_layers, |
| d_ff=args.d_ff, |
| dropout=args.dropout, |
| train_iters=args.train_iters, |
| n_scratch=args.n_scratch, |
| ) |
|
|
| train(config, args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|