""" Non-autoregressive iterative SAT solver. Learns to solve Boolean satisfiability via iterative refinement: - Input: clause membership + polarity per variable (from pre-generated .pt data) - Shared transformer body (bidirectional attention) - Output: T/F assignment per variable - QuerySAT-style feedback: current assignment's violation count fed back - Train with K=16 iterations, eval with K=16..256+ Trained on SR distribution (paired SAT/UNSAT instances differing by one literal). Classification task: predict SAT vs UNSAT, and for SAT instances, predict assignments. Usage: # Generate data first (CPU, one-time) python scripts/sat_data_gen.py --n-problems 100000 --output data/sat/train.pt python scripts/sat_data_gen.py --n-problems 2000 --output data/sat/eval.pt --seed 99999 # Train (GPU) python scripts/iterative_sat.py --train-data data/sat/train.pt --eval-data data/sat/eval.pt --compile # Quick local test python scripts/iterative_sat.py --train-data /tmp/sat_test.pt --device cpu --steps 200 --batch 32 """ import argparse import math import time from contextlib import nullcontext from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class SATConfig: max_vars: int = 40 max_clauses: int = 256 d_model: int = 128 n_heads: int = 4 n_layers: int = 4 d_ff: int = 512 dropout: float = 0.1 train_iters: int = 16 rope_base: float = 10.0 n_scratch: int = 16 # number of scratchpad/register tokens # --------------------------------------------------------------------------- # RoPE # --------------------------------------------------------------------------- def build_rope_cache(seq_len, head_dim, base=10.0, device="cpu"): theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) freqs = torch.outer(torch.arange(seq_len, device=device).float(), theta) return freqs.cos(), freqs.sin() def apply_rope(x, cos, sin): d2 = x.shape[-1] // 2 x1, x2 = x[..., :d2], x[..., d2:] cos, sin = cos[:x.shape[2], :], sin[:x.shape[2], :] return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) # --------------------------------------------------------------------------- # Transformer layers # --------------------------------------------------------------------------- class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.wq = nn.Linear(d_model, d_model, bias=False) self.wk = nn.Linear(d_model, d_model, bias=False) self.wv = nn.Linear(d_model, d_model, bias=False) self.wo = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, cos, sin): B, N, D = x.shape q = self.wq(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) k = self.wk(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) v = self.wv(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout.p if self.training else 0.0) return self.wo(attn.transpose(1, 2).contiguous().view(B, N, D)) class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout=0.1): super().__init__() self.norm1 = nn.RMSNorm(d_model) self.attn = MultiHeadAttention(d_model, n_heads, dropout) self.norm2 = nn.RMSNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), nn.ReLU(), nn.Linear(d_ff, d_model, bias=False), nn.Dropout(dropout)) def forward(self, x, cos, sin): x = x + self.attn(self.norm1(x), cos, sin) x = x + self.ff(self.norm2(x)) return x # --------------------------------------------------------------------------- # SAT Model # --------------------------------------------------------------------------- class IterativeSATModel(nn.Module): """Sotaku-style iterative SAT solver. Key design (matching sotaku): - h_prev carries the full hidden state directly (residual across iterations) - pred_proj adds a small correction from detached predictions (not the hidden state) - Scratchpad tokens provide extra working memory positions - Gradients flow through h_prev, predictions are detached """ def __init__(self, config: SATConfig): super().__init__() self.config = config d = config.d_model N = config.max_vars S = config.n_scratch total_pos = N + S # Input encoder: clause structure → initial hidden state (one-time) self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False) # Prediction feedback: small correction from detached predictions # assign(1) + violation(1) → d_model (like sotaku's pred_proj on softmax preds) self.pred_proj = nn.Linear(2, d, bias=False) # Scratchpad tokens (extra working memory) if S > 0: self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02) # Shared transformer self.layers = nn.ModuleList([ TransformerBlock(d, config.n_heads, config.d_ff, config.dropout) for _ in range(config.n_layers) ]) self.final_norm = nn.RMSNorm(d) # Output head (variable positions only) self.assign_head = nn.Linear(d, 1, bias=False) cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base) self.register_buffer("rope_cos", cos) self.register_buffer("rope_sin", sin) def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None): if n_iters is None: n_iters = self.config.train_iters B = clause_mask.shape[0] N = self.config.max_vars S = self.config.n_scratch device = clause_mask.device # One-time encoding (re-added every iteration to prevent forgetting) features = torch.cat([clause_mask, clause_sign], dim=-1) h_init = self.input_proj(features) # (B, N, d) # Append scratchpad if S > 0: h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1) h_init = torch.cat([h_init, h_scratch], dim=1) # (B, N+S, d) h_prev = h_init # first iteration starts from input encoding all_logits = [] # Initial predictions: uniform preds = torch.zeros(B, N + S, 2, device=device) preds[:, :N, 0] = 0.5 # violation starts at 0 for _ in range(n_iters): # Clean carry + fresh input + prediction correction h = h_prev + h_init + self.pred_proj(preds) # Shared transformer for layer in self.layers: h = layer(h, self.rope_cos, self.rope_sin) h = self.final_norm(h) # h becomes h_prev for next iteration (direct carry, with gradients) h_prev = h # Predict assignments from variable positions only logits = self.assign_head(h[:, :N, :]).squeeze(-1) # (B, N) all_logits.append(logits) # Build detached prediction feedback for next iteration assign_prob = torch.sigmoid(logits).detach() violation = self._compute_violations(assign_prob, clause_mask, clause_sign) preds = torch.zeros(B, N + S, 2, device=device) preds[:, :N, 0] = assign_prob preds[:, :N, 1] = violation return all_logits def _compute_violations(self, assign_prob, clause_mask, clause_sign): """Compute per-variable violation signal (QuerySAT-style). For each variable, compute how many of its clauses are currently violated by the soft assignment. A clause is "violated" if all its literals are false. assign_prob: (B, N) — probability of each variable being True clause_mask: (B, N, max_clauses) — 1 if variable in clause clause_sign: (B, N, max_clauses) — polarity Returns: (B, N) — per-variable violation signal (0-1, higher = more violated) """ # Soft literal satisfaction: how much each literal contributes to its clause # If sign=+1 (positive literal): satisfaction = assign_prob # If sign=-1 (negative literal): satisfaction = 1 - assign_prob lit_sat = torch.where( clause_sign > 0, assign_prob.unsqueeze(-1), torch.where(clause_sign < 0, 1 - assign_prob.unsqueeze(-1), torch.zeros_like(clause_sign)) ) # (B, N, max_clauses) # Per-clause satisfaction: max over all literals in clause # Sum lit_sat across variables for each clause clause_sat = (lit_sat * clause_mask).sum(dim=1) # (B, max_clauses) # Normalize by clause size clause_size = clause_mask.sum(dim=1).clamp(min=1) # (B, max_clauses) clause_unsat = 1 - (clause_sat / clause_size).clamp(max=1) # (B, max_clauses) — 0=sat, 1=unsat # Per-variable: average unsatisfaction of clauses this variable appears in var_violation = (clause_unsat.unsqueeze(1) * clause_mask).sum(dim=-1) # (B, N) var_n_clauses = clause_mask.sum(dim=-1).clamp(min=1) # (B, N) return var_violation / var_n_clauses # --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- def load_dataset(path, device="cpu"): """Load pre-generated .pt dataset.""" data = torch.load(path, weights_only=True) result = {} for k, v in data.items(): if isinstance(v, torch.Tensor): if k in ("n_clauses", "n_vars"): result[k] = v.long().to(device) # keep integer types for indexing else: result[k] = v.float().to(device) # float32 for model compatibility else: result[k] = v return result def train(config, args): device = args.device if device == "cuda": torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model = IterativeSATModel(config).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Model params: {n_params:,} ({n_params/1e6:.2f}M)") print(f"Config: {config.n_layers}L, d={config.d_model}, h={config.n_heads}, " f"ff={config.d_ff}, iters={config.train_iters}") print(f"Max vars: {config.max_vars}, max clauses: {config.max_clauses}") print(f"Device: {device}") # Load pre-generated data print(f"\nLoading training data from {args.train_data}...") train_data = load_dataset(args.train_data, device) n_train = train_data["sat_mask"].shape[0] print(f" {n_train} problems loaded") eval_data = None if args.eval_data: print(f"Loading eval data from {args.eval_data}...") eval_data = load_dataset(args.eval_data, device) print(f" {eval_data['sat_mask'].shape[0]} problems loaded") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) def lr_schedule(step): if step < args.warmup: return step / args.warmup progress = (step - args.warmup) / max(1, args.steps - args.warmup) return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) if args.compile and device == "cuda": print("Compiling...") model = torch.compile(model) print("Done.") use_amp = device == "cuda" scaler = torch.amp.GradScaler('cuda', enabled=use_amp) autocast_ctx = torch.amp.autocast('cuda', dtype=torch.bfloat16) if use_amp else nullcontext() t0 = time.time() for step in range(args.steps + 1): model.train() # Sample batch: randomly pick SAT or UNSAT (50/50) idx = torch.randint(0, n_train, (args.batch,), device=device) is_sat = torch.rand(args.batch, device=device) < 0.5 # Build input: pick sat or unsat version mask = torch.where(is_sat.view(-1, 1, 1), train_data["sat_mask"][idx], train_data["unsat_mask"][idx]) sign = torch.where(is_sat.view(-1, 1, 1), train_data["sat_sign"][idx], train_data["unsat_sign"][idx]) solutions = train_data["solutions"][idx] # only valid for SAT instances n_vars = train_data["n_vars"][idx] with autocast_ctx: all_logits = model(mask, sign, n_vars) # Multi-task loss at every iteration: # 1. For SAT instances: BCE on assignments # 2. For all instances: encourage convergence (later iterations should be better) loss = 0.0 var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1) for logits in all_logits: # Assignment loss (SAT instances only) assign_loss = F.binary_cross_entropy_with_logits( logits, solutions, reduction='none') # Mask: only SAT instances, only valid variables sat_mask = is_sat.unsqueeze(1) & var_mask if sat_mask.any(): loss += (assign_loss * sat_mask).sum() / sat_mask.sum() loss /= len(all_logits) optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() if step % args.log_interval == 0: elapsed = time.time() - t0 with torch.no_grad(): final_assign = (all_logits[-1] > 0).float() # Check which SAT instances are actually solved sat_solved = _check_sat(final_assign, mask, sign, n_vars, is_sat) print(f"Step {step:5d} | Loss: {loss.item():.4f} | " f"SAT solved: {sat_solved:.1%} | {elapsed:.1f}s") if step > 0 and step % args.eval_interval == 0 and eval_data is not None: evaluate(model, config, eval_data, device) print("\n" + "=" * 70) print("FINAL EVALUATION") print("=" * 70) if eval_data is not None: evaluate(model, config, eval_data, device, verbose=True) if args.save_path: raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model checkpoint = { "model_state_dict": raw_model.state_dict(), "config": vars(config), } torch.save(checkpoint, args.save_path) print(f"\nCheckpoint saved to {args.save_path}") if args.upload_hf: from huggingface_hub import HfApi import os api = HfApi() try: api.create_repo(args.upload_hf, exist_ok=True) except Exception as e: print(f"Warning: {e}") api.upload_file(path_or_fileobj=args.save_path, path_in_repo="model.pt", repo_id=args.upload_hf) api.upload_file(path_or_fileobj=os.path.abspath(__file__), path_in_repo="iterative_sat.py", repo_id=args.upload_hf) print(f"Uploaded to https://huggingface.co/{args.upload_hf}") return model def _check_sat(assignments, clause_mask, clause_sign, n_vars, is_sat_label): """Check what fraction of SAT-labeled instances are actually solved.""" B, N, M = clause_mask.shape device = assignments.device # Literal satisfaction: assignment matches polarity lit_sat = torch.where( clause_sign > 0, assignments.unsqueeze(-1), torch.where(clause_sign < 0, 1 - assignments.unsqueeze(-1), torch.ones_like(clause_sign)) ) # (B, N, M) # Clause is satisfied if ANY literal in it is satisfied clause_has_sat_lit = (lit_sat * clause_mask).sum(dim=1) > 0 # (B, M) clause_exists = clause_mask.sum(dim=1) > 0 # (B, M) # Formula satisfied = all existing clauses satisfied all_sat = (clause_has_sat_lit | ~clause_exists).all(dim=1) # (B,) # Only count SAT-labeled instances n_sat = is_sat_label.sum() if n_sat == 0: return 0.0 return (all_sat & is_sat_label).sum().float() / n_sat def evaluate(model, config, eval_data, device, verbose=False): """Evaluate with different iteration counts.""" model.eval() n_eval = eval_data["sat_mask"].shape[0] iter_counts = [config.train_iters, 32, 64, 128, 256] # Test on SAT instances mask = eval_data["sat_mask"] sign = eval_data["sat_sign"] solutions = eval_data["solutions"] n_vars = eval_data["n_vars"] print(f"\n SAT instances (n={n_eval})") print(f" {'Iters':>6s} | {'Solved':>8s} | {'Bit Acc':>8s}") print(f" {'-'*6} | {'-'*8} | {'-'*8}") for n_iters in iter_counts: with torch.no_grad(): all_logits = model(mask, sign, n_vars, n_iters=n_iters) final_assign = (all_logits[-1] > 0).float() is_sat = torch.ones(n_eval, dtype=torch.bool, device=device) solved = _check_sat(final_assign, mask, sign, n_vars, is_sat) # Bit accuracy (does each variable match the reference solution?) var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1) correct_bits = ((final_assign == solutions) & var_mask).sum().float() total_bits = var_mask.sum().float() bit_acc = correct_bits / total_bits print(f" {n_iters:6d} | {solved.item():>7.1%} | {bit_acc.item():>7.1%}") if verbose: # Show some examples with torch.no_grad(): all_logits = model(mask[:8], sign[:8], n_vars[:8], n_iters=256) final_assign = (all_logits[-1] > 0).float() print(f"\n Sample predictions (256 iters):") for i in range(min(8, n_eval)): nv = n_vars[i].item() pred = final_assign[i, :nv].long().tolist() true = solutions[i, :nv].long().tolist() nc = eval_data["n_clauses"][i].item() is_sat_i = torch.ones(1, dtype=torch.bool, device=device) solved_i = _check_sat( final_assign[i:i+1], mask[i:i+1], sign[i:i+1], n_vars[i:i+1], is_sat_i) status = "✓" if solved_i > 0.5 else "✗" n_diff = sum(p != t for p, t in zip(pred, true)) print(f" {status} vars={nv}, clauses={nc}, diff={n_diff}/{nv}") if nv <= 20: print(f" Pred: {pred}") print(f" True: {true}") def main(): parser = argparse.ArgumentParser(description="Iterative SAT solver") parser.add_argument("--train-data", required=True, help="Training .pt file") parser.add_argument("--eval-data", default=None, help="Eval .pt file") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--steps", type=int, default=30000) parser.add_argument("--batch", type=int, default=512) parser.add_argument("--lr", type=float, default=2e-3) parser.add_argument("--warmup", type=int, default=1400) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument("--eval-interval", type=int, default=5000) parser.add_argument("--compile", action="store_true") parser.add_argument("--save-path", default=None) parser.add_argument("--upload-hf", default=None) parser.add_argument("--d-model", type=int, default=128) parser.add_argument("--n-layers", type=int, default=4) parser.add_argument("--n-heads", type=int, default=4) parser.add_argument("--d-ff", type=int, default=512) parser.add_argument("--train-iters", type=int, default=16) parser.add_argument("--max-vars", type=int, default=40) parser.add_argument("--max-clauses", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--n-scratch", type=int, default=16, help="Number of scratchpad tokens") args = parser.parse_args() config = SATConfig( max_vars=args.max_vars, max_clauses=args.max_clauses, d_model=args.d_model, n_heads=args.n_heads, n_layers=args.n_layers, d_ff=args.d_ff, dropout=args.dropout, train_iters=args.train_iters, n_scratch=args.n_scratch, ) train(config, args) if __name__ == "__main__": main()