File size: 2,707 Bytes
24a71c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
---
license: apache-2.0
base_model: Qwen/Qwen3-4B
tags:
- SAT
- combinatorial-optimization
- classification
- cube-and-conquer
language:
- en
pipeline_tag: text-classification
---
# Qwen3-4B-SAT-VarSelector
A Qwen3-4B model fine-tuned for **SAT branching variable selection** in Cube-and-Conquer (CnC) solvers.
## Model Description
This model predicts which variable to branch/cube on next, given a SAT CNF formula state. Instead of generating text, it outputs a **classification over variable IDs** (1-500).
### Architecture
- **Base**: `Qwen/Qwen3-4B` (causal language model)
- **Head**: LayerNorm → Linear(hidden_size, 501)
- **Pooling**: Last non-pad token hidden state
- **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax
### Training
- **Dataset**: 3,898 training / 434 validation samples
- **Task**: Predict expert-selected branching variable
- **Best validation accuracy**: 16.36% (16x better than random ~1%)
- **Training**: 8 epochs, 8×H100 GPUs, DeepSpeed ZeRO-3
## Usage
```python
import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask
# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=500)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""
# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=500)], dtype=torch.bool, device="cuda")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
logits = logits.masked_fill(~valid_mask, -1e4)
predicted_var = logits.argmax(dim=-1).item()
print(f"Predicted branching variable: {predicted_var}")
```
## Files
- `pytorch_model.bin` - Model weights (8GB, bfloat16)
- `sft_qwen_var_classifier.py` - Model class definition (required for loading)
- `inference_demo.py` - Example inference script
## Metrics
| Metric | Value |
|--------|-------|
| Validation Accuracy | 16.36% |
| Validation Loss | 3.87 |
| Random Baseline | ~1% |
| Improvement | 16x |
## Limitations
- Maximum 500 variables
- Maximum 8192 tokens for CNF input
- Trained on specific CNF distribution (may not generalize to all SAT instances)
## Citation
If you use this model, please cite the Transformer-CnC paper.
## License
Apache 2.0
|