Qwen3-4B SAT Variable Selector (GRPO 2x)
A transformer-based model for selecting branching variables in SAT (Boolean Satisfiability) Cube-and-Conquer solving. This model was fine-tuned using GRPO (Group Relative Policy Optimization) on top of an SFT checkpoint.
Model Description
This model predicts which variable to branch on given a CNF formula in DIMACS format. It was trained using reinforcement learning with solver-based rewards.
Training Details
Base Model
- Architecture: Qwen3-4B with classification head
- SFT Checkpoint:
out_qwen_4b_sft_augmented/checkpoint-5000(25.06% validation accuracy)
GRPO Training Configuration
| Parameter | Value |
|---|---|
| Training data | 2x augmented dataset (16,220 samples) |
| Epochs | 1 |
| Learning rate | 5e-6 (aggressive) |
| Effective batch size | 8 (1 × 1 × 8 GPUs) |
| Num samples per CNF | 4 |
| Entropy coefficient | 0.25 (high for diversity) |
| Temperature | 1.8 (high for exploration) |
| Solver timeout | 100ms |
| Max variables | 600 |
| Max sequence length | 8192 |
Hardware
- 8× NVIDIA H100 GPUs
- DeepSpeed ZeRO-3 (not used for GRPO)
- Gradient checkpointing enabled
Usage
import torch
from transformers import AutoTokenizer
# Load model
model = torch.load("model.pt", map_location="cpu")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
# Example CNF
cnf = """p cnf 5 3
1 -2 3 0
-1 2 0
2 -5 0"""
# Tokenize
inputs = tokenizer(cnf, return_tensors="pt", truncation=True, max_length=8192)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
predicted_var = logits[0, 1:].argmax().item() + 1 # Variables are 1-indexed
print(f"Recommended branching variable: {predicted_var}")
Training Script
This model was trained using:
bash run_grpo_2x.bash
See the full training script in the Transformer-SAT repository.
Citation
If you use this model, please cite the Transformer-SAT project.
License
Apache 2.0