--- license: apache-2.0 base_model: Qwen/Qwen3-0.6B tags: - sat - satisfiability - cube-and-conquer - variable-selection - combinatorial-optimization datasets: - Yale-ROSE/SAT-VarSelector-Distilled language: - en pipeline_tag: text-classification --- # Qwen3-0.6B-SAT-VarSelector-Distilled A **Qwen3-0.6B** model fine-tuned for **SAT variable selection** in the Cube-and-Conquer (CnC) framework. Given a CNF formula state, the model predicts which variable to branch/cube on next. ## Model Description This model implements a **masked classification head** on top of Qwen3-0.6B to select branching variables for SAT solving. Unlike traditional heuristics (e.g., VSIDS), it learns from expert solver traces to make informed variable selection decisions. ### Key Features - **Task**: Variable selection for SAT Cube-and-Conquer - **Architecture**: Qwen3-0.6B backbone + classification head (601 classes for variables 0-600) - **Training**: Supervised fine-tuning on distilled expert data - **Output**: Integer variable ID to branch on ## Training Details | Attribute | Value | |-----------|-------| | Base Model | `Qwen/Qwen3-0.6B` | | Training Dataset | Distilled from GPT expert traces | | Best Checkpoint | Step 410 (Epoch ~6.7) | | **Eval Accuracy** | **14.75%** | | Eval Loss | 3.789 | | Training Time | ~53 minutes (8×H100 GPUs) | ### Performance Context - **Random Baseline**: ~1-2% accuracy (depends on number of valid variables) - **This Model**: 14.75% accuracy = **~7-15× better than random** ### Hyperparameters ```yaml learning_rate: 5e-6 warmup_ratio: 0.1 num_train_epochs: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 8 max_length: 8192 max_vars: 600 optimizer: AdamW scheduler: cosine deepspeed: ZeRO-3 ``` ## Usage ### Loading the Model ```python from transformers import AutoTokenizer, AutoModelForCausalLM import torch import torch.nn as nn import re class QwenVarClassifier(nn.Module): def __init__(self, base_model, max_vars=600): super().__init__() self.base = base_model hidden_size = base_model.config.hidden_size self.norm = nn.LayerNorm(hidden_size) self.head = nn.Linear(hidden_size, max_vars + 1) def forward(self, input_ids, attention_mask=None): outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True) hidden = outputs.hidden_states[-1] # [B, seq, hidden] # Pool at last non-pad token if attention_mask is not None: lengths = attention_mask.sum(dim=1) - 1 pooled = hidden[torch.arange(hidden.size(0)), lengths] else: pooled = hidden[:, -1, :] pooled = self.norm(pooled) logits = self.head(pooled) return logits # Load tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") model = QwenVarClassifier(base_model, max_vars=600) # Load fine-tuned weights state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) model.eval() ``` ### Inference ```python def get_valid_vars(cnf_text, max_vars=600): """Extract valid variable IDs from CNF text.""" valid = set() for line in cnf_text.strip().split('\n'): if line.startswith('c') or line.startswith('p'): continue for tok in line.split(): try: lit = int(tok) if lit != 0: valid.add(abs(lit)) except ValueError: pass return valid def predict_variable(cnf_text, model, tokenizer, max_vars=600): """Predict the next variable to branch on.""" inputs = tokenizer(cnf_text, return_tensors="pt", truncation=True, max_length=8192) with torch.no_grad(): logits = model(inputs["input_ids"], inputs["attention_mask"]) # Mask invalid variables valid_vars = get_valid_vars(cnf_text, max_vars) mask = torch.zeros(max_vars + 1, dtype=torch.bool) for v in valid_vars: if 1 <= v <= max_vars: mask[v] = True logits[0, ~mask] = -1e4 predicted_var = logits.argmax(dim=-1).item() return predicted_var # Example cnf_text = """p cnf 100 200 1 -2 3 0 -1 4 -5 0 2 5 6 0 """ var = predict_variable(cnf_text, model, tokenizer) print(f"Predicted variable: {var}") ``` ## Architecture Details ### Why Masked Classification? The valid action set is **state-dependent**: not all variables are valid at every step. - Some variables may be eliminated during simplification - Some may be out of range for the specific instance We use **masked softmax**: 1. Model outputs logits for all 601 classes (0-600) 2. Invalid variables get logits set to `-1e4` 3. Softmax only assigns probability to valid variables 4. Training uses masked cross-entropy loss ### Why Pool the Last Token? The last non-pad token has attended to the entire CNF sequence through causal attention, making it a natural summary representation. ### Why LayerNorm Before the Head? Qwen's hidden states can have large magnitudes. LayerNorm stabilizes the input to the classification head. ## Limitations - Maximum 600 variables (configurable during training) - Maximum sequence length 8192 tokens - Trained on specific CNF distribution; may not generalize to all SAT instances - Accuracy metric is strict exact-match; the model may predict "good" variables even when not matching the expert label exactly ## Citation ```bibtex @misc{qwen-sat-varselector, title={Qwen3-0.6B-SAT-VarSelector-Distilled}, author={Yale-ROSE}, year={2026}, publisher={Hugging Face}, url={https://huggingface.co/Yale-ROSE/Qwen3-0.6B-SAT-VarSelector-Distilled} } ``` ## Related Models - [Yale-ROSE/Qwen3-4B-SAT-VarSelector](https://huggingface.co/Yale-ROSE/Qwen3-4B-SAT-VarSelector) - Larger 4B parameter version ## License Apache 2.0 (following the base Qwen3 license)