erata's picture
Upload folder using huggingface_hub
ba25603 verified
metadata
license: apache-2.0
base_model: Qwen/Qwen3-0.6B
tags:
  - SAT
  - combinatorial-optimization
  - classification
  - cube-and-conquer
  - data-augmentation
language:
  - en
pipeline_tag: text-classification

Qwen3-0.6B-SAT-VarSelector-Sym-Aug

A Qwen3-0.6B model fine-tuned for SAT branching variable selection using symmetry-based data augmentation.

Model Description

This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with 5x augmented data using CNF symmetry transformations, resulting in significantly improved generalization.

Architecture

  • Base: Qwen/Qwen3-0.6B (causal language model)
  • Head: LayerNorm → Linear(hidden_size, 601)
  • Max Variables: 600
  • Pooling: Last non-pad token hidden state
  • Masking: Invalid variables (not in CNF) are masked to -10000 before softmax
  • Size: ~1.2GB (bfloat16)

Training with Symmetry Augmentation

This model was trained with 5x data augmentation using semantically-safe CNF transformations:

Augmentation Description Effect
Variable Permutation Bijective remapping of variable IDs Prevents memorizing specific variable numbers
Clause Shuffling Random reordering of clauses Teaches position-independence
Literal Reordering Shuffle literals within clauses Token-level variation
Polarity Flipping Flip signs of random variable subset Teaches structural vs. polarity features

Training Details

Parameter Value
Original training samples 8,110
Augmented training samples 40,550 (5x)
Validation samples 902 (unaugmented)
Epochs 3
Hardware 8×H100 GPUs
Training framework DeepSpeed ZeRO-3
Peak learning rate 5e-6
Best checkpoint Step 1800 (epoch 2.84)

Performance Comparison

Model Training Data Top-1 Accuracy Top-5 Accuracy
Qwen3-0.6B (baseline) 8,110 samples ~12% ~32%
Qwen3-0.6B (augmented) 40,550 samples ~19% ~42%
Improvement +5x data +7pp +10pp

Key Insight: Why Validation Loss < Training Loss

During augmented training, you'll observe validation loss consistently lower than training loss. This is expected and indicates the augmentation is working:

  1. Training data is harder — augmented CNFs with permuted variables, shuffled clauses
  2. Validation data is clean — original CNFs without transformations
  3. Model generalizes well — learned structural patterns, not memorized examples

Usage

import torch
from transformers import AutoTokenizer
from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask

# Load model
model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=600)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda", dtype=torch.bfloat16)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Prepare CNF input
cnf_text = """p cnf 100 250
1 -2 3 0
-1 2 -4 0
...
"""

# Tokenize
inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Get valid variable mask
valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs["logits"]
    logits = logits.masked_fill(~valid_mask, -1e4)
    predicted_var = logits.argmax(dim=-1).item()

print(f"Predicted branching variable: {predicted_var}")

Files

  • pytorch_model.bin - Model weights (~1.2GB, bfloat16)
  • sft_qwen_var_classifier.py - Model class definition (required for loading)

When to Use This Model

  • Better generalization than non-augmented version
  • Production/deployment with improved accuracy
  • When training data is limited — augmentation effectively multiplies your data

Augmentation Code

The augmentation script is available at:

Yale-ROSE/Transformer-SAT/new_transformer/augment_sft_dataset.py

Usage:

python augment_sft_dataset.py input.jsonl output.jsonl --multiplier 5

Limitations

  • Maximum 600 variables
  • Maximum 8192 tokens for CNF input
  • Trained on specific CNF distribution

Related Models

Citation

If you use this model, please cite the Transformer-CnC paper.

License

Apache 2.0