File size: 5,974 Bytes
7971924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
  - sat
  - satisfiability
  - cube-and-conquer
  - variable-selection
  - combinatorial-optimization
datasets:
  - Yale-ROSE/SAT-VarSelector-Distilled
language:
  - en
pipeline_tag: text-classification
---

# Qwen3-0.6B-SAT-VarSelector-Distilled

A **Qwen3-0.6B** model fine-tuned for **SAT variable selection** in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next.

## Model Description

This model implements a **masked classification head** on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions.

### Key Features

- **Task**: Variable selection for SAT Cube-and-Conquer
- **Architecture**: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600)
- **Training**: Supervised fine-tuning on distilled expert data
- **Output**: Integer variable ID to branch on

## Training Details

| Attribute | Value |
|-----------|-------|
| Base Model | `Qwen/Qwen3-0.6B` |
| Training Dataset | Distilled from GPT expert traces |
| Best Checkpoint | Step 410 (Epoch ~6.7) |
| **Eval Accuracy** | **14.75%** |
| Eval Loss | 3.789 |
| Training Time | ~53 minutes (8×H100 GPUs) |

### Performance Context

- **Random Baseline**: ~1-2% accuracy (depends on number of valid variables)
- **This Model**: 14.75% accuracy = **~7-15× better than random**

### Hyperparameters

```yaml
learning_rate: 5e-6
warmup_ratio: 0.1
num_train_epochs: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_length: 8192
max_vars: 600
optimizer: AdamW
scheduler: cosine
deepspeed: ZeRO-3
```

## Usage

### Loading the Model

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re

class QwenVarClassifier(nn.Module):
    def __init__(self, base_model, max_vars=600):
        super().__init__()
        self.base = base_model
        hidden_size = base_model.config.hidden_size
        self.norm = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, max_vars + 1)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]  # [B, seq, hidden]
        
        # Pool at last non-pad token
        if attention_mask is not None:
            lengths = attention_mask.sum(dim=1) - 1
            pooled = hidden[torch.arange(hidden.size(0)), lengths]
        else:
            pooled = hidden[:, -1, :]
        
        pooled = self.norm(pooled)
        logits = self.head(pooled)
        return logits

# Load
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model = QwenVarClassifier(base_model, max_vars=600)

# Load fine-tuned weights
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
```

### Inference

```python
def get_valid_vars(cnf_text, max_vars=600):
    """Extract valid variable IDs from CNF text."""
    valid = set()
    for line in cnf_text.strip().split('\n'):
        if line.startswith('c') or line.startswith('p'):
            continue
        for tok in line.split():
            try:
                lit = int(tok)
                if lit != 0:
                    valid.add(abs(lit))
            except ValueError:
                pass
    return valid

def predict_variable(cnf_text, model, tokenizer, max_vars=600):
    """Predict the next variable to branch on."""
    inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
    
    with torch.no_grad():
        logits = model(inputs["input_ids"], inputs["attention_mask"])
    
    # Mask invalid variables
    valid_vars = get_valid_vars(cnf_text, max_vars)
    mask = torch.zeros(max_vars + 1, dtype=torch.bool)
    for v in valid_vars:
        if 1 <= v <= max_vars:
            mask[v] = True
    
    logits[0, ~mask] = -1e4
    predicted_var = logits.argmax(dim=-1).item()
    
    return predicted_var

# Example
cnf_text = """p cnf 100 200
1 -2 3 0
-1 4 -5 0
2 5 6 0
"""

var = predict_variable(cnf_text, model, tokenizer)
print(f"Predicted variable: {var}")
```

## Architecture Details

### Why Masked Classification?

The valid action set is **state-dependent**: not all variables are valid at every step.

- Some variables may be eliminated during simplification
- Some may be out of range for the specific instance

We use **masked softmax**:
1. Model outputs logits for all 601 classes (0-600)
2. Invalid variables get logits set to `-1e4`
3. Softmax only assigns probability to valid variables
4. Training uses masked cross-entropy loss

### Why Pool the Last Token?

The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation.

### Why LayerNorm Before the Head?

Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head.

## Limitations

- Maximum 600 variables (configurable during training)
- Maximum sequence length 8192 tokens
- Trained on specific CNF distribution; may not generalize to all SAT instances
- Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly

## Citation

```bibtex
@misc{qwen-sat-varselector,
  title={Qwen3-0.6B-SAT-VarSelector-Distilled},
  author={Yale-ROSE},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled}
}
```

## Related Models

- [Yale-ROSE/Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Larger 4B parameter version

## License

Apache 2.0 (following the base Qwen3 license)