Qwen3-0.6B-SAT-VarSelector-Distilled
A Qwen3-0.6B model fine-tuned for SAT variable selection in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next.
Model Description
This model implements a masked classification head on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions.
Key Features
- Task: Variable selection for SAT Cube-and-Conquer
- Architecture: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600)
- Training: Supervised fine-tuning on distilled expert data
- Output: Integer variable ID to branch on
Training Details
| Attribute | Value |
|---|---|
| Base Model | Qwen/Qwen3-0.6B |
| Training Dataset | Distilled from GPT expert traces |
| Best Checkpoint | Step 410 (Epoch ~6.7) |
| Eval Accuracy | 14.75% |
| Eval Loss | 3.789 |
| Training Time | ~53 minutes (8×H100 GPUs) |
Performance Context
- Random Baseline: ~1-2% accuracy (depends on number of valid variables)
- This Model: 14.75% accuracy = ~7-15× better than random
Hyperparameters
learning_rate: 5e-6
warmup_ratio: 0.1
num_train_epochs: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_length: 8192
max_vars: 600
optimizer: AdamW
scheduler: cosine
deepspeed: ZeRO-3
Usage
Loading the Model
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re
class QwenVarClassifier(nn.Module):
def __init__(self, base_model, max_vars=600):
super().__init__()
self.base = base_model
hidden_size = base_model.config.hidden_size
self.norm = nn.LayerNorm(hidden_size)
self.head = nn.Linear(hidden_size, max_vars + 1)
def forward(self, input_ids, attention_mask=None):
outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden = outputs.hidden_states[-1] # [B, seq, hidden]
# Pool at last non-pad token
if attention_mask is not None:
lengths = attention_mask.sum(dim=1) - 1
pooled = hidden[torch.arange(hidden.size(0)), lengths]
else:
pooled = hidden[:, -1, :]
pooled = self.norm(pooled)
logits = self.head(pooled)
return logits
# Load
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model = QwenVarClassifier(base_model, max_vars=600)
# Load fine-tuned weights
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
Inference
def get_valid_vars(cnf_text, max_vars=600):
"""Extract valid variable IDs from CNF text."""
valid = set()
for line in cnf_text.strip().split('\n'):
if line.startswith('c') or line.startswith('p'):
continue
for tok in line.split():
try:
lit = int(tok)
if lit != 0:
valid.add(abs(lit))
except ValueError:
pass
return valid
def predict_variable(cnf_text, model, tokenizer, max_vars=600):
"""Predict the next variable to branch on."""
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
with torch.no_grad():
logits = model(inputs["input_ids"], inputs["attention_mask"])
# Mask invalid variables
valid_vars = get_valid_vars(cnf_text, max_vars)
mask = torch.zeros(max_vars + 1, dtype=torch.bool)
for v in valid_vars:
if 1 <= v <= max_vars:
mask[v] = True
logits[0, ~mask] = -1e4
predicted_var = logits.argmax(dim=-1).item()
return predicted_var
# Example
cnf_text = """p cnf 100 200
1 -2 3 0
-1 4 -5 0
2 5 6 0
"""
var = predict_variable(cnf_text, model, tokenizer)
print(f"Predicted variable: {var}")
Architecture Details
Why Masked Classification?
The valid action set is state-dependent: not all variables are valid at every step.
- Some variables may be eliminated during simplification
- Some may be out of range for the specific instance
We use masked softmax:
- Model outputs logits for all 601 classes (0-600)
- Invalid variables get logits set to
-1e4 - Softmax only assigns probability to valid variables
- Training uses masked cross-entropy loss
Why Pool the Last Token?
The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation.
Why LayerNorm Before the Head?
Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head.
Limitations
- Maximum 600 variables (configurable during training)
- Maximum sequence length 8192 tokens
- Trained on specific CNF distribution; may not generalize to all SAT instances
- Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly
Citation
@misc{qwen-sat-varselector,
title={Qwen3-0.6B-SAT-VarSelector-Distilled},
author={Yale-ROSE},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled}
}
Related Models
- Yale-ROSE/Qwen3-4B-SAT-VarSelector - Larger 4B parameter version
License
Apache 2.0 (following the base Qwen3 license)