File size: 4,833 Bytes
ba25603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
  - SAT
  - combinatorial-optimization
  - classification
  - cube-and-conquer
  - data-augmentation
language:
  - en
pipeline_tag: text-classification
---

# Qwen3-0.6B-SAT-VarSelector-Sym-Aug

A Qwen3-0.6B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**.

## Model Description

This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, resulting in significantly improved generalization.

### Architecture

- **Base**: `Qwen/Qwen3-0.6B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 601)
- **Max Variables**: 600
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~1.2GB (bfloat16)

### Training with Symmetry Augmentation

This model was trained with **5x data augmentation** using semantically-safe CNF transformations:

| Augmentation | Description | Effect |
|-------------|-------------|--------|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence |
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation |
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features |

### Training Details

| Parameter | Value |
|-----------|-------|
| Original training samples | 8,110 |
| Augmented training samples | **40,550** (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Best checkpoint | Step 1800 (epoch 2.84) |

### Performance Comparison

| Model | Training Data | Top-1 Accuracy | Top-5 Accuracy |
|-------|--------------|----------------|----------------|
| Qwen3-0.6B (baseline) | 8,110 samples | ~12% | ~32% |
| **Qwen3-0.6B (augmented)** | **40,550 samples** | **~19%** | **~42%** |
| Improvement | +5x data | **+7pp** | **+10pp** |

### Key Insight: Why Validation Loss < Training Loss

During augmented training, you'll observe **validation loss consistently lower than training loss**. This is expected and indicates the augmentation is working:

1. **Training data is harder** — augmented CNFs with permuted variables, shuffled clauses
2. **Validation data is clean** — original CNFs without transformations
3. **Model generalizes well** — learned structural patterns, not memorized examples

## Usage

```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask

# Load model
model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    logits = logits.masked_fill(~valid_mask, -1e4)
    predicted_var = logits.argmax(dim=-1).item()

print(f"Predicted branching variable: {predicted_var}")
```

## Files

- `pytorch_model.bin` - Model weights (~1.2GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)

## When to Use This Model

- **Better generalization** than non-augmented version
- **Production/deployment** with improved accuracy
- **When training data is limited** — augmentation effectively multiplies your data

## Augmentation Code

The augmentation script is available at:
```
Yale-ROSE/Transformer-SAT/new_transformer/augment_sft_dataset.py
```

Usage:
```bash
python augment_sft_dataset.py input.jsonl output.jsonl --multiplier 5
```

## Limitations

- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Trained on specific CNF distribution

## Related Models

- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline
- [Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Higher accuracy, larger model

## Citation

If you use this model, please cite the Transformer-CnC paper.

## License

Apache 2.0