Qwen3-4B-SAT-VarSelector / sft_qwen_var_classifier.py
erata's picture
Upload folder using huggingface_hub
24a71c4 verified
"""
Qwen Variable Classifier for SAT Cube-and-Conquer
This script trains a transformer-based policy to select the next branching variable
for SAT (Boolean Satisfiability) solving using the Cube-and-Conquer approach.
== Problem Overview ==
In Cube-and-Conquer SAT solving, we split a hard SAT problem into subproblems ("cubes")
by choosing variables to branch on. The quality of variable selection significantly
affects solving performance. This model learns to predict good branching variables
from expert demonstrations.
== Architecture ==
- Backbone: Qwen3-4B (pretrained causal language model)
- Head: LayerNorm + Linear classifier over variable IDs (1 to max_vars)
- The model reads a CNF formula as text and outputs logits for each possible variable
== Training Approach ==
- Supervised Fine-Tuning (SFT) on expert variable choices
- Masked classification: only variables appearing in the CNF are valid choices
- Loss: Cross-entropy with invalid variable logits masked to -infinity
== Data Format ==
JSONL with fields:
- "cnf": DIMACS-format CNF text (e.g., "p cnf 100 200\n1 -2 3 0\n...")
- "label": integer variable ID to branch on (1 to max_vars)
"""
import os
import argparse
from dataclasses import dataclass
from typing import Any, Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed,
)
# =============================================================================
# DEBUG FLAG: Set to True to enable verbose debug output, False to disable
# Can also be controlled via environment variable: DEBUG_TRAINING=1
# =============================================================================
DEBUG_TRAINING = os.environ.get("DEBUG_TRAINING", "0") == "1"
# =============================================================================
# CNF PARSING: Extract valid variables from DIMACS CNF text
# =============================================================================
def cnf_valid_mask(cnf_text: str, max_vars: int) -> List[int]:
"""
Build a binary mask indicating which variable IDs appear in the CNF.
This is crucial for masked classification:
- A variable that doesn't appear in the (simplified) CNF cannot be branched on
- By masking invalid variables, we ensure the model only learns over valid choices
Args:
cnf_text: DIMACS-format CNF string. Format example:
p cnf 100 200 # header: 100 variables, 200 clauses
1 -2 3 0 # clause: (x1 OR NOT x2 OR x3)
-1 4 0 # clause: (NOT x1 OR x4)
...
max_vars: Maximum variable ID supported (typically 500)
Returns:
List of length (max_vars + 1) where:
- mask[0] = 0 (unused, variables are 1-indexed)
- mask[v] = 1 if variable v appears in any clause
- mask[v] = 0 if variable v does not appear
Note: We skip the header line "p cnf ..." to avoid capturing the clause count
as a valid variable (which was a bug in the original regex-based approach).
"""
mask = [0] * (max_vars + 1)
for line in cnf_text.split('\n'):
line = line.strip()
# Skip empty lines, comment lines (start with 'c'), and header line (starts with 'p')
# The header "p cnf <num_vars> <num_clauses>" would incorrectly add num_clauses as a variable
if not line or line.startswith('c') or line.startswith('p'):
continue
# Parse clause: space-separated integers ending with 0
# Each integer is a literal: positive = variable, negative = negated variable
# Example: "1 -2 3 0" means (x1 OR NOT x2 OR x3)
for tok in line.split():
try:
lit = int(tok)
v = abs(lit) # Variable ID is absolute value of literal
if 1 <= v <= max_vars:
mask[v] = 1
except ValueError:
continue # Skip non-integer tokens (shouldn't happen in valid DIMACS)
# Fallback: if no variables found (e.g., truncated/malformed input), allow all
# This prevents the model from having zero valid outputs
if sum(mask) == 0:
for v in range(1, max_vars + 1):
mask[v] = 1
return mask
# =============================================================================
# MODEL: Qwen backbone with classification head for variable selection
# =============================================================================
class QwenVarClassifier(nn.Module):
"""
Transformer-based variable classifier for SAT branching.
Architecture:
Input (CNF text)
→ Tokenize
→ Qwen3-4B backbone (frozen initially, fine-tuned with small LR)
→ Extract last token's hidden state (sequence pooling)
→ LayerNorm (stabilizes hidden state magnitude)
→ Linear head (hidden_dim → num_classes)
→ Logits for each variable ID
Why this architecture?
1. Pretrained LLM backbone understands text structure and can learn CNF patterns
2. Last-token pooling: the final token has attended to the entire input
3. LayerNorm: Qwen's hidden states have large magnitudes; normalizing prevents
exploding gradients when combined with randomly-initialized head
4. Single linear head: simple, interpretable, efficient
"""
def __init__(self, base_model_name: str, max_vars: int):
"""
Initialize the classifier.
Args:
base_model_name: HuggingFace model ID (e.g., "Qwen/Qwen3-4B")
max_vars: Maximum variable ID to classify (e.g., 500)
Output dimension will be max_vars + 1 (index 0 unused)
"""
super().__init__()
self.max_vars = max_vars
# Load Qwen configuration and enable hidden state output
cfg = AutoConfig.from_pretrained(base_model_name)
cfg.output_hidden_states = True # We need hidden states, not just logits
# Load pretrained Qwen model
# Using bfloat16 for memory efficiency on modern GPUs (H100, A100)
self.backbone = AutoModelForCausalLM.from_pretrained(
base_model_name,
config=cfg,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
hidden = self.backbone.config.hidden_size # e.g., 2560 for Qwen3-4B
# LayerNorm to normalize hidden states before classification
# This is critical for stable training:
# - Qwen's hidden states can have large magnitude (std >> 1)
# - Randomly initialized linear head expects normalized inputs
# - Without LayerNorm, initial logits can be huge → high loss → exploding gradients
self.head_ln = nn.LayerNorm(hidden)
# Classification head: maps hidden state to variable logits
# Output shape: [batch, max_vars + 1]
# Index 0 is unused (variables are 1-indexed in DIMACS)
self.head = nn.Linear(hidden, max_vars + 1)
# Initialize head with standard small weights
# LayerNorm ensures the input has unit variance, so this init is appropriate
nn.init.normal_(self.head.weight, std=0.02)
nn.init.zeros_(self.head.bias)
# Expose backbone config for DeepSpeed compatibility
# DeepSpeed checks model.config.hidden_size for auto-configuration
self.config = self.backbone.config
def forward(self, input_ids, attention_mask, **kwargs):
"""
Forward pass: CNF tokens → variable logits.
Args:
input_ids: [batch, seq_len] token IDs from tokenizer
attention_mask: [batch, seq_len] binary mask (1 = real token, 0 = padding)
**kwargs: ignored (allows passing 'labels' without error during eval)
Returns:
dict with "logits": [batch, max_vars + 1] raw classification logits
"""
# Run through Qwen backbone
out = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True, # Need hidden states, not LM logits
use_cache=False, # Disable KV cache (not needed for training)
)
# Get hidden states from the last transformer layer
# Shape: [batch, seq_len, hidden_dim]
h = out.hidden_states[-1]
# Pool by taking the last non-padding token's hidden state
# This is the standard approach for causal LMs (like using [CLS] for BERT)
#
# Why last token?
# - In causal attention, each token only sees previous tokens
# - The last token has attended to the entire input sequence
# - It's a natural "summary" of the input
#
# Compute index of last real token: sum of attention mask minus 1
last_idx = attention_mask.sum(dim=1) - 1 # [batch]
last_idx = last_idx.clamp(min=0) # Safety: ensure non-negative
# Gather hidden state at the last token position for each batch element
b = torch.arange(h.size(0), device=h.device)
pooled = h[b, last_idx] # [batch, hidden_dim]
# DEBUG: Check hidden state stats
if DEBUG_TRAINING:
if not hasattr(self, '_debug_count'):
self._debug_count = 0
if self._debug_count < 3:
print(f"[DEBUG {self._debug_count}] pooled dtype={pooled.dtype}, mean={pooled.float().mean():.2f}, std={pooled.float().std():.2f}")
self._debug_count += 1
# Normalize hidden states for stable classification
pooled = self.head_ln(pooled)
# DEBUG: Check after LayerNorm
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3:
print(f"[DEBUG] after LN: dtype={pooled.dtype}, mean={pooled.float().mean():.4f}, std={pooled.float().std():.4f}")
# Project to variable logits
logits = self.head(pooled) # [batch, max_vars + 1]
# DEBUG: Check logits
if DEBUG_TRAINING and hasattr(self, '_debug_count') and self._debug_count <= 3:
print(f"[DEBUG] logits: dtype={logits.dtype}, mean={logits.float().mean():.2f}, std={logits.float().std():.2f}, min={logits.float().min():.2f}, max={logits.float().max():.2f}")
return {"logits": logits}
# =============================================================================
# DATA COLLATOR: Batch preparation with padding and mask handling
# =============================================================================
@dataclass
class Collator:
"""
Custom data collator for variable classification.
Responsibilities:
1. Pad variable-length token sequences to the same length within a batch
2. Stack labels and valid_mask tensors
3. Create proper attention masks for padded sequences
Why custom collator?
- We have custom fields (valid_mask) that need special handling
- Standard HF collators don't know about our mask format
"""
tokenizer: Any # Tokenizer for padding configuration
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""
Collate a list of examples into a batch.
Args:
features: List of dicts, each with:
- input_ids: List[int] - token IDs
- attention_mask: List[int] - attention mask
- label: int - target variable ID
- valid_mask: List[int] - binary mask of valid variables
Returns:
Dict with batched tensors:
- input_ids: [batch, max_seq_len]
- attention_mask: [batch, max_seq_len]
- labels: [batch]
- valid_mask: [batch, max_vars + 1]
"""
# Convert to tensors
input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
labels = torch.tensor([f["label"] for f in features], dtype=torch.long)
valid_mask = torch.tensor([f["valid_mask"] for f in features], dtype=torch.bool)
# Pad sequences to same length within batch
# Using pad_sequence pads shorter sequences with padding_value
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
attention_mask,
batch_first=True,
padding_value=0 # Padding positions get 0 attention
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"valid_mask": valid_mask,
}
# =============================================================================
# TRAINER: Custom loss computation with variable masking
# =============================================================================
class MaskedVarTrainer(Trainer):
"""
Custom HuggingFace Trainer with masked cross-entropy loss.
The key modification: before computing cross-entropy, we mask out logits
for invalid variables (those not appearing in the CNF). This ensures:
1. The model cannot predict invalid variables
2. No gradient flows to invalid variable logits
3. Training focuses only on distinguishing valid choices
NOTE on displayed metrics:
- 'loss' shown by Trainer is summed across GPUs (loss × world_size)
We add 'true_loss' which is the actual per-sample loss
- 'grad_norm' is the L2 norm across ALL ~4B parameters BEFORE clipping
Values of 100-200 are normal for large models; it gets clipped to max_grad_norm
"""
def __init__(self, *args, max_vars: int, **kwargs):
"""
Args:
max_vars: Maximum variable ID (for sanity checking labels)
*args, **kwargs: Passed to parent Trainer
"""
super().__init__(*args, **kwargs)
self.max_vars = max_vars
self._accumulated_loss = 0.0
self._loss_count = 0
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute masked cross-entropy loss for variable classification.
Algorithm:
1. Extract labels and valid_mask from inputs
2. Forward pass to get logits
3. Set logits for invalid variables to -inf (or -1e4 for bf16 stability)
4. Compute cross-entropy loss
Args:
model: The QwenVarClassifier
inputs: Dict with input_ids, attention_mask, labels, valid_mask
return_outputs: If True, return (loss, outputs) tuple
num_items_in_batch: Unused (for API compatibility)
Returns:
loss: Scalar loss value, or (loss, outputs) tuple if return_outputs=True
"""
# Get labels and mask (don't pop - prediction_loop needs labels for compute_metrics)
labels = inputs.get("labels") # [batch]
valid_mask = inputs.get("valid_mask") # [batch, max_vars + 1] boolean
# Remove from inputs for model.forward (which doesn't expect them)
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]}
# Forward pass
outputs = model(**model_inputs)
logits = outputs["logits"] # [batch, max_vars + 1]
# DEBUG: Check if label is in valid_mask
if DEBUG_TRAINING:
if not hasattr(self, '_loss_debug_count'):
self._loss_debug_count = 0
if self._loss_debug_count < 5:
for i, (lbl, vmask) in enumerate(zip(labels, valid_mask)):
label_in_mask = vmask[lbl].item()
valid_count = vmask.sum().item()
logit_at_label = logits[i, lbl].item()
print(f"[LOSS DEBUG {self._loss_debug_count}] label={lbl.item()}, in_mask={label_in_mask}, valid_vars={valid_count}, logit_at_label={logit_at_label:.2f}")
self._loss_debug_count += 1
# Mask invalid variables by setting their logits to a large negative value
# After softmax, these will have probability ≈ 0
#
# Why -1e4 instead of -inf or -1e9?
# - bfloat16 has limited dynamic range
# - -1e9 can cause NaN issues when computing softmax/cross-entropy
# - -1e4 is small enough to give ~0 probability while staying numerically stable
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4)
# Sanity check: labels must be valid variable IDs (1 to max_vars)
# This catches data bugs early
if torch.any(labels <= 0) or torch.any(labels > self.max_vars):
bad = labels[(labels <= 0) | (labels > self.max_vars)].detach().cpu().tolist()
raise ValueError(f"Out-of-range labels detected (showing up to 20): {bad[:20]}")
# DEBUG: Check logit at label after masking
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5:
for i, lbl in enumerate(labels):
masked_logit = logits[i, lbl].item()
print(f"[LOSS DEBUG] after mask: logit_at_label={masked_logit:.2f}")
# Standard cross-entropy loss
# PyTorch's cross_entropy expects logits, not probabilities
loss = F.cross_entropy(logits, labels.to(logits.device))
# Track true loss for accurate logging
self._accumulated_loss += loss.item()
self._loss_count += 1
# DEBUG: Print loss
if DEBUG_TRAINING and hasattr(self, '_loss_debug_count') and self._loss_debug_count <= 5:
print(f"[LOSS DEBUG] loss={loss.item():.2f}")
# Return masked logits in outputs (so compute_metrics gets properly masked predictions)
masked_outputs = {"logits": logits}
return (loss, masked_outputs) if return_outputs else loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""
Override prediction_step to properly return loss and logits for evaluation.
The default HF Trainer prediction_step doesn't work well with custom compute_loss,
so we implement our own that properly computes masked loss and returns logits.
"""
model.eval()
with torch.no_grad():
# Get labels and mask
labels = inputs.get("labels")
valid_mask = inputs.get("valid_mask")
# Forward pass
model_inputs = {k: v for k, v in inputs.items() if k not in ["labels", "valid_mask"]}
outputs = model(**model_inputs)
logits = outputs["logits"]
# Mask invalid variables
logits = logits.masked_fill(~valid_mask.to(logits.device), -1e4)
# Compute loss
loss = F.cross_entropy(logits, labels.to(logits.device))
# Return (loss, logits, labels) - this is what compute_metrics expects
return (loss, logits.detach(), labels.detach())
def log(self, logs: Dict[str, float], start_time: float = None) -> None:
"""
Override log to add true_loss and ensure eval metrics are logged to W&B.
The default 'loss' in HF Trainer is summed across GPUs in DDP/DeepSpeed.
We track the actual per-sample loss and report it as 'true_loss'.
"""
if self._loss_count > 0:
# Calculate true average loss on this device
true_loss = self._accumulated_loss / self._loss_count
logs["true_loss"] = round(true_loss, 4)
# Reset for next logging interval
self._accumulated_loss = 0.0
self._loss_count = 0
# Let HF Trainer handle W&B logging - it manages step ordering correctly
super().log(logs, start_time)
def compute_metrics(eval_pred):
"""
Compute accuracy for evaluation.
Args:
eval_pred: (logits, labels) from Trainer's prediction_loop
- logits: [num_samples, max_vars + 1] (already masked with -1e4 for invalid vars)
- labels: [num_samples]
Returns:
Dict with "accuracy" (Trainer will prefix with "eval_")
Note: eval_loss is computed automatically by Trainer from prediction_step's loss.
We don't need to compute it here.
Since invalid variables have logits ≈ -1e4, argmax will naturally avoid them.
"""
logits, labels = eval_pred
# Accuracy: argmax prediction vs true label
preds = np.argmax(logits, axis=-1)
accuracy = float((preds == labels).mean())
return {"accuracy": accuracy}
def get_wandb_report_to():
"""
Determine if this process should log to W&B.
Only the main process (rank 0) should log to W&B to avoid creating multiple runs.
Other ranks should not log to any external service.
Returns:
["wandb"] for rank 0, [] for other ranks
"""
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
return ["wandb"]
else:
return []
# =============================================================================
# MAIN: Training pipeline
# =============================================================================
def main():
"""
Main training function.
Pipeline:
1. Parse command line arguments
2. Load tokenizer and datasets
3. Preprocess: tokenize CNF text, compute valid masks
4. Initialize model with pretrained backbone + new classification head
5. Configure training (optimizer, scheduler, logging, etc.)
6. Train and evaluate
"""
ap = argparse.ArgumentParser(
description="Train a Qwen-based variable classifier for SAT branching"
)
# Model and data arguments
ap.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B",
help="HuggingFace model ID for the backbone")
ap.add_argument("--train_jsonl", type=str, required=True,
help="Path to training data (JSONL with 'cnf' and 'label' fields)")
ap.add_argument("--valid_jsonl", type=str, required=True,
help="Path to validation data (same format)")
ap.add_argument("--output_dir", type=str, default="./out_qwen_var_sft",
help="Directory for checkpoints and logs")
ap.add_argument("--max_vars", type=int, default=500,
help="Maximum variable ID (determines output dimension)")
ap.add_argument("--max_length", type=int, default=8192,
help="Maximum sequence length in tokens (truncates longer CNFs)")
ap.add_argument("--seed", type=int, default=0,
help="Random seed for reproducibility")
# Training hyperparameters
ap.add_argument("--per_device_train_batch_size", type=int, default=1,
help="Batch size per GPU for training")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1,
help="Batch size per GPU for evaluation")
ap.add_argument("--gradient_accumulation_steps", type=int, default=8,
help="Accumulate gradients over this many steps (effective batch = this * batch_size * num_gpus)")
ap.add_argument("--learning_rate", type=float, default=5e-6,
help="Peak learning rate (after warmup). Lower than typical fine-tuning due to classification head")
ap.add_argument("--num_train_epochs", type=float, default=3.0,
help="Total training epochs")
ap.add_argument("--warmup_ratio", type=float, default=0.03,
help="Fraction of training steps for learning rate warmup")
ap.add_argument("--weight_decay", type=float, default=0.0,
help="Weight decay (L2 regularization)")
ap.add_argument("--logging_steps", type=int, default=10,
help="Log training metrics every N steps")
ap.add_argument("--eval_steps", type=int, default=200,
help="Evaluate every N steps")
ap.add_argument("--save_steps", type=int, default=200,
help="Save checkpoint every N steps")
ap.add_argument("--report_to", type=str, default="wandb",
choices=["wandb", "tensorboard", "none"],
help="Logging backend")
ap.add_argument("--deepspeed", type=str, default=None,
help="Path to DeepSpeed config JSON for distributed training")
args = ap.parse_args()
# Set random seeds for reproducibility
set_seed(args.seed)
# Load tokenizer
# Qwen uses a byte-level BPE tokenizer
tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
if tok.pad_token is None:
# Qwen doesn't have a dedicated pad token; use eos as pad
tok.pad_token = tok.eos_token
# Load datasets from JSONL files
ds = load_dataset(
"json",
data_files={"train": args.train_jsonl, "validation": args.valid_jsonl},
)
def preprocess(ex):
"""
Preprocess a single example.
Steps:
1. Tokenize the CNF text
2. Compute valid variable mask
3. Return features for training
Args:
ex: Dict with 'cnf' (str) and 'label' (int)
Returns:
Dict with input_ids, attention_mask, label, valid_mask
"""
cnf = ex["cnf"]
label = int(ex["label"])
# Tokenize CNF text
# No special prompt/instruction - the model learns to interpret raw CNF
enc = tok(
cnf,
truncation=True,
max_length=args.max_length,
padding=False # We handle padding in the collator
)
return {
"input_ids": enc["input_ids"],
"attention_mask": enc["attention_mask"],
"label": label,
"valid_mask": cnf_valid_mask(cnf, args.max_vars),
}
# Apply preprocessing to all examples
# remove_columns drops original fields (cnf, label) since we've extracted what we need
ds = ds.map(preprocess, remove_columns=ds["train"].column_names)
# Initialize model
model = QwenVarClassifier(args.model_name, max_vars=args.max_vars)
# Enable gradient checkpointing to save memory on long sequences
# This trades compute for memory by recomputing activations during backward pass
model.backbone.gradient_checkpointing_enable()
# Configure W&B logging (only rank 0 logs to avoid duplicate runs)
report_to = get_wandb_report_to()
# Configure training
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
# Precision settings for modern GPUs
bf16=True, # Use bfloat16 for training (good for H100/A100)
tf32=True, # Enable TF32 for faster matmuls on Ampere+
# Batch configuration
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
# Optimizer settings
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs,
weight_decay=args.weight_decay,
# Gradient clipping for training stability
# Clips gradient norm to this value if it exceeds it
# This prevents exploding gradients from destabilizing training
max_grad_norm=1.0,
# Logging and evaluation
logging_steps=args.logging_steps,
eval_strategy="steps",
eval_steps=args.eval_steps,
# Checkpointing - keep best checkpoints based on validation accuracy
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=3, # Keep best 3 checkpoints
load_best_model_at_end=True, # Load best checkpoint at end of training
metric_for_best_model="eval_accuracy", # Use validation accuracy to determine best
greater_is_better=True, # Higher accuracy is better
# Logging backend
report_to=report_to,
run_name=os.environ.get("WANDB_RUN_NAME", "qwen-var-sft") if args.report_to == "wandb" else None,
logging_dir=os.path.join(args.output_dir, "logs"),
# Important: don't remove valid_mask column (we need it in compute_loss)
remove_unused_columns=False,
# DDP settings (for multi-GPU)
ddp_find_unused_parameters=False,
# DeepSpeed for efficient distributed training
deepspeed=args.deepspeed,
# Use pickle format for saving (safetensors has issues with some weight tying configs)
save_safetensors=False,
)
# Create trainer with custom loss computation
trainer = MaskedVarTrainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=ds["validation"],
tokenizer=tok,
data_collator=Collator(tok),
compute_metrics=compute_metrics,
max_vars=args.max_vars,
)
# Train!
trainer.train()
# Final evaluation
trainer.evaluate()
if __name__ == "__main__":
main()