File size: 4,322 Bytes
d0da4dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
license: apache-2.0
base_model: Qwen/Qwen3-4B
tags:
  - SAT
  - combinatorial-optimization
  - classification
  - cube-and-conquer
  - data-augmentation
language:
  - en
pipeline_tag: text-classification
---

# Qwen3-4B-SAT-VarSelector-Sym-Aug

A Qwen3-4B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**.

## Model Description

This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, achieving **21.8% top-1 accuracy** (vs 19% for Qwen3-0.6B).

### Architecture

- **Base**: `Qwen/Qwen3-4B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 601)
- **Max Variables**: 600
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~8GB (bfloat16)

### Training with Symmetry Augmentation

This model was trained with **5x data augmentation** using semantically-safe CNF transformations:

| Augmentation | Description | Effect |
|-------------|-------------|--------|
| **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers |
| **Clause Shuffling** | Random reordering of clauses | Teaches position-independence |
| **Literal Reordering** | Shuffle literals within clauses | Token-level variation |
| **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features |

### Training Details

| Parameter | Value |
|-----------|-------|
| Original training samples | 8,110 |
| Augmented training samples | **40,550** (5x) |
| Validation samples | 902 (unaugmented) |
| Epochs | 3 |
| Hardware | 8×H100 GPUs |
| Training framework | DeepSpeed ZeRO-3 |
| Peak learning rate | 5e-6 |
| Training time | ~4 hours |
| Best checkpoint | Step 1850 (epoch 2.92) |

### Performance Comparison

| Model | Parameters | Training Data | Top-1 Accuracy |
|-------|-----------|--------------|----------------|
| Qwen3-0.6B (baseline) | 600M | 8,110 samples | ~12% |
| Qwen3-0.6B (augmented) | 600M | 40,550 samples | ~19% |
| **Qwen3-4B (augmented)** | **4B** | **40,550 samples** | **~22%** |

### Training Curve Highlights

- Peak accuracy: **22.0%** at epoch 2.76
- Final accuracy: **21.8%** at epoch 2.92
- Eval loss: 3.35 (vs 3.37 for 0.6B)

## Usage

```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask

# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    logits = logits.masked_fill(~valid_mask, -1e4)
    predicted_var = logits.argmax(dim=-1).item()

print(f"Predicted branching variable: {predicted_var}")
```

## Files

- `pytorch_model.bin` - Model weights (~8GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)

## When to Use This Model

- **Higher accuracy** than 0.6B version (+3pp)
- **Production use** when accuracy matters more than speed
- **Cube-and-Conquer** style SAT solving

## Limitations

- Maximum 600 variables
- Maximum 8192 tokens for CNF input
- Larger model size (~8GB vs 1.2GB for 0.6B)
- Slower inference (~6x slower than 0.6B)

## Related Models

- [Qwen3-0.6B-SAT-VarSelector-Sym-Aug](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Sym-Aug) - Smaller, faster version
- [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline

## Citation

If you use this model, please cite the Transformer-CnC paper.

## License

Apache 2.0