|
|
--- |
|
|
license: apache-2.0 |
|
|
base_model: Qwen/Qwen3-0.6B |
|
|
tags: |
|
|
- sat |
|
|
- satisfiability |
|
|
- cube-and-conquer |
|
|
- variable-selection |
|
|
- combinatorial-optimization |
|
|
datasets: |
|
|
- Yale-ROSE/SAT-VarSelector-Distilled |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
# Qwen3-0.6B-SAT-VarSelector-Distilled |
|
|
|
|
|
A **Qwen3-0.6B** model fine-tuned for **SAT variable selection** in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model implements a **masked classification head** on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions. |
|
|
|
|
|
### Key Features |
|
|
|
|
|
- **Task**: Variable selection for SAT Cube-and-Conquer |
|
|
- **Architecture**: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600) |
|
|
- **Training**: Supervised fine-tuning on distilled expert data |
|
|
- **Output**: Integer variable ID to branch on |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Attribute | Value | |
|
|
|-----------|-------| |
|
|
| Base Model | `Qwen/Qwen3-0.6B` | |
|
|
| Training Dataset | Distilled from GPT expert traces | |
|
|
| Best Checkpoint | Step 410 (Epoch ~6.7) | |
|
|
| **Eval Accuracy** | **14.75%** | |
|
|
| Eval Loss | 3.789 | |
|
|
| Training Time | ~53 minutes (8×H100 GPUs) | |
|
|
|
|
|
### Performance Context |
|
|
|
|
|
- **Random Baseline**: ~1-2% accuracy (depends on number of valid variables) |
|
|
- **This Model**: 14.75% accuracy = **~7-15× better than random** |
|
|
|
|
|
### Hyperparameters |
|
|
|
|
|
```yaml |
|
|
learning_rate: 5e-6 |
|
|
warmup_ratio: 0.1 |
|
|
num_train_epochs: 8 |
|
|
per_device_train_batch_size: 1 |
|
|
gradient_accumulation_steps: 8 |
|
|
max_length: 8192 |
|
|
max_vars: 600 |
|
|
optimizer: AdamW |
|
|
scheduler: cosine |
|
|
deepspeed: ZeRO-3 |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import re |
|
|
|
|
|
class QwenVarClassifier(nn.Module): |
|
|
def __init__(self, base_model, max_vars=600): |
|
|
super().__init__() |
|
|
self.base = base_model |
|
|
hidden_size = base_model.config.hidden_size |
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
self.head = nn.Linear(hidden_size, max_vars + 1) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
hidden = outputs.hidden_states[-1] # [B, seq, hidden] |
|
|
|
|
|
# Pool at last non-pad token |
|
|
if attention_mask is not None: |
|
|
lengths = attention_mask.sum(dim=1) - 1 |
|
|
pooled = hidden[torch.arange(hidden.size(0)), lengths] |
|
|
else: |
|
|
pooled = hidden[:, -1, :] |
|
|
|
|
|
pooled = self.norm(pooled) |
|
|
logits = self.head(pooled) |
|
|
return logits |
|
|
|
|
|
# Load |
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") |
|
|
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") |
|
|
model = QwenVarClassifier(base_model, max_vars=600) |
|
|
|
|
|
# Load fine-tuned weights |
|
|
state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
def get_valid_vars(cnf_text, max_vars=600): |
|
|
"""Extract valid variable IDs from CNF text.""" |
|
|
valid = set() |
|
|
for line in cnf_text.strip().split('\n'): |
|
|
if line.startswith('c') or line.startswith('p'): |
|
|
continue |
|
|
for tok in line.split(): |
|
|
try: |
|
|
lit = int(tok) |
|
|
if lit != 0: |
|
|
valid.add(abs(lit)) |
|
|
except ValueError: |
|
|
pass |
|
|
return valid |
|
|
|
|
|
def predict_variable(cnf_text, model, tokenizer, max_vars=600): |
|
|
"""Predict the next variable to branch on.""" |
|
|
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(inputs["input_ids"], inputs["attention_mask"]) |
|
|
|
|
|
# Mask invalid variables |
|
|
valid_vars = get_valid_vars(cnf_text, max_vars) |
|
|
mask = torch.zeros(max_vars + 1, dtype=torch.bool) |
|
|
for v in valid_vars: |
|
|
if 1 <= v <= max_vars: |
|
|
mask[v] = True |
|
|
|
|
|
logits[0, ~mask] = -1e4 |
|
|
predicted_var = logits.argmax(dim=-1).item() |
|
|
|
|
|
return predicted_var |
|
|
|
|
|
# Example |
|
|
cnf_text = """p cnf 100 200 |
|
|
1 -2 3 0 |
|
|
-1 4 -5 0 |
|
|
2 5 6 0 |
|
|
""" |
|
|
|
|
|
var = predict_variable(cnf_text, model, tokenizer) |
|
|
print(f"Predicted variable: {var}") |
|
|
``` |
|
|
|
|
|
## Architecture Details |
|
|
|
|
|
### Why Masked Classification? |
|
|
|
|
|
The valid action set is **state-dependent**: not all variables are valid at every step. |
|
|
|
|
|
- Some variables may be eliminated during simplification |
|
|
- Some may be out of range for the specific instance |
|
|
|
|
|
We use **masked softmax**: |
|
|
1. Model outputs logits for all 601 classes (0-600) |
|
|
2. Invalid variables get logits set to `-1e4` |
|
|
3. Softmax only assigns probability to valid variables |
|
|
4. Training uses masked cross-entropy loss |
|
|
|
|
|
### Why Pool the Last Token? |
|
|
|
|
|
The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation. |
|
|
|
|
|
### Why LayerNorm Before the Head? |
|
|
|
|
|
Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head. |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Maximum 600 variables (configurable during training) |
|
|
- Maximum sequence length 8192 tokens |
|
|
- Trained on specific CNF distribution; may not generalize to all SAT instances |
|
|
- Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{qwen-sat-varselector, |
|
|
title={Qwen3-0.6B-SAT-VarSelector-Distilled}, |
|
|
author={Yale-ROSE}, |
|
|
year={2026}, |
|
|
publisher={Hugging Face}, |
|
|
url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Related Models |
|
|
|
|
|
- [Yale-ROSE/Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Larger 4B parameter version |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 (following the base Qwen3 license) |
|
|
|