--- license: apache-2.0 base_model: Qwen/Qwen3-0.6B tags: - SAT - combinatorial-optimization - classification - cube-and-conquer - data-augmentation language: - en pipeline_tag: text-classification --- # Qwen3-0.6B-SAT-VarSelector-Sym-Aug A Qwen3-0.6B model fine-tuned for **SAT branching variable selection** using **symmetry-based data augmentation**. ## Model Description This model predicts which variable to branch/cube on next, given a SAT CNF formula state. It was trained with **5x augmented data** using CNF symmetry transformations, resulting in significantly improved generalization. ### Architecture - **Base**: `Qwen/Qwen3-0.6B` (causal language model) - **Head**: LayerNorm → Linear(hidden_size, 601) - **Max Variables**: 600 - **Pooling**: Last non-pad token hidden state - **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax - **Size**: ~1.2GB (bfloat16) ### Training with Symmetry Augmentation This model was trained with **5x data augmentation** using semantically-safe CNF transformations: | Augmentation | Description | Effect | |-------------|-------------|--------| | **Variable Permutation** | Bijective remapping of variable IDs | Prevents memorizing specific variable numbers | | **Clause Shuffling** | Random reordering of clauses | Teaches position-independence | | **Literal Reordering** | Shuffle literals within clauses | Token-level variation | | **Polarity Flipping** | Flip signs of random variable subset | Teaches structural vs. polarity features | ### Training Details | Parameter | Value | |-----------|-------| | Original training samples | 8,110 | | Augmented training samples | **40,550** (5x) | | Validation samples | 902 (unaugmented) | | Epochs | 3 | | Hardware | 8×H100 GPUs | | Training framework | DeepSpeed ZeRO-3 | | Peak learning rate | 5e-6 | | Best checkpoint | Step 1800 (epoch 2.84) | ### Performance Comparison | Model | Training Data | Top-1 Accuracy | Top-5 Accuracy | |-------|--------------|----------------|----------------| | Qwen3-0.6B (baseline) | 8,110 samples | ~12% | ~32% | | **Qwen3-0.6B (augmented)** | **40,550 samples** | **~19%** | **~42%** | | Improvement | +5x data | **+7pp** | **+10pp** | ### Key Insight: Why Validation Loss < Training Loss During augmented training, you'll observe **validation loss consistently lower than training loss**. This is expected and indicates the augmentation is working: 1. **Training data is harder** — augmented CNFs with permuted variables, shuffled clauses 2. **Validation data is clean** — original CNFs without transformations 3. **Model generalizes well** — learned structural patterns, not memorized examples ## Usage ```python import torch from transformers import AutoTokenizer from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask # Load model model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=600) state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict, strict=False) model = model.to("cuda", dtype=torch.bfloat16) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") # Prepare CNF input cnf_text = """p cnf 100 250 1 -2 3 0 -1 2 -4 0 ... """ # Tokenize inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192) inputs = {k: v.to("cuda") for k, v in inputs.items()} # Get valid variable mask valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=600)], dtype=torch.bool, device="cuda") # Predict with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] logits = logits.masked_fill(~valid_mask, -1e4) predicted_var = logits.argmax(dim=-1).item() print(f"Predicted branching variable: {predicted_var}") ``` ## Files - `pytorch_model.bin` - Model weights (~1.2GB, bfloat16) - `sft_qwen_var_classifier.py` - Model class definition (required for loading) ## When to Use This Model - **Better generalization** than non-augmented version - **Production/deployment** with improved accuracy - **When training data is limited** — augmentation effectively multiplies your data ## Augmentation Code The augmentation script is available at: ``` Yale-ROSE/Transformer-SAT/new_transformer/augment_sft_dataset.py ``` Usage: ```bash python augment_sft_dataset.py input.jsonl output.jsonl --multiplier 5 ``` ## Limitations - Maximum 600 variables - Maximum 8192 tokens for CNF input - Trained on specific CNF distribution ## Related Models - [Qwen3-0.6B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector) - Non-augmented baseline - [Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Higher accuracy, larger model ## Citation If you use this model, please cite the Transformer-CnC paper. ## License Apache 2.0