File size: 4,322 Bytes
d0da4dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
---
license: apache-2.0
base_model: Qwen/Qwen3-4B
tags:
- SAT
- combinatorial-optimization
- classification
- cube-and-conquer
- data-augmentation
language:
- en
pipeline_tag: text-classification
---
# Qwen3-4B-SAT-VarSelector-Sym-Aug
A Qwen3-4B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**.
## Model Description
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, achieving **21.8% top-1 accuracy** (vs 19% for Qwen3-0.6B).
### Architecture
- **Base**: `Qwen/Qwen3-4B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 601)
- **Max Variables**: 600
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~8GB (bfloat16)
### Training with Symmetry Augmentation
This model was trained with **5x data augmentation** using semantically-safe CNF transformations:
| Augmentation | Description | Effect |
|-------------|-------------|--------|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence |
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation |
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features |
### Training Details
| Parameter | Value |
|-----------|-------|
| Original training samples | 8,110 |
| Augmented training samples | **40,550** (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Training time | ~4 hours |
| Best checkpoint | Step 1850 (epoch 2.92) |
### Performance Comparison
| Model | Parameters | Training Data | Top-1 Accuracy |
|-------|-----------|--------------|----------------|
| Qwen3-0.6B (baseline) | 600M | 8,110 samples | ~12% |
| Qwen3-0.6B (augmented) | 600M | 40,550 samples | ~19% |
| **Qwen3-4B (augmented)** | **4B** | **40,550 samples** | **~22%** |
### Training Curve Highlights
- Peak accuracy: **22.0%** at epoch 2.76
- Final accuracy: **21.8%** at epoch 2.92
- Eval loss: 3.35 (vs 3.37 for 0.6B)
## Usage
```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask
# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
logits = logits.masked_fill(~valid_mask, -1e4)
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
```
## Files
- `pytorch_model.bin` - Model weights (~8GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)
## When to Use This Model
- **Higher accuracy** than 0.6B version (+3pp)
- **Production use** when accuracy matters more than speed
- **Cube-and-Conquer** style SAT solving
## Limitations
- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Larger model size (~8GB vs 1.2GB for 0.6B)
- Slower inference (~6x slower than 0.6B)
## Related Models
- [Qwen3-0.6B-SAT-VarSelector-Sym-Aug](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Sym-Aug) - Smaller, faster version
- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline
## Citation
If you use this model, please cite the Transformer-CnC paper.
## License
Apache 2.0
|