File size: 4,833 Bytes
ba25603 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
- SAT
- combinatorial-optimization
- classification
- cube-and-conquer
- data-augmentation
language:
- en
pipeline_tag: text-classification
---
# Qwen3-0.6B-SAT-VarSelector-Sym-Aug
A Qwen3-0.6B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**.
## Model Description
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, resulting in significantly improved generalization.
### Architecture
- **Base**: `Qwen/Qwen3-0.6B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 601)
- **Max Variables**: 600
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~1.2GB (bfloat16)
### Training with Symmetry Augmentation
This model was trained with **5x data augmentation** using semantically-safe CNF transformations:
| Augmentation | Description | Effect |
|-------------|-------------|--------|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence |
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation |
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features |
### Training Details
| Parameter | Value |
|-----------|-------|
| Original training samples | 8,110 |
| Augmented training samples | **40,550** (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Best checkpoint | Step 1800 (epoch 2.84) |
### Performance Comparison
| Model | Training Data | Top-1 Accuracy | Top-5 Accuracy |
|-------|--------------|----------------|----------------|
| Qwen3-0.6B (baseline) | 8,110 samples | ~12% | ~32% |
| **Qwen3-0.6B (augmented)** | **40,550 samples** | **~19%** | **~42%** |
| Improvement | +5x data | **+7pp** | **+10pp** |
### Key Insight: Why Validation Loss < Training Loss
During augmented training, you'll observe **validation loss consistently lower than training loss**. This is expected and indicates the augmentation is working:
1. **Training data is harder** — augmented CNFs with permuted variables, shuffled clauses
2. **Validation data is clean** — original CNFs without transformations
3. **Model generalizes well** — learned structural patterns, not memorized examples
## Usage
```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask
# Load model
model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
logits = logits.masked_fill(~valid_mask, -1e4)
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
```
## Files
- `pytorch_model.bin` - Model weights (~1.2GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)
## When to Use This Model
- **Better generalization** than non-augmented version
- **Production/deployment** with improved accuracy
- **When training data is limited** — augmentation effectively multiplies your data
## Augmentation Code
The augmentation script is available at:
```
Yale-ROSE/Transformer-SAT/new_transformer/augment_sft_dataset.py
```
Usage:
```bash
python augment_sft_dataset.py input.jsonl output.jsonl --multiplier 5
```
## Limitations
- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Trained on specific CNF distribution
## Related Models
- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline
- [Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Higher accuracy, larger model
## Citation
If you use this model, please cite the Transformer-CnC paper.
## License
Apache 2.0
|