|
|
--- |
|
|
license: apache-2.0 |
|
|
base_model: Qwen/Qwen3-4B |
|
|
tags: |
|
|
- SAT |
|
|
- combinatorial-optimization |
|
|
- classification |
|
|
- cube-and-conquer |
|
|
- data-augmentation |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
# Qwen3-4B-SAT-VarSelector-Sym-Aug |
|
|
|
|
|
A Qwen3-4B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, achieving **21.8% top-1 accuracy** (vs 19% for Qwen3-0.6B). |
|
|
|
|
|
### Architecture |
|
|
|
|
|
- **Base**: `Qwen/Qwen3-4B` (causal language model) |
|
|
- **Head**: LayerNorm → Linear(hidden_size, 601) |
|
|
- **Max Variables**: 600 |
|
|
- **Pooling**: Last non-pad token hidden state |
|
|
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax |
|
|
- **Size**: ~8GB (bfloat16) |
|
|
|
|
|
### Training with Symmetry Augmentation |
|
|
|
|
|
This model was trained with **5x data augmentation** using semantically-safe CNF transformations: |
|
|
|
|
|
| Augmentation | Description | Effect | |
|
|
|-------------|-------------|--------| |
|
|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers | |
|
|
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence | |
|
|
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation | |
|
|
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features | |
|
|
|
|
|
### Training Details |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Original training samples | 8,110 | |
|
|
| Augmented training samples | **40,550** (5x) | |
|
|
| Validation samples | 902 (unaugmented) | |
|
|
| Epochs | 3 | |
|
|
| Hardware | 8×H100 GPUs | |
|
|
| Training framework | DeepSpeed ZeRO-3 | |
|
|
| Peak learning rate | 5e-6 | |
|
|
| Training time | ~4 hours | |
|
|
| Best checkpoint | Step 1850 (epoch 2.92) | |
|
|
|
|
|
### Performance Comparison |
|
|
|
|
|
| Model | Parameters | Training Data | Top-1 Accuracy | |
|
|
|-------|-----------|--------------|----------------| |
|
|
| Qwen3-0.6B (baseline) | 600M | 8,110 samples | ~12% | |
|
|
| Qwen3-0.6B (augmented) | 600M | 40,550 samples | ~19% | |
|
|
| **Qwen3-4B (augmented)** | **4B** | **40,550 samples** | **~22%** | |
|
|
|
|
|
### Training Curve Highlights |
|
|
|
|
|
- Peak accuracy: **22.0%** at epoch 2.76 |
|
|
- Final accuracy: **21.8%** at epoch 2.92 |
|
|
- Eval loss: 3.35 (vs 3.37 for 0.6B) |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask |
|
|
|
|
|
# Load model |
|
|
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600) |
|
|
state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model = model.to("cuda", dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") |
|
|
|
|
|
# Prepare CNF input |
|
|
cnf_text = """p cnf 100 250 |
|
|
1 -2 3 0 |
|
|
-1 2 -4 0 |
|
|
... |
|
|
""" |
|
|
|
|
|
# Tokenize |
|
|
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192) |
|
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
# Get valid variable mask |
|
|
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda") |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs["logits"] |
|
|
logits = logits.masked_fill(~valid_mask, -1e4) |
|
|
predicted_var = logits.argmax(dim=-1).item() |
|
|
|
|
|
print(f"Predicted branching variable: {predicted_var}") |
|
|
``` |
|
|
|
|
|
## Files |
|
|
|
|
|
- `pytorch_model.bin` - Model weights (~8GB, bfloat16) |
|
|
- `sft_qwen_var_classifier.py` - Model class definition (required for loading) |
|
|
|
|
|
## When to Use This Model |
|
|
|
|
|
- **Higher accuracy** than 0.6B version (+3pp) |
|
|
- **Production use** when accuracy matters more than speed |
|
|
- **Cube-and-Conquer** style SAT solving |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Maximum 600 variables |
|
|
- Maximum 8192 tokens for CNF input |
|
|
- Larger model size (~8GB vs 1.2GB for 0.6B) |
|
|
- Slower inference (~6x slower than 0.6B) |
|
|
|
|
|
## Related Models |
|
|
|
|
|
- [Qwen3-0.6B-SAT-VarSelector-Sym-Aug](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Sym-Aug) - Smaller, faster version |
|
|
- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite the Transformer-CnC paper. |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|