File size: 5,974 Bytes
7971924 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
- sat
- satisfiability
- cube-and-conquer
- variable-selection
- combinatorial-optimization
datasets:
- Yale-ROSE/SAT-VarSelector-Distilled
language:
- en
pipeline_tag: text-classification
---
# Qwen3-0.6B-SAT-VarSelector-Distilled
A **Qwen3-0.6B** model fine-tuned for **SAT variable selection** in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next.
## Model Description
This model implements a **masked classification head** on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions.
### Key Features
- **Task**: Variable selection for SAT Cube-and-Conquer
- **Architecture**: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600)
- **Training**: Supervised fine-tuning on distilled expert data
- **Output**: Integer variable ID to branch on
## Training Details
| Attribute | Value |
|-----------|-------|
| Base Model | `Qwen/Qwen3-0.6B` |
| Training Dataset | Distilled from GPT expert traces |
| Best Checkpoint | Step 410 (Epoch ~6.7) |
| **Eval Accuracy** | **14.75%** |
| Eval Loss | 3.789 |
| Training Time | ~53 minutes (8×H100 GPUs) |
### Performance Context
- **Random Baseline**: ~1-2% accuracy (depends on number of valid variables)
- **This Model**: 14.75% accuracy = **~7-15× better than random**
### Hyperparameters
```yaml
learning_rate: 5e-6
warmup_ratio: 0.1
num_train_epochs: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_length: 8192
max_vars: 600
optimizer: AdamW
scheduler: cosine
deepspeed: ZeRO-3
```
## Usage
### Loading the Model
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re
class QwenVarClassifier(nn.Module):
def __init__(self, base_model, max_vars=600):
super().__init__()
self.base = base_model
hidden_size = base_model.config.hidden_size
self.norm = nn.LayerNorm(hidden_size)
self.head = nn.Linear(hidden_size, max_vars + 1)
def forward(self, input_ids, attention_mask=None):
outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden = outputs.hidden_states[-1] # [B, seq, hidden]
# Pool at last non-pad token
if attention_mask is not None:
lengths = attention_mask.sum(dim=1) - 1
pooled = hidden[torch.arange(hidden.size(0)), lengths]
else:
pooled = hidden[:, -1, :]
pooled = self.norm(pooled)
logits = self.head(pooled)
return logits
# Load
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model = QwenVarClassifier(base_model, max_vars=600)
# Load fine-tuned weights
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
```
### Inference
```python
def get_valid_vars(cnf_text, max_vars=600):
"""Extract valid variable IDs from CNF text."""
valid = set()
for line in cnf_text.strip().split('\n'):
if line.startswith('c') or line.startswith('p'):
continue
for tok in line.split():
try:
lit = int(tok)
if lit != 0:
valid.add(abs(lit))
except ValueError:
pass
return valid
def predict_variable(cnf_text, model, tokenizer, max_vars=600):
"""Predict the next variable to branch on."""
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
with torch.no_grad():
logits = model(inputs["input_ids"], inputs["attention_mask"])
# Mask invalid variables
valid_vars = get_valid_vars(cnf_text, max_vars)
mask = torch.zeros(max_vars + 1, dtype=torch.bool)
for v in valid_vars:
if 1 <= v <= max_vars:
mask[v] = True
logits[0, ~mask] = -1e4
predicted_var = logits.argmax(dim=-1).item()
return predicted_var
# Example
cnf_text = """p cnf 100 200
1 -2 3 0
-1 4 -5 0
2 5 6 0
"""
var = predict_variable(cnf_text, model, tokenizer)
print(f"Predicted variable: {var}")
```
## Architecture Details
### Why Masked Classification?
The valid action set is **state-dependent**: not all variables are valid at every step.
- Some variables may be eliminated during simplification
- Some may be out of range for the specific instance
We use **masked softmax**:
1. Model outputs logits for all 601 classes (0-600)
2. Invalid variables get logits set to `-1e4`
3. Softmax only assigns probability to valid variables
4. Training uses masked cross-entropy loss
### Why Pool the Last Token?
The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation.
### Why LayerNorm Before the Head?
Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head.
## Limitations
- Maximum 600 variables (configurable during training)
- Maximum sequence length 8192 tokens
- Trained on specific CNF distribution; may not generalize to all SAT instances
- Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly
## Citation
```bibtex
@misc{qwen-sat-varselector,
title={Qwen3-0.6B-SAT-VarSelector-Distilled},
author={Yale-ROSE},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled}
}
```
## Related Models
- [Yale-ROSE/Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Larger 4B parameter version
## License
Apache 2.0 (following the base Qwen3 license)
|