File size: 3,057 Bytes
6eeb584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
  - SAT
  - combinatorial-optimization
  - classification
  - cube-and-conquer
language:
  - en
pipeline_tag: text-classification
---

# Qwen3-0.6B-SAT-VarSelector

A lightweight Qwen3-0.6B model fine-tuned for **SAT branching variable selection** in Cube-and-Conquer (CnC) solvers.

## Model Description

This model predicts which variable to branch/cube on next, given a SAT CNF formula state. Instead of generating text, it outputs a **classification over variable IDs** (1-500).

### Architecture

- **Base**: `Qwen/Qwen3-0.6B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 501)
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
- **Size**: ~1.2GB (bfloat16)

### Training

- **Dataset**: 3,898 training / 434 validation samples
- **Task**: Predict expert-selected branching variable
- **Training**: 8 epochs, 8×H100 GPUs, DeepSpeed ZeRO-3

### Comparison with 4B Model

| Model | Size | Top-1 Acc | Top-5 Acc | Inference Speed |
|-------|------|-----------|-----------|-----------------|
| Qwen3-4B | 8GB | 24% | 48% | ~150ms/sample |
| **Qwen3-0.6B** | 1.2GB | ~12% | ~32% | ~45ms/sample |

## Usage

```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask

# Load model
model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=500)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=500)], dtype=torch.bool, device="cuda")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    logits = logits.masked_fill(~valid_mask, -1e4)
    predicted_var = logits.argmax(dim=-1).item()

print(f"Predicted branching variable: {predicted_var}")
```

## Files

- `pytorch_model.bin` - Model weights (~1.2GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)

## When to Use

- **Production/Deployment**: Faster inference, smaller memory footprint
- **Edge devices**: Can run on smaller GPUs
- **Rapid prototyping**: Quick experiments
- **CPU inference**: More practical than 4B model

For maximum accuracy, use the [4B model](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector).

## Limitations

- Maximum 500 variables
- Maximum 8192 tokens for CNF input
- Lower accuracy than 4B model
- Trained on specific CNF distribution

## Citation

If you use this model, please cite the Transformer-CnC paper.

## License

Apache 2.0