|
|
""" |
|
|
Qwen Variable Classifier for SAT Cube-and-Conquer |
|
|
|
|
|
This script trains a transformer-based policy to select the next branching variable |
|
|
for SAT (Boolean Satisfiability) solving using the Cube-and-Conquer approach. |
|
|
|
|
|
== Problem Overview == |
|
|
In Cube-and-Conquer SAT solving, we split a hard SAT problem into subproblems ("cubes") |
|
|
by choosing variables to branch on. The quality of variable selection significantly |
|
|
affects solving performance. This model learns to predict good branching variables |
|
|
from expert demonstrations. |
|
|
|
|
|
== Architecture == |
|
|
- Backbone: Qwen3-4B (pretrained causal language model) |
|
|
- Head: LayerNorm + Linear classifier over variable IDs (1 to max_vars) |
|
|
- The model reads a CNF formula as text and outputs logits for each possible variable |
|
|
|
|
|
== Training Approach == |
|
|
- Supervised Fine-Tuning (SFT) on expert variable choices |
|
|
- Masked classification: only variables appearing in the CNF are valid choices |
|
|
- Loss: Cross-entropy with invalid variable logits masked to -infinity |
|
|
|
|
|
== Data Format == |
|
|
JSONL with fields: |
|
|
- "cnf": DIMACS-format CNF text (e.g., "p cnf 100 200\n1 -2 3 0\n...") |
|
|
- "label": integer variable ID to branch on (1 to max_vars) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
set_seed, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_TRAINING = os.environ.get("DEBUG_TRAINING", "0") == "1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cnf_valid_mask(cnf_text: str, max_vars: int) -> List[int]: |
|
|
""" |
|
|
Build a binary mask indicating which variable IDs appear in the CNF. |
|
|
|
|
|
This is crucial for masked classification: |
|
|
- A variable that doesn't appear in the (simplified) CNF cannot be branched on |
|
|
- By masking invalid variables, we ensure the model only learns over valid choices |
|
|
|
|
|
Args: |
|
|
cnf_text: DIMACS-format CNF string. Format example: |
|
|
p cnf 100 200 # header: 100 variables, 200 clauses |
|
|
1 -2 3 0 # clause: (x1 OR NOT x2 OR x3) |
|
|
-1 4 0 # clause: (NOT x1 OR x4) |
|
|
... |
|
|
max_vars: Maximum variable ID supported (typically 500) |
|
|
|
|
|
Returns: |
|
|
List of length (max_vars + 1) where: |
|
|
- mask[0] = 0 (unused, variables are 1-indexed) |
|
|
- mask[v] = 1 if variable v appears in any clause |
|
|
- mask[v] = 0 if variable v does not appear |
|
|
|
|
|
Note: We skip the header line "p cnf ..." to avoid capturing the clause count |
|
|
as a valid variable (which was a bug in the original regex-based approach). |
|
|
""" |
|
|
mask = [0] * (max_vars + 1) |
|
|
|
|
|
for line in cnf_text.split('\n'): |
|
|
line = line.strip() |
|
|
|
|
|
|
|
|
|
|
|
if not line or line.startswith('c') or line.startswith('p'): |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for tok in line.split(): |
|
|
try: |
|
|
lit = int(tok) |
|
|
v = abs(lit) |
|
|
if 1 <= v <= max_vars: |
|
|
mask[v] = 1 |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if sum(mask) == 0: |
|
|
for v in range(1, max_vars + 1): |
|
|
mask[v] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QwenVarClassifier(nn.Module): |
|
|
""" |
|
|
Transformer-based variable classifier for SAT branching. |
|
|
|
|
|
Architecture: |
|
|
Input (CNF text) |
|
|
→ Tokenize |
|
|
→ Qwen3-4B backbone (frozen initially, fine-tuned with small LR) |
|
|
→ Extract last token's hidden state (sequence pooling) |
|
|
→ LayerNorm (stabilizes hidden state magnitude) |
|
|
→ Linear head (hidden_dim → num_classes) |
|
|
→ Logits for each variable ID |
|
|
|
|
|
Why this architecture? |
|
|
1. Pretrained LLM backbone understands text structure and can learn CNF patterns |
|
|
2. Last-token pooling: the final token has attended to the entire input |
|
|
3. LayerNorm: Qwen's hidden states have large magnitudes; normalizing prevents |
|
|
exploding gradients when combined with randomly-initialized head |
|
|
4. Single linear head: simple, interpretable, efficient |
|
|
""" |
|
|
|
|
|
def __init__(self, base_model_name: str, max_vars: int): |
|
|
""" |
|
|
Initialize the classifier. |
|
|
|
|
|
Args: |
|
|
base_model_name: HuggingFace model ID (e.g., "Qwen/Qwen3-4B") |
|
|
max_vars: Maximum variable ID to classify (e.g., 500) |
|
|
Output dimension will be max_vars + 1 (index 0 unused) |
|
|
""" |
|
|
super().__init__() |
|
|
self.max_vars = max_vars |
|
|
|
|
|
|
|
|
cfg = AutoConfig.from_pretrained(base_model_name) |
|
|
cfg.output_hidden_states = True |
|
|
|
|
|
|
|
|
|
|
|
self.backbone = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_name, |
|
|
config=cfg, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
) |
|
|
|
|
|
hidden = self.backbone.config.hidden_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.head_ln = nn.LayerNorm(hidden) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Linear(hidden, max_vars + 1) |
|
|
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.head.weight, std=0.02) |
|
|
nn.init.zeros_(self.head.bias) |
|
|
|
|
|
|
|
|
|
|
|
self.config = self.backbone.config |
|
|
|
|
|
def forward(self, input_ids, attention_mask, **kwargs): |
|
|
""" |
|
|
Forward pass: CNF tokens → variable logits. |
|
|
|
|
|
Args: |
|
|
input_ids: [batch, seq_len] token IDs from tokenizer |
|
|
attention_mask: [batch, seq_len] binary mask (1 = real token, 0 = padding) |
|
|
**kwargs: ignored (allows passing 'labels' without error during eval) |
|
|
|
|
|
Returns: |
|
|
dict with "logits": [batch, max_vars + 1] raw classification logits |
|
|
""" |
|
|
|
|
|
out = self.backbone( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
h = out.hidden_states[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
last_idx = attention_mask.sum(dim=1) - 1 |
|
|
last_idx = last_idx.clamp(min=0) |
|
|
|
|
|
|
|
|
b = torch.arange(h.size(0), device=h.device) |
|
|
pooled = h[b, last_idx] |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING: |
|
|
if not hasattr(self, '_debug_count'): |
|
|
self._debug_count = 0 |
|
|
if self._debug_count < 3: |
|
|
print(f"[DEBUG {self._debug_count}] pooled dtype={pooled.dtype}, mean={pooled.float().mean():.2f}, std={pooled.float().std():.2f}") |
|
|
self._debug_count += 1 |
|
|
|
|
|
|
|
|
pooled = self.head_ln(pooled) |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: |
|
|
print(f"[DEBUG] after LN: dtype={pooled.dtype}, mean={pooled.float().mean():.4f}, std={pooled.float().std():.4f}") |
|
|
|
|
|
|
|
|
logits = self.head(pooled) |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3: |
|
|
print(f"[DEBUG] logits: dtype={logits.dtype}, mean={logits.float().mean():.2f}, std={logits.float().std():.2f}, min={logits.float().min():.2f}, max={logits.float().max():.2f}") |
|
|
|
|
|
return {"logits": logits} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Collator: |
|
|
""" |
|
|
Custom data collator for variable classification. |
|
|
|
|
|
Responsibilities: |
|
|
1. Pad variable-length token sequences to the same length within a batch |
|
|
2. Stack labels and valid_mask tensors |
|
|
3. Create proper attention masks for padded sequences |
|
|
|
|
|
Why custom collator? |
|
|
- We have custom fields (valid_mask) that need special handling |
|
|
- Standard HF collators don't know about our mask format |
|
|
""" |
|
|
tokenizer: Any |
|
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Collate a list of examples into a batch. |
|
|
|
|
|
Args: |
|
|
features: List of dicts, each with: |
|
|
- input_ids: List[int] - token IDs |
|
|
- attention_mask: List[int] - attention mask |
|
|
- label: int - target variable ID |
|
|
- valid_mask: List[int] - binary mask of valid variables |
|
|
|
|
|
Returns: |
|
|
Dict with batched tensors: |
|
|
- input_ids: [batch, max_seq_len] |
|
|
- attention_mask: [batch, max_seq_len] |
|
|
- labels: [batch] |
|
|
- valid_mask: [batch, max_vars + 1] |
|
|
""" |
|
|
|
|
|
input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features] |
|
|
attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features] |
|
|
labels = torch.tensor([f["label"] for f in features], dtype=torch.long) |
|
|
valid_mask = torch.tensor([f["valid_mask"] for f in features], dtype=torch.bool) |
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
|
input_ids, |
|
|
batch_first=True, |
|
|
padding_value=self.tokenizer.pad_token_id |
|
|
) |
|
|
attention_mask = torch.nn.utils.rnn.pad_sequence( |
|
|
attention_mask, |
|
|
batch_first=True, |
|
|
padding_value=0 |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
"valid_mask": valid_mask, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskedVarTrainer(Trainer): |
|
|
""" |
|
|
Custom HuggingFace Trainer with masked cross-entropy loss. |
|
|
|
|
|
The key modification: before computing cross-entropy, we mask out logits |
|
|
for invalid variables (those not appearing in the CNF). This ensures: |
|
|
1. The model cannot predict invalid variables |
|
|
2. No gradient flows to invalid variable logits |
|
|
3. Training focuses only on distinguishing valid choices |
|
|
|
|
|
NOTE on displayed metrics: |
|
|
- 'loss' shown by Trainer is summed across GPUs (loss × world_size) |
|
|
We add 'true_loss' which is the actual per-sample loss |
|
|
- 'grad_norm' is the L2 norm across ALL ~4B parameters BEFORE clipping |
|
|
Values of 100-200 are normal for large models; it gets clipped to max_grad_norm |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, max_vars: int, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
max_vars: Maximum variable ID (for sanity checking labels) |
|
|
*args, **kwargs: Passed to parent Trainer |
|
|
""" |
|
|
super().__init__(*args, **kwargs) |
|
|
self.max_vars = max_vars |
|
|
self._accumulated_loss = 0.0 |
|
|
self._loss_count = 0 |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
""" |
|
|
Compute masked cross-entropy loss for variable classification. |
|
|
|
|
|
Algorithm: |
|
|
1. Extract labels and valid_mask from inputs |
|
|
2. Forward pass to get logits |
|
|
3. Set logits for invalid variables to -inf (or -1e4 for bf16 stability) |
|
|
4. Compute cross-entropy loss |
|
|
|
|
|
Args: |
|
|
model: The QwenVarClassifier |
|
|
inputs: Dict with input_ids, attention_mask, labels, valid_mask |
|
|
return_outputs: If True, return (loss, outputs) tuple |
|
|
num_items_in_batch: Unused (for API compatibility) |
|
|
|
|
|
Returns: |
|
|
loss: Scalar loss value, or (loss, outputs) tuple if return_outputs=True |
|
|
""" |
|
|
|
|
|
labels = inputs.get("labels") |
|
|
valid_mask = inputs.get("valid_mask") |
|
|
|
|
|
|
|
|
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} |
|
|
|
|
|
|
|
|
outputs = model(**model_inputs) |
|
|
logits = outputs["logits"] |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING: |
|
|
if not hasattr(self, '_loss_debug_count'): |
|
|
self._loss_debug_count = 0 |
|
|
if self._loss_debug_count < 5: |
|
|
for i, (lbl, vmask) in enumerate(zip(labels, valid_mask)): |
|
|
label_in_mask = vmask[lbl].item() |
|
|
valid_count = vmask.sum().item() |
|
|
logit_at_label = logits[i, lbl].item() |
|
|
print(f"[LOSS DEBUG {self._loss_debug_count}] label={lbl.item()}, in_mask={label_in_mask}, valid_vars={valid_count}, logit_at_label={logit_at_label:.2f}") |
|
|
self._loss_debug_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) |
|
|
|
|
|
|
|
|
|
|
|
if torch.any(labels <= 0) or torch.any(labels > self.max_vars): |
|
|
bad = labels[(labels <= 0) | (labels > self.max_vars)].detach().cpu().tolist() |
|
|
raise ValueError(f"Out-of-range labels detected (showing up to 20): {bad[:20]}") |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: |
|
|
for i, lbl in enumerate(labels): |
|
|
masked_logit = logits[i, lbl].item() |
|
|
print(f"[LOSS DEBUG] after mask: logit_at_label={masked_logit:.2f}") |
|
|
|
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits, labels.to(logits.device)) |
|
|
|
|
|
|
|
|
self._accumulated_loss += loss.item() |
|
|
self._loss_count += 1 |
|
|
|
|
|
|
|
|
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5: |
|
|
print(f"[LOSS DEBUG] loss={loss.item():.2f}") |
|
|
|
|
|
|
|
|
masked_outputs = {"logits": logits} |
|
|
return (loss, masked_outputs) if return_outputs else loss |
|
|
|
|
|
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): |
|
|
""" |
|
|
Override prediction_step to properly return loss and logits for evaluation. |
|
|
|
|
|
The default HF Trainer prediction_step doesn't work well with custom compute_loss, |
|
|
so we implement our own that properly computes masked loss and returns logits. |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
labels = inputs.get("labels") |
|
|
valid_mask = inputs.get("valid_mask") |
|
|
|
|
|
|
|
|
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]} |
|
|
outputs = model(**model_inputs) |
|
|
logits = outputs["logits"] |
|
|
|
|
|
|
|
|
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4) |
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits, labels.to(logits.device)) |
|
|
|
|
|
|
|
|
return (loss, logits.detach(), labels.detach()) |
|
|
|
|
|
def log(self, logs: Dict[str, float], start_time: float = None) -> None: |
|
|
""" |
|
|
Override log to add true_loss and ensure eval metrics are logged to W&B. |
|
|
|
|
|
The default 'loss' in HF Trainer is summed across GPUs in DDP/DeepSpeed. |
|
|
We track the actual per-sample loss and report it as 'true_loss'. |
|
|
""" |
|
|
if self._loss_count > 0: |
|
|
|
|
|
true_loss = self._accumulated_loss / self._loss_count |
|
|
logs["true_loss"] = round(true_loss, 4) |
|
|
|
|
|
|
|
|
self._accumulated_loss = 0.0 |
|
|
self._loss_count = 0 |
|
|
|
|
|
|
|
|
super().log(logs, start_time) |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
""" |
|
|
Compute accuracy for evaluation. |
|
|
|
|
|
Args: |
|
|
eval_pred: (logits, labels) from Trainer's prediction_loop |
|
|
- logits: [num_samples, max_vars + 1] (already masked with -1e4 for invalid vars) |
|
|
- labels: [num_samples] |
|
|
|
|
|
Returns: |
|
|
Dict with "accuracy" (Trainer will prefix with "eval_") |
|
|
|
|
|
Note: eval_loss is computed automatically by Trainer from prediction_step's loss. |
|
|
We don't need to compute it here. |
|
|
|
|
|
Since invalid variables have logits ≈ -1e4, argmax will naturally avoid them. |
|
|
""" |
|
|
logits, labels = eval_pred |
|
|
|
|
|
|
|
|
preds = np.argmax(logits, axis=-1) |
|
|
accuracy = float((preds == labels).mean()) |
|
|
|
|
|
return {"accuracy": accuracy} |
|
|
|
|
|
|
|
|
def get_wandb_report_to(): |
|
|
""" |
|
|
Determine if this process should log to W&B. |
|
|
|
|
|
Only the main process (rank 0) should log to W&B to avoid creating multiple runs. |
|
|
Other ranks should not log to any external service. |
|
|
|
|
|
Returns: |
|
|
["wandb"] for rank 0, [] for other ranks |
|
|
""" |
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
|
|
if local_rank == 0: |
|
|
return ["wandb"] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main training function. |
|
|
|
|
|
Pipeline: |
|
|
1. Parse command line arguments |
|
|
2. Load tokenizer and datasets |
|
|
3. Preprocess: tokenize CNF text, compute valid masks |
|
|
4. Initialize model with pretrained backbone + new classification head |
|
|
5. Configure training (optimizer, scheduler, logging, etc.) |
|
|
6. Train and evaluate |
|
|
""" |
|
|
ap = argparse.ArgumentParser( |
|
|
description="Train a Qwen-based variable classifier for SAT branching" |
|
|
) |
|
|
|
|
|
|
|
|
ap.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B", |
|
|
help="HuggingFace model ID for the backbone") |
|
|
ap.add_argument("--train_jsonl", type=str, required=True, |
|
|
help="Path to training data (JSONL with 'cnf' and 'label' fields)") |
|
|
ap.add_argument("--valid_jsonl", type=str, required=True, |
|
|
help="Path to validation data (same format)") |
|
|
ap.add_argument("--output_dir", type=str, default="./out_qwen_var_sft", |
|
|
help="Directory for checkpoints and logs") |
|
|
ap.add_argument("--max_vars", type=int, default=500, |
|
|
help="Maximum variable ID (determines output dimension)") |
|
|
ap.add_argument("--max_length", type=int, default=8192, |
|
|
help="Maximum sequence length in tokens (truncates longer CNFs)") |
|
|
ap.add_argument("--seed", type=int, default=0, |
|
|
help="Random seed for reproducibility") |
|
|
|
|
|
|
|
|
ap.add_argument("--per_device_train_batch_size", type=int, default=1, |
|
|
help="Batch size per GPU for training") |
|
|
ap.add_argument("--per_device_eval_batch_size", type=int, default=1, |
|
|
help="Batch size per GPU for evaluation") |
|
|
ap.add_argument("--gradient_accumulation_steps", type=int, default=8, |
|
|
help="Accumulate gradients over this many steps (effective batch = this * batch_size * num_gpus)") |
|
|
ap.add_argument("--learning_rate", type=float, default=5e-6, |
|
|
help="Peak learning rate (after warmup). Lower than typical fine-tuning due to classification head") |
|
|
ap.add_argument("--num_train_epochs", type=float, default=3.0, |
|
|
help="Total training epochs") |
|
|
ap.add_argument("--warmup_ratio", type=float, default=0.03, |
|
|
help="Fraction of training steps for learning rate warmup") |
|
|
ap.add_argument("--weight_decay", type=float, default=0.0, |
|
|
help="Weight decay (L2 regularization)") |
|
|
ap.add_argument("--logging_steps", type=int, default=10, |
|
|
help="Log training metrics every N steps") |
|
|
ap.add_argument("--eval_steps", type=int, default=200, |
|
|
help="Evaluate every N steps") |
|
|
ap.add_argument("--save_steps", type=int, default=200, |
|
|
help="Save checkpoint every N steps") |
|
|
ap.add_argument("--report_to", type=str, default="wandb", |
|
|
choices=["wandb", "tensorboard", "none"], |
|
|
help="Logging backend") |
|
|
ap.add_argument("--deepspeed", type=str, default=None, |
|
|
help="Path to DeepSpeed config JSON for distributed training") |
|
|
|
|
|
args = ap.parse_args() |
|
|
|
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) |
|
|
if tok.pad_token is None: |
|
|
|
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
|
|
|
ds = load_dataset( |
|
|
"json", |
|
|
data_files={"train": args.train_jsonl, "validation": args.valid_jsonl}, |
|
|
) |
|
|
|
|
|
def preprocess(ex): |
|
|
""" |
|
|
Preprocess a single example. |
|
|
|
|
|
Steps: |
|
|
1. Tokenize the CNF text |
|
|
2. Compute valid variable mask |
|
|
3. Return features for training |
|
|
|
|
|
Args: |
|
|
ex: Dict with 'cnf' (str) and 'label' (int) |
|
|
|
|
|
Returns: |
|
|
Dict with input_ids, attention_mask, label, valid_mask |
|
|
""" |
|
|
cnf = ex["cnf"] |
|
|
label = int(ex["label"]) |
|
|
|
|
|
|
|
|
|
|
|
enc = tok( |
|
|
cnf, |
|
|
truncation=True, |
|
|
max_length=args.max_length, |
|
|
padding=False |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": enc["input_ids"], |
|
|
"attention_mask": enc["attention_mask"], |
|
|
"label": label, |
|
|
"valid_mask": cnf_valid_mask(cnf, args.max_vars), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ds = ds.map(preprocess, remove_columns=ds["train"].column_names) |
|
|
|
|
|
|
|
|
model = QwenVarClassifier(args.model_name, max_vars=args.max_vars) |
|
|
|
|
|
|
|
|
|
|
|
model.backbone.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
report_to = get_wandb_report_to() |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
overwrite_output_dir=True, |
|
|
|
|
|
|
|
|
bf16=True, |
|
|
tf32=True, |
|
|
|
|
|
|
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
|
|
|
|
|
|
learning_rate=args.learning_rate, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
weight_decay=args.weight_decay, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_grad_norm=1.0, |
|
|
|
|
|
|
|
|
logging_steps=args.logging_steps, |
|
|
eval_strategy="steps", |
|
|
eval_steps=args.eval_steps, |
|
|
|
|
|
|
|
|
save_strategy="steps", |
|
|
save_steps=args.save_steps, |
|
|
save_total_limit=3, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_accuracy", |
|
|
greater_is_better=True, |
|
|
|
|
|
|
|
|
report_to=report_to, |
|
|
run_name=os.environ.get("WANDB_RUN_NAME", "qwen-var-sft") if args.report_to == "wandb" else None, |
|
|
logging_dir=os.path.join(args.output_dir, "logs"), |
|
|
|
|
|
|
|
|
remove_unused_columns=False, |
|
|
|
|
|
|
|
|
ddp_find_unused_parameters=False, |
|
|
|
|
|
|
|
|
deepspeed=args.deepspeed, |
|
|
|
|
|
|
|
|
save_safetensors=False, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = MaskedVarTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=ds["train"], |
|
|
eval_dataset=ds["validation"], |
|
|
tokenizer=tok, |
|
|
data_collator=Collator(tok), |
|
|
compute_metrics=compute_metrics, |
|
|
max_vars=args.max_vars, |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.evaluate() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|