iterative-sat / iterative_sat.py
RameshArvind's picture
Upload iterative_sat.py with huggingface_hub
a8ccddf verified
"""
Non-autoregressive iterative SAT solver.
Learns to solve Boolean satisfiability via iterative refinement:
- Input: clause membership + polarity per variable (from pre-generated .pt data)
- Shared transformer body (bidirectional attention)
- Output: T/F assignment per variable
- QuerySAT-style feedback: current assignment's violation count fed back
- Train with K=16 iterations, eval with K=16..256+
Trained on SR distribution (paired SAT/UNSAT instances differing by one literal).
Classification task: predict SAT vs UNSAT, and for SAT instances, predict assignments.
Usage:
# Generate data first (CPU, one-time)
python scripts/sat_data_gen.py --n-problems 100000 --output data/sat/train.pt
python scripts/sat_data_gen.py --n-problems 2000 --output data/sat/eval.pt --seed 99999
# Train (GPU)
python scripts/iterative_sat.py --train-data data/sat/train.pt --eval-data data/sat/eval.pt --compile
# Quick local test
python scripts/iterative_sat.py --train-data /tmp/sat_test.pt --device cpu --steps 200 --batch 32
"""
import argparse
import math
import time
from contextlib import nullcontext
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class SATConfig:
max_vars: int = 40
max_clauses: int = 256
d_model: int = 128
n_heads: int = 4
n_layers: int = 4
d_ff: int = 512
dropout: float = 0.1
train_iters: int = 16
rope_base: float = 10.0
n_scratch: int = 16 # number of scratchpad/register tokens
# ---------------------------------------------------------------------------
# RoPE
# ---------------------------------------------------------------------------
def build_rope_cache(seq_len, head_dim, base=10.0, device="cpu"):
theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
freqs = torch.outer(torch.arange(seq_len, device=device).float(), theta)
return freqs.cos(), freqs.sin()
def apply_rope(x, cos, sin):
d2 = x.shape[-1] // 2
x1, x2 = x[..., :d2], x[..., d2:]
cos, sin = cos[:x.shape[2], :], sin[:x.shape[2], :]
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
# ---------------------------------------------------------------------------
# Transformer layers
# ---------------------------------------------------------------------------
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, cos, sin):
B, N, D = x.shape
q = self.wq(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
k = self.wk(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
v = self.wv(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
attn = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
return self.wo(attn.transpose(1, 2).contiguous().view(B, N, D))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = nn.RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm2 = nn.RMSNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff, bias=False), nn.ReLU(),
nn.Linear(d_ff, d_model, bias=False), nn.Dropout(dropout))
def forward(self, x, cos, sin):
x = x + self.attn(self.norm1(x), cos, sin)
x = x + self.ff(self.norm2(x))
return x
# ---------------------------------------------------------------------------
# SAT Model
# ---------------------------------------------------------------------------
class IterativeSATModel(nn.Module):
"""Sotaku-style iterative SAT solver.
Key design (matching sotaku):
- h_prev carries the full hidden state directly (residual across iterations)
- pred_proj adds a small correction from detached predictions (not the hidden state)
- Scratchpad tokens provide extra working memory positions
- Gradients flow through h_prev, predictions are detached
"""
def __init__(self, config: SATConfig):
super().__init__()
self.config = config
d = config.d_model
N = config.max_vars
S = config.n_scratch
total_pos = N + S
# Input encoder: clause structure → initial hidden state (one-time)
self.input_proj = nn.Linear(2 * config.max_clauses, d, bias=False)
# Prediction feedback: small correction from detached predictions
# assign(1) + violation(1) → d_model (like sotaku's pred_proj on softmax preds)
self.pred_proj = nn.Linear(2, d, bias=False)
# Scratchpad tokens (extra working memory)
if S > 0:
self.scratch_embeds = nn.Parameter(torch.randn(S, d) * 0.02)
# Shared transformer
self.layers = nn.ModuleList([
TransformerBlock(d, config.n_heads, config.d_ff, config.dropout)
for _ in range(config.n_layers)
])
self.final_norm = nn.RMSNorm(d)
# Output head (variable positions only)
self.assign_head = nn.Linear(d, 1, bias=False)
cos, sin = build_rope_cache(total_pos, d // config.n_heads, config.rope_base)
self.register_buffer("rope_cos", cos)
self.register_buffer("rope_sin", sin)
def forward(self, clause_mask, clause_sign, n_vars_batch=None, n_iters=None):
if n_iters is None:
n_iters = self.config.train_iters
B = clause_mask.shape[0]
N = self.config.max_vars
S = self.config.n_scratch
device = clause_mask.device
# One-time encoding (re-added every iteration to prevent forgetting)
features = torch.cat([clause_mask, clause_sign], dim=-1)
h_init = self.input_proj(features) # (B, N, d)
# Append scratchpad
if S > 0:
h_scratch = self.scratch_embeds.unsqueeze(0).expand(B, -1, -1)
h_init = torch.cat([h_init, h_scratch], dim=1) # (B, N+S, d)
h_prev = h_init # first iteration starts from input encoding
all_logits = []
# Initial predictions: uniform
preds = torch.zeros(B, N + S, 2, device=device)
preds[:, :N, 0] = 0.5
# violation starts at 0
for _ in range(n_iters):
# Clean carry + fresh input + prediction correction
h = h_prev + h_init + self.pred_proj(preds)
# Shared transformer
for layer in self.layers:
h = layer(h, self.rope_cos, self.rope_sin)
h = self.final_norm(h)
# h becomes h_prev for next iteration (direct carry, with gradients)
h_prev = h
# Predict assignments from variable positions only
logits = self.assign_head(h[:, :N, :]).squeeze(-1) # (B, N)
all_logits.append(logits)
# Build detached prediction feedback for next iteration
assign_prob = torch.sigmoid(logits).detach()
violation = self._compute_violations(assign_prob, clause_mask, clause_sign)
preds = torch.zeros(B, N + S, 2, device=device)
preds[:, :N, 0] = assign_prob
preds[:, :N, 1] = violation
return all_logits
def _compute_violations(self, assign_prob, clause_mask, clause_sign):
"""Compute per-variable violation signal (QuerySAT-style).
For each variable, compute how many of its clauses are currently violated
by the soft assignment. A clause is "violated" if all its literals are false.
assign_prob: (B, N) — probability of each variable being True
clause_mask: (B, N, max_clauses) — 1 if variable in clause
clause_sign: (B, N, max_clauses) — polarity
Returns: (B, N) — per-variable violation signal (0-1, higher = more violated)
"""
# Soft literal satisfaction: how much each literal contributes to its clause
# If sign=+1 (positive literal): satisfaction = assign_prob
# If sign=-1 (negative literal): satisfaction = 1 - assign_prob
lit_sat = torch.where(
clause_sign > 0,
assign_prob.unsqueeze(-1),
torch.where(clause_sign < 0, 1 - assign_prob.unsqueeze(-1), torch.zeros_like(clause_sign))
) # (B, N, max_clauses)
# Per-clause satisfaction: max over all literals in clause
# Sum lit_sat across variables for each clause
clause_sat = (lit_sat * clause_mask).sum(dim=1) # (B, max_clauses)
# Normalize by clause size
clause_size = clause_mask.sum(dim=1).clamp(min=1) # (B, max_clauses)
clause_unsat = 1 - (clause_sat / clause_size).clamp(max=1) # (B, max_clauses) — 0=sat, 1=unsat
# Per-variable: average unsatisfaction of clauses this variable appears in
var_violation = (clause_unsat.unsqueeze(1) * clause_mask).sum(dim=-1) # (B, N)
var_n_clauses = clause_mask.sum(dim=-1).clamp(min=1) # (B, N)
return var_violation / var_n_clauses
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def load_dataset(path, device="cpu"):
"""Load pre-generated .pt dataset."""
data = torch.load(path, weights_only=True)
result = {}
for k, v in data.items():
if isinstance(v, torch.Tensor):
if k in ("n_clauses", "n_vars"):
result[k] = v.long().to(device) # keep integer types for indexing
else:
result[k] = v.float().to(device) # float32 for model compatibility
else:
result[k] = v
return result
def train(config, args):
device = args.device
if device == "cuda":
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model = IterativeSATModel(config).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model params: {n_params:,} ({n_params/1e6:.2f}M)")
print(f"Config: {config.n_layers}L, d={config.d_model}, h={config.n_heads}, "
f"ff={config.d_ff}, iters={config.train_iters}")
print(f"Max vars: {config.max_vars}, max clauses: {config.max_clauses}")
print(f"Device: {device}")
# Load pre-generated data
print(f"\nLoading training data from {args.train_data}...")
train_data = load_dataset(args.train_data, device)
n_train = train_data["sat_mask"].shape[0]
print(f" {n_train} problems loaded")
eval_data = None
if args.eval_data:
print(f"Loading eval data from {args.eval_data}...")
eval_data = load_dataset(args.eval_data, device)
print(f" {eval_data['sat_mask'].shape[0]} problems loaded")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
def lr_schedule(step):
if step < args.warmup:
return step / args.warmup
progress = (step - args.warmup) / max(1, args.steps - args.warmup)
return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
if args.compile and device == "cuda":
print("Compiling...")
model = torch.compile(model)
print("Done.")
use_amp = device == "cuda"
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
autocast_ctx = torch.amp.autocast('cuda', dtype=torch.bfloat16) if use_amp else nullcontext()
t0 = time.time()
for step in range(args.steps + 1):
model.train()
# Sample batch: randomly pick SAT or UNSAT (50/50)
idx = torch.randint(0, n_train, (args.batch,), device=device)
is_sat = torch.rand(args.batch, device=device) < 0.5
# Build input: pick sat or unsat version
mask = torch.where(is_sat.view(-1, 1, 1), train_data["sat_mask"][idx], train_data["unsat_mask"][idx])
sign = torch.where(is_sat.view(-1, 1, 1), train_data["sat_sign"][idx], train_data["unsat_sign"][idx])
solutions = train_data["solutions"][idx] # only valid for SAT instances
n_vars = train_data["n_vars"][idx]
with autocast_ctx:
all_logits = model(mask, sign, n_vars)
# Multi-task loss at every iteration:
# 1. For SAT instances: BCE on assignments
# 2. For all instances: encourage convergence (later iterations should be better)
loss = 0.0
var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1)
for logits in all_logits:
# Assignment loss (SAT instances only)
assign_loss = F.binary_cross_entropy_with_logits(
logits, solutions, reduction='none')
# Mask: only SAT instances, only valid variables
sat_mask = is_sat.unsqueeze(1) & var_mask
if sat_mask.any():
loss += (assign_loss * sat_mask).sum() / sat_mask.sum()
loss /= len(all_logits)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
if step % args.log_interval == 0:
elapsed = time.time() - t0
with torch.no_grad():
final_assign = (all_logits[-1] > 0).float()
# Check which SAT instances are actually solved
sat_solved = _check_sat(final_assign, mask, sign, n_vars, is_sat)
print(f"Step {step:5d} | Loss: {loss.item():.4f} | "
f"SAT solved: {sat_solved:.1%} | {elapsed:.1f}s")
if step > 0 and step % args.eval_interval == 0 and eval_data is not None:
evaluate(model, config, eval_data, device)
print("\n" + "=" * 70)
print("FINAL EVALUATION")
print("=" * 70)
if eval_data is not None:
evaluate(model, config, eval_data, device, verbose=True)
if args.save_path:
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
checkpoint = {
"model_state_dict": raw_model.state_dict(),
"config": vars(config),
}
torch.save(checkpoint, args.save_path)
print(f"\nCheckpoint saved to {args.save_path}")
if args.upload_hf:
from huggingface_hub import HfApi
import os
api = HfApi()
try:
api.create_repo(args.upload_hf, exist_ok=True)
except Exception as e:
print(f"Warning: {e}")
api.upload_file(path_or_fileobj=args.save_path, path_in_repo="model.pt", repo_id=args.upload_hf)
api.upload_file(path_or_fileobj=os.path.abspath(__file__), path_in_repo="iterative_sat.py", repo_id=args.upload_hf)
print(f"Uploaded to https://huggingface.co/{args.upload_hf}")
return model
def _check_sat(assignments, clause_mask, clause_sign, n_vars, is_sat_label):
"""Check what fraction of SAT-labeled instances are actually solved."""
B, N, M = clause_mask.shape
device = assignments.device
# Literal satisfaction: assignment matches polarity
lit_sat = torch.where(
clause_sign > 0, assignments.unsqueeze(-1),
torch.where(clause_sign < 0, 1 - assignments.unsqueeze(-1), torch.ones_like(clause_sign))
) # (B, N, M)
# Clause is satisfied if ANY literal in it is satisfied
clause_has_sat_lit = (lit_sat * clause_mask).sum(dim=1) > 0 # (B, M)
clause_exists = clause_mask.sum(dim=1) > 0 # (B, M)
# Formula satisfied = all existing clauses satisfied
all_sat = (clause_has_sat_lit | ~clause_exists).all(dim=1) # (B,)
# Only count SAT-labeled instances
n_sat = is_sat_label.sum()
if n_sat == 0:
return 0.0
return (all_sat & is_sat_label).sum().float() / n_sat
def evaluate(model, config, eval_data, device, verbose=False):
"""Evaluate with different iteration counts."""
model.eval()
n_eval = eval_data["sat_mask"].shape[0]
iter_counts = [config.train_iters, 32, 64, 128, 256]
# Test on SAT instances
mask = eval_data["sat_mask"]
sign = eval_data["sat_sign"]
solutions = eval_data["solutions"]
n_vars = eval_data["n_vars"]
print(f"\n SAT instances (n={n_eval})")
print(f" {'Iters':>6s} | {'Solved':>8s} | {'Bit Acc':>8s}")
print(f" {'-'*6} | {'-'*8} | {'-'*8}")
for n_iters in iter_counts:
with torch.no_grad():
all_logits = model(mask, sign, n_vars, n_iters=n_iters)
final_assign = (all_logits[-1] > 0).float()
is_sat = torch.ones(n_eval, dtype=torch.bool, device=device)
solved = _check_sat(final_assign, mask, sign, n_vars, is_sat)
# Bit accuracy (does each variable match the reference solution?)
var_mask = torch.arange(config.max_vars, device=device).unsqueeze(0) < n_vars.unsqueeze(1)
correct_bits = ((final_assign == solutions) & var_mask).sum().float()
total_bits = var_mask.sum().float()
bit_acc = correct_bits / total_bits
print(f" {n_iters:6d} | {solved.item():>7.1%} | {bit_acc.item():>7.1%}")
if verbose:
# Show some examples
with torch.no_grad():
all_logits = model(mask[:8], sign[:8], n_vars[:8], n_iters=256)
final_assign = (all_logits[-1] > 0).float()
print(f"\n Sample predictions (256 iters):")
for i in range(min(8, n_eval)):
nv = n_vars[i].item()
pred = final_assign[i, :nv].long().tolist()
true = solutions[i, :nv].long().tolist()
nc = eval_data["n_clauses"][i].item()
is_sat_i = torch.ones(1, dtype=torch.bool, device=device)
solved_i = _check_sat(
final_assign[i:i+1], mask[i:i+1], sign[i:i+1], n_vars[i:i+1], is_sat_i)
status = "✓" if solved_i > 0.5 else "✗"
n_diff = sum(p != t for p, t in zip(pred, true))
print(f" {status} vars={nv}, clauses={nc}, diff={n_diff}/{nv}")
if nv <= 20:
print(f" Pred: {pred}")
print(f" True: {true}")
def main():
parser = argparse.ArgumentParser(description="Iterative SAT solver")
parser.add_argument("--train-data", required=True, help="Training .pt file")
parser.add_argument("--eval-data", default=None, help="Eval .pt file")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--steps", type=int, default=30000)
parser.add_argument("--batch", type=int, default=512)
parser.add_argument("--lr", type=float, default=2e-3)
parser.add_argument("--warmup", type=int, default=1400)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument("--eval-interval", type=int, default=5000)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--save-path", default=None)
parser.add_argument("--upload-hf", default=None)
parser.add_argument("--d-model", type=int, default=128)
parser.add_argument("--n-layers", type=int, default=4)
parser.add_argument("--n-heads", type=int, default=4)
parser.add_argument("--d-ff", type=int, default=512)
parser.add_argument("--train-iters", type=int, default=16)
parser.add_argument("--max-vars", type=int, default=40)
parser.add_argument("--max-clauses", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--n-scratch", type=int, default=16, help="Number of scratchpad tokens")
args = parser.parse_args()
config = SATConfig(
max_vars=args.max_vars,
max_clauses=args.max_clauses,
d_model=args.d_model,
n_heads=args.n_heads,
n_layers=args.n_layers,
d_ff=args.d_ff,
dropout=args.dropout,
train_iters=args.train_iters,
n_scratch=args.n_scratch,
)
train(config, args)
if __name__ == "__main__":
main()