Qwen3-4B-SAT-VarSelector-Sym-Aug
A Qwen3-4B model fine-tuned for SAT branching variable selection using symmetry-based data augmentation.
Model Description
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with 5x augmented data using CNF symmetry transformations, achieving 21.8% top-1 accuracy (vs 19% for Qwen3-0.6B).
Architecture
- Base:
Qwen/Qwen3-4B(causal language model) - Head: LayerNorm → Linear(hidden_size, 601)
- Max Variables: 600
- Pooling: Last non-pad token hidden state
- Masking: Invalid variables (not in CNF) are masked to -10000 before softmax
- Size: ~8GB (bfloat16)
Training with Symmetry Augmentation
This model was trained with 5x data augmentation using semantically-safe CNF transformations:
| Augmentation | Description | Effect |
|---|---|---|
| Variable Permutation | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| Clause Shuffling | Random reordering of clauses | Teaches position-independence |
| Literal Reordering | Shuffle literals within clauses | Token-level variation |
| Polarity Flipping | Flip signs of random variable subset | Teaches structural vs. polarity features |
Training Details
| Parameter | Value |
|---|---|
| Original training samples | 8,110 |
| Augmented training samples | 40,550 (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Training time | ~4 hours |
| Best checkpoint | Step 1850 (epoch 2.92) |
Performance Comparison
| Model | Parameters | Training Data | Top-1 Accuracy |
|---|---|---|---|
| Qwen3-0.6B (baseline) | 600M | 8,110 samples | ~12% |
| Qwen3-0.6B (augmented) | 600M | 40,550 samples | ~19% |
| Qwen3-4B (augmented) | 4B | 40,550 samples | ~22% |
Training Curve Highlights
- Peak accuracy: 22.0% at epoch 2.76
- Final accuracy: 21.8% at epoch 2.92
- Eval loss: 3.35 (vs 3.37 for 0.6B)
Usage
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask
# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
logits = logits.masked_fill(~valid_mask, -1e4)
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
Files
pytorch_model.bin- Model weights (~8GB, bfloat16)sft_qwen_var_classifier.py- Model class definition (required for loading)
When to Use This Model
- Higher accuracy than 0.6B version (+3pp)
- Production use when accuracy matters more than speed
- Cube-and-Conquer style SAT solving
Limitations
- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Larger model size (~8GB vs 1.2GB for 0.6B)
- Slower inference (~6x slower than 0.6B)
Related Models
- Qwen3-0.6B-SAT-VarSelector-Sym-Aug - Smaller, faster version
- Qwen3-0.6B-SAT-VarSelector - Non-augmented baseline
Citation
If you use this model, please cite the Transformer-CnC paper.
License
Apache 2.0