| """ |
| Qwen Variable Classifier for SAT Cube-and-Conquer |
| |
| This script trains a transformer-based policy to select the next branching variable |
| for SAT (Boolean Satisfiability) solving using the Cube-and-Conquer approach. |
| |
| == Problem Overview == |
| In Cube-and-Conquer SAT solving, we split a hard SAT problem into subproblems ("cubes") |
| by choosing variables to branch on. The quality of variable selection significantly |
| affects solving performance. This model learns to predict good branching variables |
| from expert demonstrations. |
| |
| == Architecture == |
| - Backbone: Qwen3-4B (pretrained causal language model) |
| - Head: LayerNorm + Linear classifier over variable IDs (1 to max_vars) |
| - The model reads a CNF formula as text and outputs logits for each possible variable |
| |
| == Training Approach == |
| - Supervised Fine-Tuning (SFT) on expert variable choices |
| - Masked classification: only variables appearing in the CNF are valid choices |
| - Loss: Cross-entropy with invalid variable logits masked to -infinity |
| |
| == Data Format == |
| JSONL with fields: |
| - "cnf": DIMACS-format CNF text (e.g., "p cnf 100 200\n1 -2 3 0\n...") |
| - "label": integer variable ID to branch on (1 to max_vars) |
| """ |
|
|
| import os |
| import argparse |
| from dataclasses import dataclass |
| from typing import Any, Dict, List |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import ( |
| AutoConfig, |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainingArguments, |
| Trainer, |
| set_seed, |
| ) |
|
|
|
|
| |
| |
| |
| |
| DEBUG_TRAINING = os.environ.get("DEBUG_TRAINING", "0") == "1" |
|
|
|
|
| |
| |
| |
|
|
| def cnf_valid_mask(cnf_text: str, max_vars: int) -> List[int]: |
| """ |
| Build a binary mask indicating which variable IDs appear in the CNF. |
| |
| This is crucial for masked classification: |
| - A variable that doesn't appear in the (simplified) CNF cannot be branched on |
| - By masking invalid variables, we ensure the model only learns over valid choices |
| |
| Args: |
| cnf_text: DIMACS-format CNF string. Format example: |
| p cnf 100 200 # header: 100 variables, 200 clauses |
| 1 -2 3 0 # clause: (x1 OR NOT x2 OR x3) |
| -1 4 0 # clause: (NOT x1 OR x4) |
| ... |
| max_vars: Maximum variable ID supported (typically 500) |
| |
| Returns: |
| List of length (max_vars + 1) where: |
| - mask[0] = 0 (unused, variables are 1-indexed) |
| - mask[v] = 1 if variable v appears in any clause |
| - mask[v] = 0 if variable v does not appear |
| |
| Note: We skip the header line "p cnf ..." to avoid capturing the clause count |
| as a valid variable (which was a bug in the original regex-based approach). |
| """ |
| mask = [0] * (max_vars + 1) |
| |
| for line in cnf_text.split('\n'): |
| line = line.strip() |
| |
| |
| |
| if not line or line.startswith('c') or line.startswith('p'): |
| continue |
| |
| |
| |
| |
| for tok in line.split(): |
| try: |
| lit = int(tok) |
| v = abs(lit) |
| if 1 <= v <= max_vars: |
| mask[v] = 1 |
| except ValueError: |
| continue |
|
|
| |
| |
| if sum(mask) == 0: |
| for v in range(1, max_vars + 1): |
| mask[v] = 1 |
| |
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class QwenVarClassifier(nn.Module): |
| """ |
| Transformer-based variable classifier for SAT branching. |
| |
| Architecture: |
| Input (CNF text) |
| → Tokenize |
| → Qwen3-4B backbone (frozen initially, fine-tuned with small LR) |
| → Extract last token's hidden state (sequence pooling) |
| → LayerNorm (stabilizes hidden state magnitude) |
| → Linear head (hidden_dim → num_classes) |
| → Logits for each variable ID |
| |
| Why this architecture? |
| 1. Pretrained LLM backbone understands text structure and can learn CNF patterns |
| 2. Last-token pooling: the final token has attended to the entire input |
| 3. LayerNorm: Qwen's hidden states have large magnitudes; normalizing prevents |
| exploding gradients when combined with randomly-initialized head |
| 4. Single linear head: simple, interpretable, efficient |
| """ |
|
|
| def __init__(self, base_model_name: str, max_vars: int): |
| """ |
| Initialize the classifier. |
| |
| Args: |
| base_model_name: HuggingFace model ID (e.g., "Qwen/Qwen3-4B") |
| max_vars: Maximum variable ID to classify (e.g., 500) |
| Output dimension will be max_vars + 1 (index 0 unused) |
| """ |
| super().__init__() |
| self.max_vars = max_vars |
| |
| |
| cfg = AutoConfig.from_pretrained(base_model_name) |
| cfg.output_hidden_states = True |
| |
| |
| |
| self.backbone = AutoModelForCausalLM.from_pretrained( |
| base_model_name, |
| config=cfg, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| ) |
| |
| hidden = self.backbone.config.hidden_size |
| |
| |
| |
| |
| |
| |
| self.head_ln = nn.LayerNorm(hidden) |
| |
| |
| |
| |
| self.head = nn.Linear(hidden, max_vars + 1) |
| |
| |
| |
| nn.init.normal_(self.head.weight, std=0.02) |
| nn.init.zeros_(self.head.bias) |
| |
| |
| |
| self.config = self.backbone.config |
|
|
| def forward(self, input_ids, attention_mask, **kwargs): |
| """ |
| Forward pass: CNF tokens → variable logits. |
| |
| Args: |
| input_ids: [batch, seq_len] token IDs from tokenizer |
| attention_mask: [batch, seq_len] binary mask (1 = real token, 0 = padding) |
| **kwargs: ignored (allows passing 'labels' without error during eval) |
| |
| Returns: |
| dict with "logits": [batch, max_vars + 1] raw classification logits |
| """ |
| |
| out = self.backbone( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
| |
| |
| |
| h = out.hidden_states[-1] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| last_idx = attention_mask.sum(dim=1) - 1 |
| last_idx = last_idx.clamp(min=0) |
| |
| |
| b = torch.arange(h.size(0), device=h.device) |
| pooled = h[b, last_idx] |
| |
| |
| if DEBUG_TRAINING: |
| if not hasattr(self, '_debug_count'): |
| self._debug_count = 0 |
| if self._debug_count < 3: |
| print(f"[DEBUG {self._debug_count}] pooled dtype={pooled.dtype}, mean={pooled.float().mean():.2f}, std={pooled.float().std():.2f}") |
| self._debug_count += 1 |
| |
| |
| pooled = self.head_ln(pooled) |
| |
| |
| if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: |
| print(f"[DEBUG] after LN: dtype={pooled.dtype}, mean={pooled.float().mean():.4f}, std={pooled.float().std():.4f}") |
| |
| |
| logits = self.head(pooled) |
| |
| |
| if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: |
| print(f"[DEBUG] logits: dtype={logits.dtype}, mean={logits.float().mean():.2f}, std={logits.float().std():.2f}, min={logits.float().min():.2f}, max={logits.float().max():.2f}") |
| |
| return {"logits": logits} |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Collator: |
| """ |
| Custom data collator for variable classification. |
| |
| Responsibilities: |
| 1. Pad variable-length token sequences to the same length within a batch |
| 2. Stack labels and valid_mask tensors |
| 3. Create proper attention masks for padded sequences |
| |
| Why custom collator? |
| - We have custom fields (valid_mask) that need special handling |
| - Standard HF collators don't know about our mask format |
| """ |
| tokenizer: Any |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| """ |
| Collate a list of examples into a batch. |
| |
| Args: |
| features: List of dicts, each with: |
| - input_ids: List[int] - token IDs |
| - attention_mask: List[int] - attention mask |
| - label: int - target variable ID |
| - valid_mask: List[int] - binary mask of valid variables |
| |
| Returns: |
| Dict with batched tensors: |
| - input_ids: [batch, max_seq_len] |
| - attention_mask: [batch, max_seq_len] |
| - labels: [batch] |
| - valid_mask: [batch, max_vars + 1] |
| """ |
| |
| input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features] |
| attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features] |
| labels = torch.tensor([f["label"] for f in features], dtype=torch.long) |
| valid_mask = torch.tensor([f["valid_mask"] for f in features], dtype=torch.bool) |
| |
| |
| |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, |
| batch_first=True, |
| padding_value=self.tokenizer.pad_token_id |
| ) |
| attention_mask = torch.nn.utils.rnn.pad_sequence( |
| attention_mask, |
| batch_first=True, |
| padding_value=0 |
| ) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "valid_mask": valid_mask, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class MaskedVarTrainer(Trainer): |
| """ |
| Custom HuggingFace Trainer with masked cross-entropy loss. |
| |
| The key modification: before computing cross-entropy, we mask out logits |
| for invalid variables (those not appearing in the CNF). This ensures: |
| 1. The model cannot predict invalid variables |
| 2. No gradient flows to invalid variable logits |
| 3. Training focuses only on distinguishing valid choices |
| |
| NOTE on displayed metrics: |
| - 'loss' shown by Trainer is summed across GPUs (loss × world_size) |
| We add 'true_loss' which is the actual per-sample loss |
| - 'grad_norm' is the L2 norm across ALL ~4B parameters BEFORE clipping |
| Values of 100-200 are normal for large models; it gets clipped to max_grad_norm |
| """ |
|
|
| def __init__(self, *args, max_vars: int, **kwargs): |
| """ |
| Args: |
| max_vars: Maximum variable ID (for sanity checking labels) |
| *args, **kwargs: Passed to parent Trainer |
| """ |
| super().__init__(*args, **kwargs) |
| self.max_vars = max_vars |
| self._accumulated_loss = 0.0 |
| self._loss_count = 0 |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| """ |
| Compute masked cross-entropy loss for variable classification. |
| |
| Algorithm: |
| 1. Extract labels and valid_mask from inputs |
| 2. Forward pass to get logits |
| 3. Set logits for invalid variables to -inf (or -1e4 for bf16 stability) |
| 4. Compute cross-entropy loss |
| |
| Args: |
| model: The QwenVarClassifier |
| inputs: Dict with input_ids, attention_mask, labels, valid_mask |
| return_outputs: If True, return (loss, outputs) tuple |
| num_items_in_batch: Unused (for API compatibility) |
| |
| Returns: |
| loss: Scalar loss value, or (loss, outputs) tuple if return_outputs=True |
| """ |
| |
| labels = inputs.get("labels") |
| valid_mask = inputs.get("valid_mask") |
| |
| |
| model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} |
| |
| |
| outputs = model(**model_inputs) |
| logits = outputs["logits"] |
| |
| |
| if DEBUG_TRAINING: |
| if not hasattr(self, '_loss_debug_count'): |
| self._loss_debug_count = 0 |
| if self._loss_debug_count < 5: |
| for i, (lbl, vmask) in enumerate(zip(labels, valid_mask)): |
| label_in_mask = vmask[lbl].item() |
| valid_count = vmask.sum().item() |
| logit_at_label = logits[i, lbl].item() |
| print(f"[LOSS DEBUG {self._loss_debug_count}] label={lbl.item()}, in_mask={label_in_mask}, valid_vars={valid_count}, logit_at_label={logit_at_label:.2f}") |
| self._loss_debug_count += 1 |
| |
| |
| |
| |
| |
| |
| |
| |
| logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) |
| |
| |
| |
| if torch.any(labels <= 0) or torch.any(labels > self.max_vars): |
| bad = labels[(labels <= 0) | (labels > self.max_vars)].detach().cpu().tolist() |
| raise ValueError(f"Out-of-range labels detected (showing up to 20): {bad[:20]}") |
| |
| |
| if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: |
| for i, lbl in enumerate(labels): |
| masked_logit = logits[i, lbl].item() |
| print(f"[LOSS DEBUG] after mask: logit_at_label={masked_logit:.2f}") |
| |
| |
| |
| loss = F.cross_entropy(logits, labels.to(logits.device)) |
| |
| |
| self._accumulated_loss += loss.item() |
| self._loss_count += 1 |
| |
| |
| if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: |
| print(f"[LOSS DEBUG] loss={loss.item():.2f}") |
| |
| |
| masked_outputs = {"logits": logits} |
| return (loss, masked_outputs) if return_outputs else loss |
|
|
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): |
| """ |
| Override prediction_step to properly return loss and logits for evaluation. |
| |
| The default HF Trainer prediction_step doesn't work well with custom compute_loss, |
| so we implement our own that properly computes masked loss and returns logits. |
| """ |
| model.eval() |
| |
| with torch.no_grad(): |
| |
| labels = inputs.get("labels") |
| valid_mask = inputs.get("valid_mask") |
| |
| |
| model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} |
| outputs = model(**model_inputs) |
| logits = outputs["logits"] |
| |
| |
| logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) |
| |
| |
| loss = F.cross_entropy(logits, labels.to(logits.device)) |
| |
| |
| return (loss, logits.detach(), labels.detach()) |
|
|
| def log(self, logs: Dict[str, float], start_time: float = None) -> None: |
| """ |
| Override log to add true_loss and ensure eval metrics are logged to W&B. |
| |
| The default 'loss' in HF Trainer is summed across GPUs in DDP/DeepSpeed. |
| We track the actual per-sample loss and report it as 'true_loss'. |
| """ |
| if self._loss_count > 0: |
| |
| true_loss = self._accumulated_loss / self._loss_count |
| logs["true_loss"] = round(true_loss, 4) |
| |
| |
| self._accumulated_loss = 0.0 |
| self._loss_count = 0 |
| |
| |
| super().log(logs, start_time) |
|
|
|
|
| def compute_metrics(eval_pred): |
| """ |
| Compute accuracy for evaluation. |
| |
| Args: |
| eval_pred: (logits, labels) from Trainer's prediction_loop |
| - logits: [num_samples, max_vars + 1] (already masked with -1e4 for invalid vars) |
| - labels: [num_samples] |
| |
| Returns: |
| Dict with "accuracy" (Trainer will prefix with "eval_") |
| |
| Note: eval_loss is computed automatically by Trainer from prediction_step's loss. |
| We don't need to compute it here. |
| |
| Since invalid variables have logits ≈ -1e4, argmax will naturally avoid them. |
| """ |
| logits, labels = eval_pred |
| |
| |
| preds = np.argmax(logits, axis=-1) |
| accuracy = float((preds == labels).mean()) |
| |
| return {"accuracy": accuracy} |
|
|
|
|
| def get_wandb_report_to(): |
| """ |
| Determine if this process should log to W&B. |
| |
| Only the main process (rank 0) should log to W&B to avoid creating multiple runs. |
| Other ranks should not log to any external service. |
| |
| Returns: |
| ["wandb"] for rank 0, [] for other ranks |
| """ |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| |
| if local_rank == 0: |
| return ["wandb"] |
| else: |
| return [] |
|
|
|
|
| |
| |
| |
| |
| def main(): |
| """ |
| Main training function. |
| |
| Pipeline: |
| 1. Parse command line arguments |
| 2. Load tokenizer and datasets |
| 3. Preprocess: tokenize CNF text, compute valid masks |
| 4. Initialize model with pretrained backbone + new classification head |
| 5. Configure training (optimizer, scheduler, logging, etc.) |
| 6. Train and evaluate |
| """ |
| ap = argparse.ArgumentParser( |
| description="Train a Qwen-based variable classifier for SAT branching" |
| ) |
| |
| |
| ap.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B", |
| help="HuggingFace model ID for the backbone") |
| ap.add_argument("--train_jsonl", type=str, required=True, |
| help="Path to training data (JSONL with 'cnf' and 'label' fields)") |
| ap.add_argument("--valid_jsonl", type=str, required=True, |
| help="Path to validation data (same format)") |
| ap.add_argument("--output_dir", type=str, default="./out_qwen_var_sft", |
| help="Directory for checkpoints and logs") |
| ap.add_argument("--max_vars", type=int, default=500, |
| help="Maximum variable ID (determines output dimension)") |
| ap.add_argument("--max_length", type=int, default=8192, |
| help="Maximum sequence length in tokens (truncates longer CNFs)") |
| ap.add_argument("--seed", type=int, default=0, |
| help="Random seed for reproducibility") |
| |
| |
| ap.add_argument("--per_device_train_batch_size", type=int, default=1, |
| help="Batch size per GPU for training") |
| ap.add_argument("--per_device_eval_batch_size", type=int, default=1, |
| help="Batch size per GPU for evaluation") |
| ap.add_argument("--gradient_accumulation_steps", type=int, default=8, |
| help="Accumulate gradients over this many steps (effective batch = this * batch_size * num_gpus)") |
| ap.add_argument("--learning_rate", type=float, default=5e-6, |
| help="Peak learning rate (after warmup). Lower than typical fine-tuning due to classification head") |
| ap.add_argument("--num_train_epochs", type=float, default=3.0, |
| help="Total training epochs") |
| ap.add_argument("--warmup_ratio", type=float, default=0.03, |
| help="Fraction of training steps for learning rate warmup") |
| ap.add_argument("--weight_decay", type=float, default=0.0, |
| help="Weight decay (L2 regularization)") |
| ap.add_argument("--logging_steps", type=int, default=10, |
| help="Log training metrics every N steps") |
| ap.add_argument("--eval_steps", type=int, default=200, |
| help="Evaluate every N steps") |
| ap.add_argument("--save_steps", type=int, default=200, |
| help="Save checkpoint every N steps") |
| ap.add_argument("--report_to", type=str, default="wandb", |
| choices=["wandb", "tensorboard", "none"], |
| help="Logging backend") |
| ap.add_argument("--deepspeed", type=str, default=None, |
| help="Path to DeepSpeed config JSON for distributed training") |
| ap.add_argument("--resume_from_checkpoint", type=str, default=None, |
| help="Path to checkpoint directory to resume from. If a directory is given, " |
| "the latest checkpoint in that directory will be used.") |
| |
| args = ap.parse_args() |
|
|
| |
| set_seed(args.seed) |
|
|
| |
| |
| tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) |
| if tok.pad_token is None: |
| |
| tok.pad_token = tok.eos_token |
|
|
| |
| ds = load_dataset( |
| "json", |
| data_files={"train": args.train_jsonl, "validation": args.valid_jsonl}, |
| ) |
|
|
| def preprocess(ex): |
| """ |
| Preprocess a single example. |
| |
| Steps: |
| 1. Tokenize the CNF text |
| 2. Compute valid variable mask |
| 3. Return features for training |
| |
| Args: |
| ex: Dict with 'cnf' (str) and 'label' (int) |
| |
| Returns: |
| Dict with input_ids, attention_mask, label, valid_mask |
| """ |
| cnf = ex["cnf"] |
| label = int(ex["label"]) |
| |
| |
| |
| enc = tok( |
| cnf, |
| truncation=True, |
| max_length=args.max_length, |
| padding=False |
| ) |
| |
| return { |
| "input_ids": enc["input_ids"], |
| "attention_mask": enc["attention_mask"], |
| "label": label, |
| "valid_mask": cnf_valid_mask(cnf, args.max_vars), |
| } |
| |
| |
| |
| ds = ds.map(preprocess, remove_columns=ds["train"].column_names) |
| |
| |
| model = QwenVarClassifier(args.model_name, max_vars=args.max_vars) |
| |
| |
| |
| model.backbone.gradient_checkpointing_enable() |
| |
| |
| report_to = get_wandb_report_to() |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| |
| |
| bf16=True, |
| tf32=True, |
| |
| |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| |
| |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| num_train_epochs=args.num_train_epochs, |
| weight_decay=args.weight_decay, |
| |
| |
| |
| |
| max_grad_norm=1.0, |
| |
| |
| logging_steps=args.logging_steps, |
| eval_strategy="steps", |
| eval_steps=args.eval_steps, |
| |
| |
| save_strategy="steps", |
| save_steps=args.save_steps, |
| save_total_limit=3, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_accuracy", |
| greater_is_better=True, |
| |
| |
| report_to=report_to, |
| run_name=os.environ.get("WANDB_RUN_NAME", "qwen-var-sft") if args.report_to == "wandb" else None, |
| logging_dir=os.path.join(args.output_dir, "logs"), |
| |
| |
| remove_unused_columns=False, |
| |
| |
| ddp_find_unused_parameters=False, |
| |
| |
| deepspeed=args.deepspeed, |
| |
| |
| save_safetensors=False, |
| ) |
|
|
| |
| trainer = MaskedVarTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=ds["train"], |
| eval_dataset=ds["validation"], |
| tokenizer=tok, |
| data_collator=Collator(tok), |
| compute_metrics=compute_metrics, |
| max_vars=args.max_vars, |
| ) |
|
|
| |
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| |
| |
| trainer.evaluate() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|