Qwen3-4B-SAT-VarSelector-Sym-Aug-8-Epoch
A Qwen3-4B-based variable selector for SAT Cube-and-Conquer solving. This model predicts the optimal branching variable given a CNF formula state, enabling learned splitting strategies for parallel SAT solving.
Model Description
This model is trained for variable selection in SAT Cube-and-Conquer (CnC) solving. Given a simplified CNF formula, it outputs the variable ID that should be branched on next.
Architecture
- Backbone: Qwen3-4B (3.8B parameters, causal language model)
- Head: LayerNorm + Linear classifier
- Output: Logits over variable IDs (1 to 600)
- Pooling: Last non-padding token hidden state
Input (CNF text) → Tokenize → Qwen3-4B → Last Token Hidden State → LayerNorm → Linear → Variable Logits
Training Approach
- Supervised Fine-Tuning (SFT) on expert variable selections
- Masked Classification: Only variables appearing in the CNF are valid outputs
- Loss: Cross-entropy with invalid variable logits masked to -∞
Training Details
| Parameter | Value |
|---|---|
| Base Model | Qwen/Qwen3-4B |
| Training Samples | 40,550 (8,110 original × 5x augmentation) |
| Validation Samples | 902 |
| Epochs | 8 |
| Batch Size | 64 effective (1 per GPU × 8 GPUs × 8 gradient accumulation) |
| Learning Rate | 5e-6 |
| Warmup Ratio | 0.1 |
| Max Sequence Length | 8,192 tokens |
| Max Variables | 600 |
| Precision | bfloat16 |
| Hardware | 8× H100 GPUs |
| Distributed Training | DeepSpeed ZeRO-3 |
Data Augmentation
The training data was augmented 5× using semantically-safe transformations:
- Variable Permutation (50%): Bijective remapping of variable IDs
- Clause Shuffling (100%): Random reordering of clauses
- Literal Reordering (100%): Shuffle literals within each clause
- Polarity Flipping (30%): Flip polarity of random variable subset
Usage
Loading the Model
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
class QwenVarClassifier(nn.Module):
def __init__(self, base_model_name: str, max_vars: int = 600):
super().__init__()
self.max_vars = max_vars
cfg = AutoConfig.from_pretrained(base_model_name)
cfg.output_hidden_states = True
self.backbone = AutoModelForCausalLM.from_pretrained(
base_model_name,
config=cfg,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
hidden = self.backbone.config.hidden_size
self.head_ln = nn.LayerNorm(hidden)
self.head = nn.Linear(hidden, max_vars + 1)
self.config = self.backbone.config
def forward(self, input_ids, attention_mask, **kwargs):
out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
h = out.hidden_states[-1]
last_idx = attention_mask.sum(dim=1) - 1
last_idx = last_idx.clamp(min=0)
b = torch.arange(h.size(0), device=h.device)
pooled = h[b, last_idx]
pooled = self.head_ln(pooled)
logits = self.head(pooled)
return {"logits": logits}
# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
Inference
def cnf_valid_mask(cnf_text: str, max_vars: int = 600):
"""Build mask of valid variables from CNF text."""
mask = [0] * (max_vars + 1)
for line in cnf_text.split('\n'):
line = line.strip()
if not line or line.startswith('c') or line.startswith('p'):
continue
for tok in line.split():
try:
lit = int(tok)
v = abs(lit)
if 1 <= v <= max_vars:
mask[v] = 1
except ValueError:
continue
if sum(mask) == 0:
for v in range(1, max_vars + 1):
mask[v] = 1
return mask
# Example CNF (DIMACS format)
cnf_text = """p cnf 100 3
1 -2 3 0
-1 4 5 0
2 -3 -4 0"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, 600)], dtype=torch.bool)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
# Mask invalid variables
logits = logits.masked_fill(~valid_mask, -1e4)
# Get prediction
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
Input Format
The model expects CNF formulas in DIMACS format:
p cnf <num_vars> <num_clauses>
<lit1> <lit2> ... 0
<lit1> <lit2> ... 0
...
- Header line:
p cnf <variables> <clauses> - Each clause is a space-separated list of literals ending with
0 - Literals are integers: positive = variable, negative = negated variable
- Example:
1 -2 3 0represents (x₁ ∨ ¬x₂ ∨ x₃)
Evaluation
The model is evaluated using exact-match accuracy: the predicted variable must exactly match the expert's choice.
Random baseline for ~100 valid variables: ~1%
Intended Use
This model is designed for:
- SAT Cube-and-Conquer solving: Selecting split variables for parallel SAT solving
- Research: Studying learned heuristics for combinatorial optimization
Limitations
- Maximum 600 variables (configurable at training time)
- Maximum 8,192 tokens input length
- Trained on specific CNF distribution; may not generalize to all SAT instances
- Requires valid variable masking for correct inference
Citation
If you use this model, please cite:
@misc{qwen3-sat-varselector,
title={Qwen3-4B-SAT-VarSelector: Learned Variable Selection for SAT Cube-and-Conquer},
author={Yale-ROSE},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector-Sym-Aug-8-Epoch}
}
License
Apache 2.0