erata's picture
Upload folder using huggingface_hub
7971924 verified
metadata
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
  - sat
  - satisfiability
  - cube-and-conquer
  - variable-selection
  - combinatorial-optimization
datasets:
  - Yale-ROSE/SAT-VarSelector-Distilled
language:
  - en
pipeline_tag: text-classification

Qwen3-0.6B-SAT-VarSelector-Distilled

A Qwen3-0.6B model fine-tuned for SAT variable selection in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next.

Model Description

This model implements a masked classification head on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions.

Key Features

  • Task: Variable selection for SAT Cube-and-Conquer
  • Architecture: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600)
  • Training: Supervised fine-tuning on distilled expert data
  • Output: Integer variable ID to branch on

Training Details

Attribute Value
Base Model Qwen/Qwen3-0.6B
Training Dataset Distilled from GPT expert traces
Best Checkpoint Step 410 (Epoch ~6.7)
Eval Accuracy 14.75%
Eval Loss 3.789
Training Time ~53 minutes (8×H100 GPUs)

Performance Context

  • Random Baseline: ~1-2% accuracy (depends on number of valid variables)
  • This Model: 14.75% accuracy = ~7-15× better than random

Hyperparameters

learning_rate: 5e-6
warmup_ratio: 0.1
num_train_epochs: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_length: 8192
max_vars: 600
optimizer: AdamW
scheduler: cosine
deepspeed: ZeRO-3

Usage

Loading the Model

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re

class QwenVarClassifier(nn.Module):
    def __init__(self, base_model, max_vars=600):
        super().__init__()
        self.base = base_model
        hidden_size = base_model.config.hidden_size
        self.norm = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, max_vars + 1)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]  # [B, seq, hidden]
        
        # Pool at last non-pad token
        if attention_mask is not None:
            lengths = attention_mask.sum(dim=1) - 1
            pooled = hidden[torch.arange(hidden.size(0)), lengths]
        else:
            pooled = hidden[:, -1, :]
        
        pooled = self.norm(pooled)
        logits = self.head(pooled)
        return logits

# Load
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
model = QwenVarClassifier(base_model, max_vars=600)

# Load fine-tuned weights
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

Inference

def get_valid_vars(cnf_text, max_vars=600):
    """Extract valid variable IDs from CNF text."""
    valid = set()
    for line in cnf_text.strip().split('\n'):
        if line.startswith('c') or line.startswith('p'):
            continue
        for tok in line.split():
            try:
                lit = int(tok)
                if lit != 0:
                    valid.add(abs(lit))
            except ValueError:
                pass
    return valid

def predict_variable(cnf_text, model, tokenizer, max_vars=600):
    """Predict the next variable to branch on."""
    inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
    
    with torch.no_grad():
        logits = model(inputs["input_ids"], inputs["attention_mask"])
    
    # Mask invalid variables
    valid_vars = get_valid_vars(cnf_text, max_vars)
    mask = torch.zeros(max_vars + 1, dtype=torch.bool)
    for v in valid_vars:
        if 1 <= v <= max_vars:
            mask[v] = True
    
    logits[0, ~mask] = -1e4
    predicted_var = logits.argmax(dim=-1).item()
    
    return predicted_var

# Example
cnf_text = """p cnf 100 200
1 -2 3 0
-1 4 -5 0
2 5 6 0
"""

var = predict_variable(cnf_text, model, tokenizer)
print(f"Predicted variable: {var}")

Architecture Details

Why Masked Classification?

The valid action set is state-dependent: not all variables are valid at every step.

  • Some variables may be eliminated during simplification
  • Some may be out of range for the specific instance

We use masked softmax:

  1. Model outputs logits for all 601 classes (0-600)
  2. Invalid variables get logits set to -1e4
  3. Softmax only assigns probability to valid variables
  4. Training uses masked cross-entropy loss

Why Pool the Last Token?

The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation.

Why LayerNorm Before the Head?

Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head.

Limitations

  • Maximum 600 variables (configurable during training)
  • Maximum sequence length 8192 tokens
  • Trained on specific CNF distribution; may not generalize to all SAT instances
  • Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly

Citation

@misc{qwen-sat-varselector,
  title={Qwen3-0.6B-SAT-VarSelector-Distilled},
  author={Yale-ROSE},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled}
}

Related Models

License

Apache 2.0 (following the base Qwen3 license)