File size: 30,363 Bytes
d0da4dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 | """
Qwen Variable Classifier for SAT Cube-and-Conquer
This script trains a transformer-based policy to select the next branching variable
for SAT (Boolean Satisfiability) solving using the Cube-and-Conquer approach.
== Problem Overview ==
In Cube-and-Conquer SAT solving, we split a hard SAT problem into subproblems ("cubes")
by choosing variables to branch on. The quality of variable selection significantly
affects solving performance. This model learns to predict good branching variables
from expert demonstrations.
== Architecture ==
- Backbone: Qwen3-4B (pretrained causal language model)
- Head: LayerNorm + Linear classifier over variable IDs (1 to max_vars)
- The model reads a CNF formula as text and outputs logits for each possible variable
== Training Approach ==
- Supervised Fine-Tuning (SFT) on expert variable choices
- Masked classification: only variables appearing in the CNF are valid choices
- Loss: Cross-entropy with invalid variable logits masked to -infinity
== Data Format ==
JSONL with fields:
- "cnf": DIMACS-format CNF text (e.g., "p cnf 100 200\n1 -2 3 0\n...")
- "label": integer variable ID to branch on (1 to max_vars)
"""
import os
import argparse
from dataclasses import dataclass
from typing import Any, Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed,
)
# =============================================================================
# DEBUG FLAG: Set to True to enable verbose debug output, False to disable
# Can also be controlled via environment variable: DEBUG_TRAINING=1
# =============================================================================
DEBUG_TRAINING = os.environ.get("DEBUG_TRAINING", "0") == "1"
# =============================================================================
# CNF PARSING: Extract valid variables from DIMACS CNF text
# =============================================================================
def cnf_valid_mask(cnf_text: str, max_vars: int) -> List[int]:
"""
Build a binary mask indicating which variable IDs appear in the CNF.
This is crucial for masked classification:
- A variable that doesn't appear in the (simplified) CNF cannot be branched on
- By masking invalid variables, we ensure the model only learns over valid choices
Args:
cnf_text: DIMACS-format CNF string. Format example:
p cnf 100 200 # header: 100 variables, 200 clauses
1 -2 3 0 # clause: (x1 OR NOT x2 OR x3)
-1 4 0 # clause: (NOT x1 OR x4)
...
max_vars: Maximum variable ID supported (typically 500)
Returns:
List of length (max_vars + 1) where:
- mask[0] = 0 (unused, variables are 1-indexed)
- mask[v] = 1 if variable v appears in any clause
- mask[v] = 0 if variable v does not appear
Note: We skip the header line "p cnf ..." to avoid capturing the clause count
as a valid variable (which was a bug in the original regex-based approach).
"""
mask = [0] * (max_vars + 1)
for line in cnf_text.split('\n'):
line = line.strip()
# Skip empty lines, comment lines (start with 'c'), and header line (starts with 'p')
# The header "p cnf <num_vars> <num_clauses>" would incorrectly add num_clauses as a variable
if not line or line.startswith('c') or line.startswith('p'):
continue
# Parse clause: space-separated integers ending with 0
# Each integer is a literal: positive = variable, negative = negated variable
# Example: "1 -2 3 0" means (x1 OR NOT x2 OR x3)
for tok in line.split():
try:
lit = int(tok)
v = abs(lit) # Variable ID is absolute value of literal
if 1 <= v <= max_vars:
mask[v] = 1
except ValueError:
continue # Skip non-integer tokens (shouldn't happen in valid DIMACS)
# Fallback: if no variables found (e.g., truncated/malformed input), allow all
# This prevents the model from having zero valid outputs
if sum(mask) == 0:
for v in range(1, max_vars + 1):
mask[v] = 1
return mask
# =============================================================================
# MODEL: Qwen backbone with classification head for variable selection
# =============================================================================
class QwenVarClassifier(nn.Module):
"""
Transformer-based variable classifier for SAT branching.
Architecture:
Input (CNF text)
β Tokenize
β Qwen3-4B backbone (frozen initially, fine-tuned with small LR)
β Extract last token's hidden state (sequence pooling)
β LayerNorm (stabilizes hidden state magnitude)
β Linear head (hidden_dim β num_classes)
β Logits for each variable ID
Why this architecture?
1. Pretrained LLM backbone understands text structure and can learn CNF patterns
2. Last-token pooling: the final token has attended to the entire input
3. LayerNorm: Qwen's hidden states have large magnitudes; normalizing prevents
exploding gradients when combined with randomly-initialized head
4. Single linear head: simple, interpretable, efficient
"""
def __init__(self, base_model_name: str, max_vars: int):
"""
Initialize the classifier.
Args:
base_model_name: HuggingFace model ID (e.g., "Qwen/Qwen3-4B")
max_vars: Maximum variable ID to classify (e.g., 500)
Output dimension will be max_vars + 1 (index 0 unused)
"""
super().__init__()
self.max_vars = max_vars
# Load Qwen configuration and enable hidden state output
cfg = AutoConfig.from_pretrained(base_model_name)
cfg.output_hidden_states = True # We need hidden states, not just logits
# Load pretrained Qwen model
# Using bfloat16 for memory efficiency on modern GPUs (H100, A100)
self.backbone = AutoModelForCausalLM.from_pretrained(
base_model_name,
config=cfg,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
hidden = self.backbone.config.hidden_size # e.g., 2560 for Qwen3-4B
# LayerNorm to normalize hidden states before classification
# This is critical for stable training:
# - Qwen's hidden states can have large magnitude (std >> 1)
# - Randomly initialized linear head expects normalized inputs
# - Without LayerNorm, initial logits can be huge β high loss β exploding gradients
self.head_ln = nn.LayerNorm(hidden)
# Classification head: maps hidden state to variable logits
# Output shape: [batch, max_vars + 1]
# Index 0 is unused (variables are 1-indexed in DIMACS)
self.head = nn.Linear(hidden, max_vars + 1)
# Initialize head with standard small weights
# LayerNorm ensures the input has unit variance, so this init is appropriate
nn.init.normal_(self.head.weight, std=0.02)
nn.init.zeros_(self.head.bias)
# Expose backbone config for DeepSpeed compatibility
# DeepSpeed checks model.config.hidden_size for auto-configuration
self.config = self.backbone.config
def forward(self, input_ids, attention_mask, **kwargs):
"""
Forward pass: CNF tokens β variable logits.
Args:
input_ids: [batch, seq_len] token IDs from tokenizer
attention_mask: [batch, seq_len] binary mask (1 = real token, 0 = padding)
**kwargs: ignored (allows passing 'labels' without error during eval)
Returns:
dict with "logits": [batch, max_vars + 1] raw classification logits
"""
# Run through Qwen backbone
out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True, # Need hidden states, not LM logits
use_cache=False, # Disable KV cache (not needed for training)
)
# Get hidden states from the last transformer layer
# Shape: [batch, seq_len, hidden_dim]
h = out.hidden_states[-1]
# Pool by taking the last non-padding token's hidden state
# This is the standard approach for causal LMs (like using [CLS] for BERT)
#
# Why last token?
# - In causal attention, each token only sees previous tokens
# - The last token has attended to the entire input sequence
# - It's a natural "summary" of the input
#
# Compute index of last real token: sum of attention mask minus 1
last_idx = attention_mask.sum(dim=1) - 1 # [batch]
last_idx = last_idx.clamp(min=0) # Safety: ensure non-negative
# Gather hidden state at the last token position for each batch element
b = torch.arange(h.size(0), device=h.device)
pooled = h[b, last_idx] # [batch, hidden_dim]
# DEBUG: Check hidden state stats
if DEBUG_TRAINING:
if not hasattr(self, '_debug_count'):
self._debug_count = 0
if self._debug_count < 3:
print(f"[DEBUG {self._debug_count}] pooled dtype={pooled.dtype}, mean={pooled.float().mean():.2f}, std={pooled.float().std():.2f}")
self._debug_count += 1
# Normalize hidden states for stable classification
pooled = self.head_ln(pooled)
# DEBUG: Check after LayerNorm
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3:
print(f"[DEBUG] after LN: dtype={pooled.dtype}, mean={pooled.float().mean():.4f}, std={pooled.float().std():.4f}")
# Project to variable logits
logits = self.head(pooled) # [batch, max_vars + 1]
# DEBUG: Check logits
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3:
print(f"[DEBUG] logits: dtype={logits.dtype}, mean={logits.float().mean():.2f}, std={logits.float().std():.2f}, min={logits.float().min():.2f}, max={logits.float().max():.2f}")
return {"logits": logits}
# =============================================================================
# DATA COLLATOR: Batch preparation with padding and mask handling
# =============================================================================
@dataclass
class Collator:
"""
Custom data collator for variable classification.
Responsibilities:
1. Pad variable-length token sequences to the same length within a batch
2. Stack labels and valid_mask tensors
3. Create proper attention masks for padded sequences
Why custom collator?
- We have custom fields (valid_mask) that need special handling
- Standard HF collators don't know about our mask format
"""
tokenizer: Any # Tokenizer for padding configuration
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Collate a list of examples into a batch.
Args:
features: List of dicts, each with:
- input_ids: List[int] - token IDs
- attention_mask: List[int] - attention mask
- label: int - target variable ID
- valid_mask: List[int] - binary mask of valid variables
Returns:
Dict with batched tensors:
- input_ids: [batch, max_seq_len]
- attention_mask: [batch, max_seq_len]
- labels: [batch]
- valid_mask: [batch, max_vars + 1]
"""
# Convert to tensors
input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
labels = torch.tensor([f["label"] for f in features], dtype=torch.long)
valid_mask = torch.tensor([f["valid_mask"] for f in features], dtype=torch.bool)
# Pad sequences to same length within batch
# Using pad_sequence pads shorter sequences with padding_value
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
attention_mask,
batch_first=True,
padding_value=0 # Padding positions get 0 attention
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"valid_mask": valid_mask,
}
# =============================================================================
# TRAINER: Custom loss computation with variable masking
# =============================================================================
class MaskedVarTrainer(Trainer):
"""
Custom HuggingFace Trainer with masked cross-entropy loss.
The key modification: before computing cross-entropy, we mask out logits
for invalid variables (those not appearing in the CNF). This ensures:
1. The model cannot predict invalid variables
2. No gradient flows to invalid variable logits
3. Training focuses only on distinguishing valid choices
NOTE on displayed metrics:
- 'loss' shown by Trainer is summed across GPUs (loss Γ world_size)
We add 'true_loss' which is the actual per-sample loss
- 'grad_norm' is the L2 norm across ALL ~4B parameters BEFORE clipping
Values of 100-200 are normal for large models; it gets clipped to max_grad_norm
"""
def __init__(self, *args, max_vars: int, **kwargs):
"""
Args:
max_vars: Maximum variable ID (for sanity checking labels)
*args, **kwargs: Passed to parent Trainer
"""
super().__init__(*args, **kwargs)
self.max_vars = max_vars
self._accumulated_loss = 0.0
self._loss_count = 0
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute masked cross-entropy loss for variable classification.
Algorithm:
1. Extract labels and valid_mask from inputs
2. Forward pass to get logits
3. Set logits for invalid variables to -inf (or -1e4 for bf16 stability)
4. Compute cross-entropy loss
Args:
model: The QwenVarClassifier
inputs: Dict with input_ids, attention_mask, labels, valid_mask
return_outputs: If True, return (loss, outputs) tuple
num_items_in_batch: Unused (for API compatibility)
Returns:
loss: Scalar loss value, or (loss, outputs) tuple if return_outputs=True
"""
# Get labels and mask (don't pop - prediction_loop needs labels for compute_metrics)
labels = inputs.get("labels") # [batch]
valid_mask = inputs.get("valid_mask") # [batch, max_vars + 1] boolean
# Remove from inputs for model.forward (which doesn't expect them)
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]}
# Forward pass
outputs = model(**model_inputs)
logits = outputs["logits"] # [batch, max_vars + 1]
# DEBUG: Check if label is in valid_mask
if DEBUG_TRAINING:
if not hasattr(self, '_loss_debug_count'):
self._loss_debug_count = 0
if self._loss_debug_count < 5:
for i, (lbl, vmask) in enumerate(zip(labels, valid_mask)):
label_in_mask = vmask[lbl].item()
valid_count = vmask.sum().item()
logit_at_label = logits[i, lbl].item()
print(f"[LOSS DEBUG {self._loss_debug_count}] label={lbl.item()}, in_mask={label_in_mask}, valid_vars={valid_count}, logit_at_label={logit_at_label:.2f}")
self._loss_debug_count += 1
# Mask invalid variables by setting their logits to a large negative value
# After softmax, these will have probability β 0
#
# Why -1e4 instead of -inf or -1e9?
# - bfloat16 has limited dynamic range
# - -1e9 can cause NaN issues when computing softmax/cross-entropy
# - -1e4 is small enough to give ~0 probability while staying numerically stable
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4)
# Sanity check: labels must be valid variable IDs (1 to max_vars)
# This catches data bugs early
if torch.any(labels <= 0) or torch.any(labels > self.max_vars):
bad = labels[(labels <= 0) | (labels > self.max_vars)].detach().cpu().tolist()
raise ValueError(f"Out-of-range labels detected (showing up to 20): {bad[:20]}")
# DEBUG: Check logit at label after masking
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5:
for i, lbl in enumerate(labels):
masked_logit = logits[i, lbl].item()
print(f"[LOSS DEBUG] after mask: logit_at_label={masked_logit:.2f}")
# Standard cross-entropy loss
# PyTorch's cross_entropy expects logits, not probabilities
loss = F.cross_entropy(logits, labels.to(logits.device))
# Track true loss for accurate logging
self._accumulated_loss += loss.item()
self._loss_count += 1
# DEBUG: Print loss
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5:
print(f"[LOSS DEBUG] loss={loss.item():.2f}")
# Return masked logits in outputs (so compute_metrics gets properly masked predictions)
masked_outputs = {"logits": logits}
return (loss, masked_outputs) if return_outputs else loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""
Override prediction_step to properly return loss and logits for evaluation.
The default HF Trainer prediction_step doesn't work well with custom compute_loss,
so we implement our own that properly computes masked loss and returns logits.
"""
model.eval()
with torch.no_grad():
# Get labels and mask
labels = inputs.get("labels")
valid_mask = inputs.get("valid_mask")
# Forward pass
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]}
outputs = model(**model_inputs)
logits = outputs["logits"]
# Mask invalid variables
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4)
# Compute loss
loss = F.cross_entropy(logits, labels.to(logits.device))
# Return (loss, logits, labels) - this is what compute_metrics expects
return (loss, logits.detach(), labels.detach())
def log(self, logs: Dict[str, float], start_time: float = None) -> None:
"""
Override log to add true_loss and ensure eval metrics are logged to W&B.
The default 'loss' in HF Trainer is summed across GPUs in DDP/DeepSpeed.
We track the actual per-sample loss and report it as 'true_loss'.
"""
if self._loss_count > 0:
# Calculate true average loss on this device
true_loss = self._accumulated_loss / self._loss_count
logs["true_loss"] = round(true_loss, 4)
# Reset for next logging interval
self._accumulated_loss = 0.0
self._loss_count = 0
# Let HF Trainer handle W&B logging - it manages step ordering correctly
super().log(logs, start_time)
def compute_metrics(eval_pred):
"""
Compute accuracy for evaluation.
Args:
eval_pred: (logits, labels) from Trainer's prediction_loop
- logits: [num_samples, max_vars + 1] (already masked with -1e4 for invalid vars)
- labels: [num_samples]
Returns:
Dict with "accuracy" (Trainer will prefix with "eval_")
Note: eval_loss is computed automatically by Trainer from prediction_step's loss.
We don't need to compute it here.
Since invalid variables have logits β -1e4, argmax will naturally avoid them.
"""
logits, labels = eval_pred
# Accuracy: argmax prediction vs true label
preds = np.argmax(logits, axis=-1)
accuracy = float((preds == labels).mean())
return {"accuracy": accuracy}
def get_wandb_report_to():
"""
Determine if this process should log to W&B.
Only the main process (rank 0) should log to W&B to avoid creating multiple runs.
Other ranks should not log to any external service.
Returns:
["wandb"] for rank 0, [] for other ranks
"""
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
return ["wandb"]
else:
return []
# =============================================================================
# MAIN: Training pipeline
# =============================================================================
def main():
"""
Main training function.
Pipeline:
1. Parse command line arguments
2. Load tokenizer and datasets
3. Preprocess: tokenize CNF text, compute valid masks
4. Initialize model with pretrained backbone + new classification head
5. Configure training (optimizer, scheduler, logging, etc.)
6. Train and evaluate
"""
ap = argparse.ArgumentParser(
description="Train a Qwen-based variable classifier for SAT branching"
)
# Model and data arguments
ap.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B",
help="HuggingFace model ID for the backbone")
ap.add_argument("--train_jsonl", type=str, required=True,
help="Path to training data (JSONL with 'cnf' and 'label' fields)")
ap.add_argument("--valid_jsonl", type=str, required=True,
help="Path to validation data (same format)")
ap.add_argument("--output_dir", type=str, default="./out_qwen_var_sft",
help="Directory for checkpoints and logs")
ap.add_argument("--max_vars", type=int, default=500,
help="Maximum variable ID (determines output dimension)")
ap.add_argument("--max_length", type=int, default=8192,
help="Maximum sequence length in tokens (truncates longer CNFs)")
ap.add_argument("--seed", type=int, default=0,
help="Random seed for reproducibility")
# Training hyperparameters
ap.add_argument("--per_device_train_batch_size", type=int, default=1,
help="Batch size per GPU for training")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1,
help="Batch size per GPU for evaluation")
ap.add_argument("--gradient_accumulation_steps", type=int, default=8,
help="Accumulate gradients over this many steps (effective batch = this * batch_size * num_gpus)")
ap.add_argument("--learning_rate", type=float, default=5e-6,
help="Peak learning rate (after warmup). Lower than typical fine-tuning due to classification head")
ap.add_argument("--num_train_epochs", type=float, default=3.0,
help="Total training epochs")
ap.add_argument("--warmup_ratio", type=float, default=0.03,
help="Fraction of training steps for learning rate warmup")
ap.add_argument("--weight_decay", type=float, default=0.0,
help="Weight decay (L2 regularization)")
ap.add_argument("--logging_steps", type=int, default=10,
help="Log training metrics every N steps")
ap.add_argument("--eval_steps", type=int, default=200,
help="Evaluate every N steps")
ap.add_argument("--save_steps", type=int, default=200,
help="Save checkpoint every N steps")
ap.add_argument("--report_to", type=str, default="wandb",
choices=["wandb", "tensorboard", "none"],
help="Logging backend")
ap.add_argument("--deepspeed", type=str, default=None,
help="Path to DeepSpeed config JSON for distributed training")
ap.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint directory to resume from. If a directory is given, "
"the latest checkpoint in that directory will be used.")
args = ap.parse_args()
# Set random seeds for reproducibility
set_seed(args.seed)
# Load tokenizer
# Qwen uses a byte-level BPE tokenizer
tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
if tok.pad_token is None:
# Qwen doesn't have a dedicated pad token; use eos as pad
tok.pad_token = tok.eos_token
# Load datasets from JSONL files
ds = load_dataset(
"json",
data_files={"train": args.train_jsonl, "validation": args.valid_jsonl},
)
def preprocess(ex):
"""
Preprocess a single example.
Steps:
1. Tokenize the CNF text
2. Compute valid variable mask
3. Return features for training
Args:
ex: Dict with 'cnf' (str) and 'label' (int)
Returns:
Dict with input_ids, attention_mask, label, valid_mask
"""
cnf = ex["cnf"]
label = int(ex["label"])
# Tokenize CNF text
# No special prompt/instruction - the model learns to interpret raw CNF
enc = tok(
cnf,
truncation=True,
max_length=args.max_length,
padding=False # We handle padding in the collator
)
return {
"input_ids": enc["input_ids"],
"attention_mask": enc["attention_mask"],
"label": label,
"valid_mask": cnf_valid_mask(cnf, args.max_vars),
}
# Apply preprocessing to all examples
# remove_columns drops original fields (cnf, label) since we've extracted what we need
ds = ds.map(preprocess, remove_columns=ds["train"].column_names)
# Initialize model
model = QwenVarClassifier(args.model_name, max_vars=args.max_vars)
# Enable gradient checkpointing to save memory on long sequences
# This trades compute for memory by recomputing activations during backward pass
model.backbone.gradient_checkpointing_enable()
# Configure W&B logging (only rank 0 logs to avoid duplicate runs)
report_to = get_wandb_report_to()
# Configure training
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
# Precision settings for modern GPUs
bf16=True, # Use bfloat16 for training (good for H100/A100)
tf32=True, # Enable TF32 for faster matmuls on Ampere+
# Batch configuration
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
# Optimizer settings
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs,
weight_decay=args.weight_decay,
# Gradient clipping for training stability
# Clips gradient norm to this value if it exceeds it
# This prevents exploding gradients from destabilizing training
max_grad_norm=1.0,
# Logging and evaluation
logging_steps=args.logging_steps,
eval_strategy="steps",
eval_steps=args.eval_steps,
# Checkpointing - keep best checkpoints based on validation accuracy
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=3, # Keep best 3 checkpoints
load_best_model_at_end=True, # Load best checkpoint at end of training
metric_for_best_model="eval_accuracy", # Use validation accuracy to determine best
greater_is_better=True, # Higher accuracy is better
# Logging backend
report_to=report_to,
run_name=os.environ.get("WANDB_RUN_NAME", "qwen-var-sft") if args.report_to == "wandb" else None,
logging_dir=os.path.join(args.output_dir, "logs"),
# Important: don't remove valid_mask column (we need it in compute_loss)
remove_unused_columns=False,
# DDP settings (for multi-GPU)
ddp_find_unused_parameters=False,
# DeepSpeed for efficient distributed training
deepspeed=args.deepspeed,
# Use pickle format for saving (safetensors has issues with some weight tying configs)
save_safetensors=False,
)
# Create trainer with custom loss computation
trainer = MaskedVarTrainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=ds["validation"],
tokenizer=tok,
data_collator=Collator(tok),
compute_metrics=compute_metrics,
max_vars=args.max_vars,
)
# Train! (resume from checkpoint if specified)
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
# Final evaluation
trainer.evaluate()
if __name__ == "__main__":
main()
|