Qwen3-4B-SAT-VarSelector-Sym-Aug-8-Epoch

A Qwen3-4B-based variable selector for SAT Cube-and-Conquer solving. This model predicts the optimal branching variable given a CNF formula state, enabling learned splitting strategies for parallel SAT solving.

Model Description

This model is trained for variable selection in SAT Cube-and-Conquer (CnC) solving. Given a simplified CNF formula, it outputs the variable ID that should be branched on next.

Architecture

  • Backbone: Qwen3-4B (3.8B parameters, causal language model)
  • Head: LayerNorm + Linear classifier
  • Output: Logits over variable IDs (1 to 600)
  • Pooling: Last non-padding token hidden state
Input (CNF text) → Tokenize → Qwen3-4B → Last Token Hidden State → LayerNorm → Linear → Variable Logits

Training Approach

  • Supervised Fine-Tuning (SFT) on expert variable selections
  • Masked Classification: Only variables appearing in the CNF are valid outputs
  • Loss: Cross-entropy with invalid variable logits masked to -∞

Training Details

Parameter Value
Base Model Qwen/Qwen3-4B
Training Samples 40,550 (8,110 original × 5x augmentation)
Validation Samples 902
Epochs 8
Batch Size 64 effective (1 per GPU × 8 GPUs × 8 gradient accumulation)
Learning Rate 5e-6
Warmup Ratio 0.1
Max Sequence Length 8,192 tokens
Max Variables 600
Precision bfloat16
Hardware 8× H100 GPUs
Distributed Training DeepSpeed ZeRO-3

Data Augmentation

The training data was augmented 5× using semantically-safe transformations:

  1. Variable Permutation (50%): Bijective remapping of variable IDs
  2. Clause Shuffling (100%): Random reordering of clauses
  3. Literal Reordering (100%): Shuffle literals within each clause
  4. Polarity Flipping (30%): Flip polarity of random variable subset

Usage

Loading the Model

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

class QwenVarClassifier(nn.Module):
    def __init__(self, base_model_name: str, max_vars: int = 600):
        super().__init__()
        self.max_vars = max_vars
        
        cfg = AutoConfig.from_pretrained(base_model_name)
        cfg.output_hidden_states = True
        
        self.backbone = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            config=cfg,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        )
        
        hidden = self.backbone.config.hidden_size
        self.head_ln = nn.LayerNorm(hidden)
        self.head = nn.Linear(hidden, max_vars + 1)
        self.config = self.backbone.config

    def forward(self, input_ids, attention_mask, **kwargs):
        out = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
        )
        h = out.hidden_states[-1]
        last_idx = attention_mask.sum(dim=1) - 1
        last_idx = last_idx.clamp(min=0)
        b = torch.arange(h.size(0), device=h.device)
        pooled = h[b, last_idx]
        pooled = self.head_ln(pooled)
        logits = self.head(pooled)
        return {"logits": logits}

# Load model
model = QwenVarClassifier("Qwen/Qwen3-4B", max_vars=600)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

Inference

def cnf_valid_mask(cnf_text: str, max_vars: int = 600):
    """Build mask of valid variables from CNF text."""
    mask = [0] * (max_vars + 1)
    for line in cnf_text.split('\n'):
        line = line.strip()
        if not line or line.startswith('c') or line.startswith('p'):
            continue
        for tok in line.split():
            try:
                lit = int(tok)
                v = abs(lit)
                if 1 <= v <= max_vars:
                    mask[v] = 1
            except ValueError:
                continue
    if sum(mask) == 0:
        for v in range(1, max_vars + 1):
            mask[v] = 1
    return mask

# Example CNF (DIMACS format)
cnf_text = """p cnf 100 3
1 -2 3 0
-1 4 5 0
2 -3 -4 0"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, 600)], dtype=torch.bool)

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    
    # Mask invalid variables
    logits = logits.masked_fill(~valid_mask, -1e4)
    
    # Get prediction
    predicted_var = logits.argmax(dim=-1).item()
    
print(f"Predicted branching variable: {predicted_var}")

Input Format

The model expects CNF formulas in DIMACS format:

p cnf <num_vars> <num_clauses>
<lit1> <lit2> ... 0
<lit1> <lit2> ... 0
...
  • Header line: p cnf <variables> <clauses>
  • Each clause is a space-separated list of literals ending with 0
  • Literals are integers: positive = variable, negative = negated variable
  • Example: 1 -2 3 0 represents (x₁ ∨ ¬x₂ ∨ x₃)

Evaluation

The model is evaluated using exact-match accuracy: the predicted variable must exactly match the expert's choice.

Random baseline for ~100 valid variables: ~1%

Intended Use

This model is designed for:

  • SAT Cube-and-Conquer solving: Selecting split variables for parallel SAT solving
  • Research: Studying learned heuristics for combinatorial optimization

Limitations

  • Maximum 600 variables (configurable at training time)
  • Maximum 8,192 tokens input length
  • Trained on specific CNF distribution; may not generalize to all SAT instances
  • Requires valid variable masking for correct inference

Citation

If you use this model, please cite:

@misc{qwen3-sat-varselector,
  title={Qwen3-4B-SAT-VarSelector: Learned Variable Selection for SAT Cube-and-Conquer},
  author={Yale-ROSE},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector-Sym-Aug-8-Epoch}
}

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Yale-ROSE/Qwen3-4B-SAT-VarSelector-Sym-Aug-8-Epoch

Base model

Qwen/Qwen3-4B-Base
Finetuned
Qwen/Qwen3-4B
Finetuned
(422)
this model