LorenzoNava's picture
feat: Add K-Fold Cross-Validation to train.py
0af6aee
#!/usr/bin/env python3
"""
DeBERTa CWE Classification - Training Script with Cross-Validation
===================================================================
Direct training script without Gradio UI. Run via SSH.
Optimized for 4x NVIDIA L4 GPUs (24GB each = 96GB total VRAM)
- Gradient checkpointing DISABLED (we have plenty of VRAM)
- Batch size optimized for maximum GPU utilization
- Quality-focused training parameters
- K-Fold Cross-Validation support
Usage:
python3 train.py --model deberta-v3-base --epochs 10 --batch-size 32
python3 train.py --model deberta-v3-base --epochs 10 --batch-size 32 --cv-folds 5
Author: Berghem - Smart Information Security
"""
import argparse
import os
import sys
import time
import torch
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
TrainerCallback,
)
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold
import numpy as np
# Dataset (local cleaned version - only standard CWE-XXXX IDs)
DATASET_PATH = "./dataset/cleaned"
# Model options
MODELS = {
"deberta-v3-small": "microsoft/deberta-v3-small",
"deberta-v3-base": "microsoft/deberta-v3-base",
"deberta-v3-large": "microsoft/deberta-v3-large",
}
class CUDACacheClearCallback(TrainerCallback):
"""Clear CUDA cache after each epoch to prevent memory buildup"""
def on_epoch_end(self, args, state, control, **kwargs):
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\n🧹 CUDA cache cleared after epoch {state.epoch}")
def train_fold(
model_name,
train_dataset,
val_dataset,
label2id,
id2label,
num_labels,
epochs,
batch_size,
learning_rate,
max_length,
early_stopping_patience,
fold_num=None,
total_folds=None,
):
"""Train a single fold"""
fold_str = f" (Fold {fold_num}/{total_folds})" if fold_num else ""
print("=" * 80)
print(f"DEBERTA CWE CLASSIFICATION TRAINING{fold_str}")
print("=" * 80)
print(f"Model: {model_name}")
print(f"Epochs: {epochs}")
print(f"Total batch size: {batch_size}")
print(f"Learning rate: {learning_rate}")
print(f"Max length: {max_length}")
if fold_num:
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print("=" * 80)
# Check device
if torch.cuda.is_available():
device = "cuda"
print(f"\n🖥️ Device: {device}")
print(f" GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
print(f"\n🖥️ Device: {device} (CPU only)")
# Load tokenizer
print(f"\n📚 Loading tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model
print(f"\n🤖 Loading model: {model_name}")
# Determine precision
use_bf16_model = False
use_fp16_model = False
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0).upper()
if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']):
use_bf16_model = True
else:
use_fp16_model = True
# Determine model dtype
model_dtype = None
if torch.cuda.is_available():
model_dtype = torch.bfloat16 if use_bf16_model else torch.float16
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
torch_dtype=model_dtype,
)
model = model.to(device)
print(f" ✅ Model loaded on {device}")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Training configuration
print("\n⚙️ Configuring training...")
output_dir = f"./models/deberta-cwe-fold-{fold_num}" if fold_num else "./models/deberta-cwe-final"
# Precision settings
use_bf16 = False
use_fp16 = False
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0).upper()
if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']):
use_bf16 = True
print(f" Using bf16 precision (optimal for {gpu_name})")
else:
use_fp16 = True
print(f" Using fp16 precision ({gpu_name})")
# Multi-GPU detection
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
print(f" GPUs detected: {num_gpus}")
# Memory monitoring
if torch.cuda.is_available():
for i in range(num_gpus):
mem_total = torch.cuda.get_device_properties(i).total_memory / 1e9
mem_allocated = torch.cuda.memory_allocated(i) / 1e9
print(f" GPU {i}: {mem_total:.1f}GB total, {mem_allocated:.1f}GB allocated")
# Optimized batch size distribution for 4x L4 GPUs (96GB total VRAM)
if num_gpus >= 4:
per_device_batch = max(4, batch_size // num_gpus)
gradient_accum = 1
elif num_gpus == 2:
per_device_batch = max(4, batch_size // num_gpus)
gradient_accum = max(1, batch_size // (per_device_batch * num_gpus))
else:
per_device_batch = min(8, batch_size)
gradient_accum = max(1, batch_size // per_device_batch)
print(f" Per-device batch: {per_device_batch}")
print(f" Gradient accumulation: {gradient_accum}")
print(f" Effective batch: {per_device_batch * gradient_accum * num_gpus}")
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=per_device_batch,
per_device_eval_batch_size=per_device_batch * 2,
gradient_accumulation_steps=gradient_accum,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_steps=100,
logging_dir=f"{output_dir}/logs",
fp16=use_fp16,
bf16=use_bf16,
dataloader_num_workers=0,
report_to="none",
push_to_hub=False,
ddp_find_unused_parameters=False if num_gpus > 1 else None,
gradient_checkpointing=False,
optim="paged_adamw_8bit",
max_grad_norm=1.0,
)
# Metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
acc = accuracy_score(labels, predictions)
f1 = f1_score(labels, predictions, average='weighted')
return {"accuracy": acc, "f1": f1}
# Create trainer
print("\n🚀 Starting training...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[
EarlyStoppingCallback(early_stopping_patience=early_stopping_patience),
CUDACacheClearCallback(),
],
)
# Clear cache before training
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Train
try:
train_result = trainer.train()
except torch.cuda.OutOfMemoryError as e:
print(f"\n❌ Out of Memory!")
print(f" Solutions:")
print(f" 1. Reduce batch size (currently: {batch_size})")
print(f" 2. Reduce max length (currently: {max_length})")
print(f" 3. Use smaller model")
raise
# Evaluate
print("\n📊 Final evaluation...")
eval_result = trainer.evaluate()
print(f"\n✅ Training complete!")
print(f" Final Loss: {train_result.training_loss:.4f}")
print(f" Accuracy: {eval_result.get('eval_accuracy', 0):.4f}")
print(f" F1 Score: {eval_result.get('eval_f1', 0):.4f}")
# Save model
print(f"\n💾 Saving model to: {output_dir}")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
# Final CUDA cache clear
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("\n🧹 Final CUDA cache clear complete")
print("=" * 80)
return {
"accuracy": eval_result.get('eval_accuracy', 0),
"f1": eval_result.get('eval_f1', 0),
"loss": train_result.training_loss,
}
def main():
parser = argparse.ArgumentParser(description="Train DeBERTa for CWE classification with K-Fold CV")
parser.add_argument("--model", type=str, default="deberta-v3-base",
choices=list(MODELS.keys()),
help="Model to use (default: deberta-v3-base)")
parser.add_argument("--epochs", type=int, default=10,
help="Number of epochs (default: 10)")
parser.add_argument("--batch-size", type=int, default=32,
help="Total batch size - optimized for 4x L4 GPUs (default: 32)")
parser.add_argument("--learning-rate", type=float, default=2e-5,
help="Learning rate (default: 2e-5)")
parser.add_argument("--max-length", type=int, default=256,
help="Max sequence length (default: 256)")
parser.add_argument("--early-stopping", type=int, default=5,
help="Early stopping patience (default: 5)")
parser.add_argument("--cv-folds", type=int, default=1,
help="Number of cross-validation folds (default: 1 = no CV)")
args = parser.parse_args()
model_path = MODELS[args.model]
# Load cleaned dataset
print("=" * 80)
print("LOADING DATASET")
print("=" * 80)
print("\n📦 Loading cleaned dataset...")
if os.path.exists(DATASET_PATH):
print(f" Using local: {DATASET_PATH}")
dataset = load_from_disk(DATASET_PATH)
else:
print(" Local dataset not found, downloading from HuggingFace...")
dataset = load_dataset("LorenzoNava/cve-cwe-dataset-cleaned")
print(f" ✅ Loaded {len(dataset['train']):,} samples (cleaned, no NVD-CWE-Other)")
# Build label mapping
print("\n🏷️ Building CWE label mapping...")
cwe_set = set()
for example in dataset['train']:
if example.get('CWE-ID'):
cwe_set.add(example['CWE-ID'])
cwe_list = sorted(list(cwe_set))
label2id = {cwe: idx for idx, cwe in enumerate(cwe_list)}
id2label = {idx: cwe for cwe, idx in label2id.items()}
num_labels = len(label2id)
print(f" ✅ Found {num_labels} unique CWE classes")
# Tokenize dataset
print("\n🔤 Tokenizing dataset...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
def tokenize_function(examples):
return tokenizer(
examples['DESCRIPTION'],
padding='max_length',
truncation=True,
max_length=args.max_length
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset['train'].column_names
)
print(" ✅ Tokenization complete")
# Add labels
def add_labels(examples, idx):
cwe_ids = [dataset['train'][i]['CWE-ID'] for i in idx]
return {'labels': [label2id.get(cwe, -100) for cwe in cwe_ids]}
tokenized_dataset['train'] = tokenized_dataset['train'].map(
add_labels,
batched=True,
with_indices=True
)
# Filter invalid labels
print("\n🔍 Filtering invalid samples...")
tokenized_dataset['train'] = tokenized_dataset['train'].filter(
lambda x: x['labels'] != -100
)
print(f" ✅ Train: {len(tokenized_dataset['train']):,} valid samples")
# Cross-validation or single training
if args.cv_folds > 1:
print(f"\n🔄 Running {args.cv_folds}-Fold Cross-Validation")
print("=" * 80)
# Convert to indices for KFold
kfold = KFold(n_splits=args.cv_folds, shuffle=True, random_state=42)
indices = np.arange(len(tokenized_dataset['train']))
fold_results = []
for fold, (train_idx, val_idx) in enumerate(kfold.split(indices), 1):
print(f"\n{'=' * 80}")
print(f"FOLD {fold}/{args.cv_folds}")
print(f"{'=' * 80}")
# Split dataset
train_fold_dataset = tokenized_dataset['train'].select(train_idx.tolist())
val_fold_dataset = tokenized_dataset['train'].select(val_idx.tolist())
# Train fold
result = train_fold(
model_name=model_path,
train_dataset=train_fold_dataset,
val_dataset=val_fold_dataset,
label2id=label2id,
id2label=id2label,
num_labels=num_labels,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
max_length=args.max_length,
early_stopping_patience=args.early_stopping,
fold_num=fold,
total_folds=args.cv_folds,
)
fold_results.append(result)
# Report aggregate results
print("\n" + "=" * 80)
print("CROSS-VALIDATION RESULTS")
print("=" * 80)
print("\nPer-Fold Results:")
for i, result in enumerate(fold_results, 1):
print(f" Fold {i}: Accuracy={result['accuracy']:.4f}, F1={result['f1']:.4f}, Loss={result['loss']:.4f}")
avg_accuracy = np.mean([r['accuracy'] for r in fold_results])
avg_f1 = np.mean([r['f1'] for r in fold_results])
avg_loss = np.mean([r['loss'] for r in fold_results])
std_accuracy = np.std([r['accuracy'] for r in fold_results])
std_f1 = np.std([r['f1'] for r in fold_results])
std_loss = np.std([r['loss'] for r in fold_results])
print(f"\nAverage Results ({args.cv_folds} folds):")
print(f" Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}")
print(f" F1 Score: {avg_f1:.4f} ± {std_f1:.4f}")
print(f" Loss: {avg_loss:.4f} ± {std_loss:.4f}")
print("=" * 80)
else:
# Single training run (no cross-validation)
print("\n📊 Creating 90/10 train/validation split...")
split_dataset = tokenized_dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
val_dataset = split_dataset['test']
print(f" Train: {len(train_dataset):,} samples")
print(f" Validation: {len(val_dataset):,} samples")
train_fold(
model_name=model_path,
train_dataset=train_dataset,
val_dataset=val_dataset,
label2id=label2id,
id2label=id2label,
num_labels=num_labels,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
max_length=args.max_length,
early_stopping_patience=args.early_stopping,
)
print("\n🎉 Done!")
if __name__ == "__main__":
main()