""" Qwen Variable Classifier for SAT Cube-and-Conquer This script trains a transformer-based policy to select the next branching variable for SAT (Boolean Satisfiability) solving using the Cube-and-Conquer approach. == Problem Overview == In Cube-and-Conquer SAT solving, we split a hard SAT problem into subproblems ("cubes") by choosing variables to branch on. The quality of variable selection significantly affects solving performance. This model learns to predict good branching variables from expert demonstrations. == Architecture == - Backbone: Qwen3-4B (pretrained causal language model) - Head: LayerNorm + Linear classifier over variable IDs (1 to max_vars) - The model reads a CNF formula as text and outputs logits for each possible variable == Training Approach == - Supervised Fine-Tuning (SFT) on expert variable choices - Masked classification: only variables appearing in the CNF are valid choices - Loss: Cross-entropy with invalid variable logits masked to -infinity == Data Format == JSONL with fields: - "cnf": DIMACS-format CNF text (e.g., "p cnf 100 200\n1 -2 3 0\n...") - "label": integer variable ID to branch on (1 to max_vars) """ import os import argparse from dataclasses import dataclass from typing import Any, Dict, List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed, ) # ============================================================================= # DEBUG FLAG: Set to True to enable verbose debug output, False to disable # Can also be controlled via environment variable: DEBUG_TRAINING=1 # ============================================================================= DEBUG_TRAINING = os.environ.get("DEBUG_TRAINING", "0") == "1" # ============================================================================= # CNF PARSING: Extract valid variables from DIMACS CNF text # ============================================================================= def cnf_valid_mask(cnf_text: str, max_vars: int) -> List[int]: """ Build a binary mask indicating which variable IDs appear in the CNF. This is crucial for masked classification: - A variable that doesn't appear in the (simplified) CNF cannot be branched on - By masking invalid variables, we ensure the model only learns over valid choices Args: cnf_text: DIMACS-format CNF string. Format example: p cnf 100 200 # header: 100 variables, 200 clauses 1 -2 3 0 # clause: (x1 OR NOT x2 OR x3) -1 4 0 # clause: (NOT x1 OR x4) ... max_vars: Maximum variable ID supported (typically 500) Returns: List of length (max_vars + 1) where: - mask[0] = 0 (unused, variables are 1-indexed) - mask[v] = 1 if variable v appears in any clause - mask[v] = 0 if variable v does not appear Note: We skip the header line "p cnf ..." to avoid capturing the clause count as a valid variable (which was a bug in the original regex-based approach). """ mask = [0] * (max_vars + 1) for line in cnf_text.split('\n'): line = line.strip() # Skip empty lines, comment lines (start with 'c'), and header line (starts with 'p') # The header "p cnf " would incorrectly add num_clauses as a variable if not line or line.startswith('c') or line.startswith('p'): continue # Parse clause: space-separated integers ending with 0 # Each integer is a literal: positive = variable, negative = negated variable # Example: "1 -2 3 0" means (x1 OR NOT x2 OR x3) for tok in line.split(): try: lit = int(tok) v = abs(lit) # Variable ID is absolute value of literal if 1 <= v <= max_vars: mask[v] = 1 except ValueError: continue # Skip non-integer tokens (shouldn't happen in valid DIMACS) # Fallback: if no variables found (e.g., truncated/malformed input), allow all # This prevents the model from having zero valid outputs if sum(mask) == 0: for v in range(1, max_vars + 1): mask[v] = 1 return mask # ============================================================================= # MODEL: Qwen backbone with classification head for variable selection # ============================================================================= class QwenVarClassifier(nn.Module): """ Transformer-based variable classifier for SAT branching. Architecture: Input (CNF text) → Tokenize → Qwen3-4B backbone (frozen initially, fine-tuned with small LR) → Extract last token's hidden state (sequence pooling) → LayerNorm (stabilizes hidden state magnitude) → Linear head (hidden_dim → num_classes) → Logits for each variable ID Why this architecture? 1. Pretrained LLM backbone understands text structure and can learn CNF patterns 2. Last-token pooling: the final token has attended to the entire input 3. LayerNorm: Qwen's hidden states have large magnitudes; normalizing prevents exploding gradients when combined with randomly-initialized head 4. Single linear head: simple, interpretable, efficient """ def __init__(self, base_model_name: str, max_vars: int): """ Initialize the classifier. Args: base_model_name: HuggingFace model ID (e.g., "Qwen/Qwen3-4B") max_vars: Maximum variable ID to classify (e.g., 500) Output dimension will be max_vars + 1 (index 0 unused) """ super().__init__() self.max_vars = max_vars # Load Qwen configuration and enable hidden state output cfg = AutoConfig.from_pretrained(base_model_name) cfg.output_hidden_states = True # We need hidden states, not just logits # Load pretrained Qwen model # Using bfloat16 for memory efficiency on modern GPUs (H100, A100) self.backbone = AutoModelForCausalLM.from_pretrained( base_model_name, config=cfg, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) hidden = self.backbone.config.hidden_size # e.g., 2560 for Qwen3-4B # LayerNorm to normalize hidden states before classification # This is critical for stable training: # - Qwen's hidden states can have large magnitude (std >> 1) # - Randomly initialized linear head expects normalized inputs # - Without LayerNorm, initial logits can be huge → high loss → exploding gradients self.head_ln = nn.LayerNorm(hidden) # Classification head: maps hidden state to variable logits # Output shape: [batch, max_vars + 1] # Index 0 is unused (variables are 1-indexed in DIMACS) self.head = nn.Linear(hidden, max_vars + 1) # Initialize head with standard small weights # LayerNorm ensures the input has unit variance, so this init is appropriate nn.init.normal_(self.head.weight, std=0.02) nn.init.zeros_(self.head.bias) # Expose backbone config for DeepSpeed compatibility # DeepSpeed checks model.config.hidden_size for auto-configuration self.config = self.backbone.config def forward(self, input_ids, attention_mask, **kwargs): """ Forward pass: CNF tokens → variable logits. Args: input_ids: [batch, seq_len] token IDs from tokenizer attention_mask: [batch, seq_len] binary mask (1 = real token, 0 = padding) **kwargs: ignored (allows passing 'labels' without error during eval) Returns: dict with "logits": [batch, max_vars + 1] raw classification logits """ # Run through Qwen backbone out = self.backbone( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, # Need hidden states, not LM logits use_cache=False, # Disable KV cache (not needed for training) ) # Get hidden states from the last transformer layer # Shape: [batch, seq_len, hidden_dim] h = out.hidden_states[-1] # Pool by taking the last non-padding token's hidden state # This is the standard approach for causal LMs (like using [CLS] for BERT) # # Why last token? # - In causal attention, each token only sees previous tokens # - The last token has attended to the entire input sequence # - It's a natural "summary" of the input # # Compute index of last real token: sum of attention mask minus 1 last_idx = attention_mask.sum(dim=1) - 1 # [batch] last_idx = last_idx.clamp(min=0) # Safety: ensure non-negative # Gather hidden state at the last token position for each batch element b = torch.arange(h.size(0), device=h.device) pooled = h[b, last_idx] # [batch, hidden_dim] # DEBUG: Check hidden state stats if DEBUG_TRAINING: if not hasattr(self, '_debug_count'): self._debug_count = 0 if self._debug_count < 3: print(f"[DEBUG {self._debug_count}] pooled dtype={pooled.dtype}, mean={pooled.float().mean():.2f}, std={pooled.float().std():.2f}") self._debug_count += 1 # Normalize hidden states for stable classification pooled = self.head_ln(pooled) # DEBUG: Check after LayerNorm if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: print(f"[DEBUG] after LN: dtype={pooled.dtype}, mean={pooled.float().mean():.4f}, std={pooled.float().std():.4f}") # Project to variable logits logits = self.head(pooled) # [batch, max_vars + 1] # DEBUG: Check logits if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: print(f"[DEBUG] logits: dtype={logits.dtype}, mean={logits.float().mean():.2f}, std={logits.float().std():.2f}, min={logits.float().min():.2f}, max={logits.float().max():.2f}") return {"logits": logits} # ============================================================================= # DATA COLLATOR: Batch preparation with padding and mask handling # ============================================================================= @dataclass class Collator: """ Custom data collator for variable classification. Responsibilities: 1. Pad variable-length token sequences to the same length within a batch 2. Stack labels and valid_mask tensors 3. Create proper attention masks for padded sequences Why custom collator? - We have custom fields (valid_mask) that need special handling - Standard HF collators don't know about our mask format """ tokenizer: Any # Tokenizer for padding configuration def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: """ Collate a list of examples into a batch. Args: features: List of dicts, each with: - input_ids: List[int] - token IDs - attention_mask: List[int] - attention mask - label: int - target variable ID - valid_mask: List[int] - binary mask of valid variables Returns: Dict with batched tensors: - input_ids: [batch, max_seq_len] - attention_mask: [batch, max_seq_len] - labels: [batch] - valid_mask: [batch, max_vars + 1] """ # Convert to tensors input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features] attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features] labels = torch.tensor([f["label"] for f in features], dtype=torch.long) valid_mask = torch.tensor([f["valid_mask"] for f in features], dtype=torch.bool) # Pad sequences to same length within batch # Using pad_sequence pads shorter sequences with padding_value input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) attention_mask = torch.nn.utils.rnn.pad_sequence( attention_mask, batch_first=True, padding_value=0 # Padding positions get 0 attention ) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "valid_mask": valid_mask, } # ============================================================================= # TRAINER: Custom loss computation with variable masking # ============================================================================= class MaskedVarTrainer(Trainer): """ Custom HuggingFace Trainer with masked cross-entropy loss. The key modification: before computing cross-entropy, we mask out logits for invalid variables (those not appearing in the CNF). This ensures: 1. The model cannot predict invalid variables 2. No gradient flows to invalid variable logits 3. Training focuses only on distinguishing valid choices NOTE on displayed metrics: - 'loss' shown by Trainer is summed across GPUs (loss × world_size) We add 'true_loss' which is the actual per-sample loss - 'grad_norm' is the L2 norm across ALL ~4B parameters BEFORE clipping Values of 100-200 are normal for large models; it gets clipped to max_grad_norm """ def __init__(self, *args, max_vars: int, **kwargs): """ Args: max_vars: Maximum variable ID (for sanity checking labels) *args, **kwargs: Passed to parent Trainer """ super().__init__(*args, **kwargs) self.max_vars = max_vars self._accumulated_loss = 0.0 self._loss_count = 0 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ Compute masked cross-entropy loss for variable classification. Algorithm: 1. Extract labels and valid_mask from inputs 2. Forward pass to get logits 3. Set logits for invalid variables to -inf (or -1e4 for bf16 stability) 4. Compute cross-entropy loss Args: model: The QwenVarClassifier inputs: Dict with input_ids, attention_mask, labels, valid_mask return_outputs: If True, return (loss, outputs) tuple num_items_in_batch: Unused (for API compatibility) Returns: loss: Scalar loss value, or (loss, outputs) tuple if return_outputs=True """ # Get labels and mask (don't pop - prediction_loop needs labels for compute_metrics) labels = inputs.get("labels") # [batch] valid_mask = inputs.get("valid_mask") # [batch, max_vars + 1] boolean # Remove from inputs for model.forward (which doesn't expect them) model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} # Forward pass outputs = model(**model_inputs) logits = outputs["logits"] # [batch, max_vars + 1] # DEBUG: Check if label is in valid_mask if DEBUG_TRAINING: if not hasattr(self, '_loss_debug_count'): self._loss_debug_count = 0 if self._loss_debug_count < 5: for i, (lbl, vmask) in enumerate(zip(labels, valid_mask)): label_in_mask = vmask[lbl].item() valid_count = vmask.sum().item() logit_at_label = logits[i, lbl].item() print(f"[LOSS DEBUG {self._loss_debug_count}] label={lbl.item()}, in_mask={label_in_mask}, valid_vars={valid_count}, logit_at_label={logit_at_label:.2f}") self._loss_debug_count += 1 # Mask invalid variables by setting their logits to a large negative value # After softmax, these will have probability ≈ 0 # # Why -1e4 instead of -inf or -1e9? # - bfloat16 has limited dynamic range # - -1e9 can cause NaN issues when computing softmax/cross-entropy # - -1e4 is small enough to give ~0 probability while staying numerically stable logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) # Sanity check: labels must be valid variable IDs (1 to max_vars) # This catches data bugs early if torch.any(labels <= 0) or torch.any(labels > self.max_vars): bad = labels[(labels <= 0) | (labels > self.max_vars)].detach().cpu().tolist() raise ValueError(f"Out-of-range labels detected (showing up to 20): {bad[:20]}") # DEBUG: Check logit at label after masking if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: for i, lbl in enumerate(labels): masked_logit = logits[i, lbl].item() print(f"[LOSS DEBUG] after mask: logit_at_label={masked_logit:.2f}") # Standard cross-entropy loss # PyTorch's cross_entropy expects logits, not probabilities loss = F.cross_entropy(logits, labels.to(logits.device)) # Track true loss for accurate logging self._accumulated_loss += loss.item() self._loss_count += 1 # DEBUG: Print loss if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: print(f"[LOSS DEBUG] loss={loss.item():.2f}") # Return masked logits in outputs (so compute_metrics gets properly masked predictions) masked_outputs = {"logits": logits} return (loss, masked_outputs) if return_outputs else loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): """ Override prediction_step to properly return loss and logits for evaluation. The default HF Trainer prediction_step doesn't work well with custom compute_loss, so we implement our own that properly computes masked loss and returns logits. """ model.eval() with torch.no_grad(): # Get labels and mask labels = inputs.get("labels") valid_mask = inputs.get("valid_mask") # Forward pass model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} outputs = model(**model_inputs) logits = outputs["logits"] # Mask invalid variables logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) # Compute loss loss = F.cross_entropy(logits, labels.to(logits.device)) # Return (loss, logits, labels) - this is what compute_metrics expects return (loss, logits.detach(), labels.detach()) def log(self, logs: Dict[str, float], start_time: float = None) -> None: """ Override log to add true_loss and ensure eval metrics are logged to W&B. The default 'loss' in HF Trainer is summed across GPUs in DDP/DeepSpeed. We track the actual per-sample loss and report it as 'true_loss'. """ if self._loss_count > 0: # Calculate true average loss on this device true_loss = self._accumulated_loss / self._loss_count logs["true_loss"] = round(true_loss, 4) # Reset for next logging interval self._accumulated_loss = 0.0 self._loss_count = 0 # Let HF Trainer handle W&B logging - it manages step ordering correctly super().log(logs, start_time) def compute_metrics(eval_pred): """ Compute accuracy for evaluation. Args: eval_pred: (logits, labels) from Trainer's prediction_loop - logits: [num_samples, max_vars + 1] (already masked with -1e4 for invalid vars) - labels: [num_samples] Returns: Dict with "accuracy" (Trainer will prefix with "eval_") Note: eval_loss is computed automatically by Trainer from prediction_step's loss. We don't need to compute it here. Since invalid variables have logits ≈ -1e4, argmax will naturally avoid them. """ logits, labels = eval_pred # Accuracy: argmax prediction vs true label preds = np.argmax(logits, axis=-1) accuracy = float((preds == labels).mean()) return {"accuracy": accuracy} def get_wandb_report_to(): """ Determine if this process should log to W&B. Only the main process (rank 0) should log to W&B to avoid creating multiple runs. Other ranks should not log to any external service. Returns: ["wandb"] for rank 0, [] for other ranks """ local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank == 0: return ["wandb"] else: return [] # ============================================================================= # MAIN: Training pipeline # ============================================================================= def main(): """ Main training function. Pipeline: 1. Parse command line arguments 2. Load tokenizer and datasets 3. Preprocess: tokenize CNF text, compute valid masks 4. Initialize model with pretrained backbone + new classification head 5. Configure training (optimizer, scheduler, logging, etc.) 6. Train and evaluate """ ap = argparse.ArgumentParser( description="Train a Qwen-based variable classifier for SAT branching" ) # Model and data arguments ap.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B", help="HuggingFace model ID for the backbone") ap.add_argument("--train_jsonl", type=str, required=True, help="Path to training data (JSONL with 'cnf' and 'label' fields)") ap.add_argument("--valid_jsonl", type=str, required=True, help="Path to validation data (same format)") ap.add_argument("--output_dir", type=str, default="./out_qwen_var_sft", help="Directory for checkpoints and logs") ap.add_argument("--max_vars", type=int, default=500, help="Maximum variable ID (determines output dimension)") ap.add_argument("--max_length", type=int, default=8192, help="Maximum sequence length in tokens (truncates longer CNFs)") ap.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") # Training hyperparameters ap.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per GPU for training") ap.add_argument("--per_device_eval_batch_size", type=int, default=1, help="Batch size per GPU for evaluation") ap.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients over this many steps (effective batch = this * batch_size * num_gpus)") ap.add_argument("--learning_rate", type=float, default=5e-6, help="Peak learning rate (after warmup). Lower than typical fine-tuning due to classification head") ap.add_argument("--num_train_epochs", type=float, default=3.0, help="Total training epochs") ap.add_argument("--warmup_ratio", type=float, default=0.03, help="Fraction of training steps for learning rate warmup") ap.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay (L2 regularization)") ap.add_argument("--logging_steps", type=int, default=10, help="Log training metrics every N steps") ap.add_argument("--eval_steps", type=int, default=200, help="Evaluate every N steps") ap.add_argument("--save_steps", type=int, default=200, help="Save checkpoint every N steps") ap.add_argument("--report_to", type=str, default="wandb", choices=["wandb", "tensorboard", "none"], help="Logging backend") ap.add_argument("--deepspeed", type=str, default=None, help="Path to DeepSpeed config JSON for distributed training") args = ap.parse_args() # Set random seeds for reproducibility set_seed(args.seed) # Load tokenizer # Qwen uses a byte-level BPE tokenizer tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) if tok.pad_token is None: # Qwen doesn't have a dedicated pad token; use eos as pad tok.pad_token = tok.eos_token # Load datasets from JSONL files ds = load_dataset( "json", data_files={"train": args.train_jsonl, "validation": args.valid_jsonl}, ) def preprocess(ex): """ Preprocess a single example. Steps: 1. Tokenize the CNF text 2. Compute valid variable mask 3. Return features for training Args: ex: Dict with 'cnf' (str) and 'label' (int) Returns: Dict with input_ids, attention_mask, label, valid_mask """ cnf = ex["cnf"] label = int(ex["label"]) # Tokenize CNF text # No special prompt/instruction - the model learns to interpret raw CNF enc = tok( cnf, truncation=True, max_length=args.max_length, padding=False # We handle padding in the collator ) return { "input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "label": label, "valid_mask": cnf_valid_mask(cnf, args.max_vars), } # Apply preprocessing to all examples # remove_columns drops original fields (cnf, label) since we've extracted what we need ds = ds.map(preprocess, remove_columns=ds["train"].column_names) # Initialize model model = QwenVarClassifier(args.model_name, max_vars=args.max_vars) # Enable gradient checkpointing to save memory on long sequences # This trades compute for memory by recomputing activations during backward pass model.backbone.gradient_checkpointing_enable() # Configure W&B logging (only rank 0 logs to avoid duplicate runs) report_to = get_wandb_report_to() # Configure training training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, # Precision settings for modern GPUs bf16=True, # Use bfloat16 for training (good for H100/A100) tf32=True, # Enable TF32 for faster matmuls on Ampere+ # Batch configuration per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, # Optimizer settings learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs, weight_decay=args.weight_decay, # Gradient clipping for training stability # Clips gradient norm to this value if it exceeds it # This prevents exploding gradients from destabilizing training max_grad_norm=1.0, # Logging and evaluation logging_steps=args.logging_steps, eval_strategy="steps", eval_steps=args.eval_steps, # Checkpointing - keep best checkpoints based on validation accuracy save_strategy="steps", save_steps=args.save_steps, save_total_limit=3, # Keep best 3 checkpoints load_best_model_at_end=True, # Load best checkpoint at end of training metric_for_best_model="eval_accuracy", # Use validation accuracy to determine best greater_is_better=True, # Higher accuracy is better # Logging backend report_to=report_to, run_name=os.environ.get("WANDB_RUN_NAME", "qwen-var-sft") if args.report_to == "wandb" else None, logging_dir=os.path.join(args.output_dir, "logs"), # Important: don't remove valid_mask column (we need it in compute_loss) remove_unused_columns=False, # DDP settings (for multi-GPU) ddp_find_unused_parameters=False, # DeepSpeed for efficient distributed training deepspeed=args.deepspeed, # Use pickle format for saving (safetensors has issues with some weight tying configs) save_safetensors=False, ) # Create trainer with custom loss computation trainer = MaskedVarTrainer( model=model, args=training_args, train_dataset=ds["train"], eval_dataset=ds["validation"], tokenizer=tok, data_collator=Collator(tok), compute_metrics=compute_metrics, max_vars=args.max_vars, ) # Train! trainer.train() # Final evaluation trainer.evaluate() if __name__ == "__main__": main()