erata's picture
Upload folder using huggingface_hub
ba25603 verified
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
- SAT
- combinatorial-optimization
- classification
- cube-and-conquer
- data-augmentation
language:
- en
pipeline_tag: text-classification
---
# Qwen3-0.6B-SAT-VarSelector-Sym-Aug
A Qwen3-0.6B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**.
## Model Description
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, resulting in significantly improved generalization.
### Architecture
- **Base**: `Qwen/Qwen3-0.6B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 601)
- **Max Variables**: 600
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~1.2GB (bfloat16)
### Training with Symmetry Augmentation
This model was trained with **5x data augmentation** using semantically-safe CNF transformations:
| Augmentation | Description | Effect |
|-------------|-------------|--------|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence |
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation |
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features |
### Training Details
| Parameter | Value |
|-----------|-------|
| Original training samples | 8,110 |
| Augmented training samples | **40,550** (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Best checkpoint | Step 1800 (epoch 2.84) |
### Performance Comparison
| Model | Training Data | Top-1 Accuracy | Top-5 Accuracy |
|-------|--------------|----------------|----------------|
| Qwen3-0.6B (baseline) | 8,110 samples | ~12% | ~32% |
| **Qwen3-0.6B (augmented)** | **40,550 samples** | **~19%** | **~42%** |
| Improvement | +5x data | **+7pp** | **+10pp** |
### Key Insight: Why Validation Loss < Training Loss
During augmented training, you'll observe **validation loss consistently lower than training loss**. This is expected and indicates the augmentation is working:
1. **Training data is harder** — augmented CNFs with permuted variables, shuffled clauses
2. **Validation data is clean** — original CNFs without transformations
3. **Model generalizes well** — learned structural patterns, not memorized examples
## Usage
```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask
# Load model
model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
logits = logits.masked_fill(~valid_mask, -1e4)
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
```
## Files
- `pytorch_model.bin` - Model weights (~1.2GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)
## When to Use This Model
- **Better generalization** than non-augmented version
- **Production/deployment** with improved accuracy
- **When training data is limited** — augmentation effectively multiplies your data
## Augmentation Code
The augmentation script is available at:
```
Yale-ROSE/Transformer-SAT/new_transformer/augment_sft_dataset.py
```
Usage:
```bash
python augment_sft_dataset.py input.jsonl output.jsonl --multiplier 5
```
## Limitations
- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Trained on specific CNF distribution
## Related Models
- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline
- [Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Higher accuracy, larger model
## Citation
If you use this model, please cite the Transformer-CnC paper.
## License
Apache 2.0