--- license: apache-2.0 base_model: Qwen/Qwen3-0.6B tags: - SAT - combinatorial-optimization - classification - cube-and-conquer language: - en pipeline_tag: text-classification --- # Qwen3-0.6B-SAT-VarSelector A lightweight Qwen3-0.6B model fine-tuned for **SAT branching variable selection** in Cube-and-Conquer (CnC) solvers. ## Model Description This model predicts which variable to branch/cube on next, given a SAT CNF formula state. Instead of generating text, it outputs a **classification over variable IDs** (1-500). ### Architecture - **Base**: `Qwen/Qwen3-0.6B` (causal language model) - **Head**: LayerNorm → Linear(hidden_size, 501) - **Pooling**: Last non-pad token hidden state - **Masking**: Invalid variables (not in CNF) are masked to -10000 before softmax - **Size**: ~1.2GB (bfloat16) ### Training - **Dataset**: 3,898 training / 434 validation samples - **Task**: Predict expert-selected branching variable - **Training**: 8 epochs, 8×H100 GPUs, DeepSpeed ZeRO-3 ### Comparison with 4B Model | Model | Size | Top-1 Acc | Top-5 Acc | Inference Speed | |-------|------|-----------|-----------|-----------------| | Qwen3-4B | 8GB | 24% | 48% | ~150ms/sample | | **Qwen3-0.6B** | 1.2GB | ~12% | ~32% | ~45ms/sample | ## Usage ```python import torch from transformers import AutoTokenizer from sft_qwen_var_classifier import QwenVarClassifier, cnf_valid_mask # Load model model = QwenVarClassifier("Qwen/Qwen3-0.6B", max_vars=500) state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict, strict=False) model = model.to("cuda", dtype=torch.bfloat16) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") # Prepare CNF input cnf_text = """p cnf 100 250 1 -2 3 0 -1 2 -4 0 ... """ # Tokenize inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192) inputs = {k: v.to("cuda") for k, v in inputs.items()} # Get valid variable mask valid_mask = torch.tensor([cnf_valid_mask(cnf_text, max_vars=500)], dtype=torch.bool, device="cuda") # Predict with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] logits = logits.masked_fill(~valid_mask, -1e4) predicted_var = logits.argmax(dim=-1).item() print(f"Predicted branching variable: {predicted_var}") ``` ## Files - `pytorch_model.bin` - Model weights (~1.2GB, bfloat16) - `sft_qwen_var_classifier.py` - Model class definition (required for loading) ## When to Use - **Production/Deployment**: Faster inference, smaller memory footprint - **Edge devices**: Can run on smaller GPUs - **Rapid prototyping**: Quick experiments - **CPU inference**: More practical than 4B model For maximum accuracy, use the [4B model](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector). ## Limitations - Maximum 500 variables - Maximum 8192 tokens for CNF input - Lower accuracy than 4B model - Trained on specific CNF distribution ## Citation If you use this model, please cite the Transformer-CnC paper. ## License Apache 2.0