Qwen3-4B-SAT-VarSelector-Sym-Aug

A Qwen3-4B model fine-tuned for SAT branching variable selection using symmetry-based data augmentation.

Model Description

This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with 5x augmented data using CNF symmetry transformations, achieving 21.8% top-1 accuracy (vs 19% for Qwen3-0.6B).

Architecture

  • Base: Qwen/Qwen3-4B (causal language model)
  • Head: LayerNorm → Linear(hidden_size, 601)
  • Max Variables: 600
  • Pooling: Last non-pad token hidden state
  • Masking: Invalid variables (not in CNF) are masked to -10000 before softmax
  • Size: ~8GB (bfloat16)

Training with Symmetry Augmentation

This model was trained with 5x data augmentation using semantically-safe CNF transformations:

Augmentation Description Effect
Variable Permutation Bijective remapping of variable IDs Prevents memorizing specific variable numbers
Clause Shuffling Random reordering of clauses Teaches position-independence
Literal Reordering Shuffle literals within clauses Token-level variation
Polarity Flipping Flip signs of random variable subset Teaches structural vs. polarity features

Training Details

Parameter Value
Original training samples 8,110
Augmented training samples 40,550 (5x)
Validation samples 902 (unaugmented)
Epochs 3
Hardware 8×H100 GPUs
Training framework DeepSpeed ZeRO-3
Peak learning rate 5e-6
Training time ~4 hours
Best checkpoint Step 1850 (epoch 2.92)

Performance Comparison

Model Parameters Training Data Top-1 Accuracy
Qwen3-0.6B (baseline) 600M 8,110 samples ~12%
Qwen3-0.6B (augmented) 600M 40,550 samples ~19%
Qwen3-4B (augmented) 4B 40,550 samples ~22%

Training Curve Highlights

  • Peak accuracy: 22.0% at epoch 2.76
  • Final accuracy: 21.8% at epoch 2.92
  • Eval loss: 3.35 (vs 3.37 for 0.6B)

Usage

import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask

# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    logits = logits.masked_fill(~valid_mask, -1e4)
    predicted_var = logits.argmax(dim=-1).item()

print(f"Predicted branching variable: {predicted_var}")

Files

  • pytorch_model.bin - Model weights (~8GB, bfloat16)
  • sft_qwen_var_classifier.py - Model class definition (required for loading)

When to Use This Model

  • Higher accuracy than 0.6B version (+3pp)
  • Production use when accuracy matters more than speed
  • Cube-and-Conquer style SAT solving

Limitations

  • Maximum 600 variables
  • Maximum 8192 tokens for CNF input
  • Larger model size (~8GB vs 1.2GB for 0.6B)
  • Slower inference (~6x slower than 0.6B)

Related Models

Citation

If you use this model, please cite the Transformer-CnC paper.

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Yale-ROSE/Qwen3-4B-SAT-VarSelector-Sym-Aug

Base model

Qwen/Qwen3-4B-Base
Finetuned
Qwen/Qwen3-4B
Finetuned
(418)
this model