|
|
|
|
|
""" |
|
|
DeBERTa CWE Classification - Training Script with Cross-Validation |
|
|
=================================================================== |
|
|
Direct training script without Gradio UI. Run via SSH. |
|
|
|
|
|
Optimized for 4x NVIDIA L4 GPUs (24GB each = 96GB total VRAM) |
|
|
- Gradient checkpointing DISABLED (we have plenty of VRAM) |
|
|
- Batch size optimized for maximum GPU utilization |
|
|
- Quality-focused training parameters |
|
|
- K-Fold Cross-Validation support |
|
|
|
|
|
Usage: |
|
|
python3 train.py --model deberta-v3-base --epochs 10 --batch-size 32 |
|
|
python3 train.py --model deberta-v3-base --epochs 10 --batch-size 32 --cv-folds 5 |
|
|
|
|
|
Author: Berghem - Smart Information Security |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
from datasets import load_dataset, load_from_disk, concatenate_datasets |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
EarlyStoppingCallback, |
|
|
TrainerCallback, |
|
|
) |
|
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
from sklearn.model_selection import KFold |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
DATASET_PATH = "./dataset/cleaned" |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"deberta-v3-small": "microsoft/deberta-v3-small", |
|
|
"deberta-v3-base": "microsoft/deberta-v3-base", |
|
|
"deberta-v3-large": "microsoft/deberta-v3-large", |
|
|
} |
|
|
|
|
|
|
|
|
class CUDACacheClearCallback(TrainerCallback): |
|
|
"""Clear CUDA cache after each epoch to prevent memory buildup""" |
|
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
print(f"\n🧹 CUDA cache cleared after epoch {state.epoch}") |
|
|
|
|
|
|
|
|
def train_fold( |
|
|
model_name, |
|
|
train_dataset, |
|
|
val_dataset, |
|
|
label2id, |
|
|
id2label, |
|
|
num_labels, |
|
|
epochs, |
|
|
batch_size, |
|
|
learning_rate, |
|
|
max_length, |
|
|
early_stopping_patience, |
|
|
fold_num=None, |
|
|
total_folds=None, |
|
|
): |
|
|
"""Train a single fold""" |
|
|
|
|
|
fold_str = f" (Fold {fold_num}/{total_folds})" if fold_num else "" |
|
|
print("=" * 80) |
|
|
print(f"DEBERTA CWE CLASSIFICATION TRAINING{fold_str}") |
|
|
print("=" * 80) |
|
|
print(f"Model: {model_name}") |
|
|
print(f"Epochs: {epochs}") |
|
|
print(f"Total batch size: {batch_size}") |
|
|
print(f"Learning rate: {learning_rate}") |
|
|
print(f"Max length: {max_length}") |
|
|
if fold_num: |
|
|
print(f"Training samples: {len(train_dataset):,}") |
|
|
print(f"Validation samples: {len(val_dataset):,}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
print(f"\n🖥️ Device: {device}") |
|
|
print(f" GPU: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
device = "cpu" |
|
|
print(f"\n🖥️ Device: {device} (CPU only)") |
|
|
|
|
|
|
|
|
print(f"\n📚 Loading tokenizer: {model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
print(f"\n🤖 Loading model: {model_name}") |
|
|
|
|
|
|
|
|
use_bf16_model = False |
|
|
use_fp16_model = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0).upper() |
|
|
if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']): |
|
|
use_bf16_model = True |
|
|
else: |
|
|
use_fp16_model = True |
|
|
|
|
|
|
|
|
model_dtype = None |
|
|
if torch.cuda.is_available(): |
|
|
model_dtype = torch.bfloat16 if use_bf16_model else torch.float16 |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
num_labels=num_labels, |
|
|
label2id=label2id, |
|
|
id2label=id2label, |
|
|
torch_dtype=model_dtype, |
|
|
) |
|
|
|
|
|
model = model.to(device) |
|
|
print(f" ✅ Model loaded on {device}") |
|
|
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
print("\n⚙️ Configuring training...") |
|
|
output_dir = f"./models/deberta-cwe-fold-{fold_num}" if fold_num else "./models/deberta-cwe-final" |
|
|
|
|
|
|
|
|
use_bf16 = False |
|
|
use_fp16 = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0).upper() |
|
|
if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']): |
|
|
use_bf16 = True |
|
|
print(f" Using bf16 precision (optimal for {gpu_name})") |
|
|
else: |
|
|
use_fp16 = True |
|
|
print(f" Using fp16 precision ({gpu_name})") |
|
|
|
|
|
|
|
|
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 |
|
|
print(f" GPUs detected: {num_gpus}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
for i in range(num_gpus): |
|
|
mem_total = torch.cuda.get_device_properties(i).total_memory / 1e9 |
|
|
mem_allocated = torch.cuda.memory_allocated(i) / 1e9 |
|
|
print(f" GPU {i}: {mem_total:.1f}GB total, {mem_allocated:.1f}GB allocated") |
|
|
|
|
|
|
|
|
if num_gpus >= 4: |
|
|
per_device_batch = max(4, batch_size // num_gpus) |
|
|
gradient_accum = 1 |
|
|
elif num_gpus == 2: |
|
|
per_device_batch = max(4, batch_size // num_gpus) |
|
|
gradient_accum = max(1, batch_size // (per_device_batch * num_gpus)) |
|
|
else: |
|
|
per_device_batch = min(8, batch_size) |
|
|
gradient_accum = max(1, batch_size // per_device_batch) |
|
|
|
|
|
print(f" Per-device batch: {per_device_batch}") |
|
|
print(f" Gradient accumulation: {gradient_accum}") |
|
|
print(f" Effective batch: {per_device_batch * gradient_accum * num_gpus}") |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=per_device_batch, |
|
|
per_device_eval_batch_size=per_device_batch * 2, |
|
|
gradient_accumulation_steps=gradient_accum, |
|
|
learning_rate=learning_rate, |
|
|
weight_decay=0.01, |
|
|
warmup_ratio=0.1, |
|
|
lr_scheduler_type="cosine", |
|
|
eval_strategy="steps", |
|
|
eval_steps=500, |
|
|
save_strategy="steps", |
|
|
save_steps=500, |
|
|
save_total_limit=2, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1", |
|
|
greater_is_better=True, |
|
|
logging_steps=100, |
|
|
logging_dir=f"{output_dir}/logs", |
|
|
fp16=use_fp16, |
|
|
bf16=use_bf16, |
|
|
dataloader_num_workers=0, |
|
|
report_to="none", |
|
|
push_to_hub=False, |
|
|
ddp_find_unused_parameters=False if num_gpus > 1 else None, |
|
|
gradient_checkpointing=False, |
|
|
optim="paged_adamw_8bit", |
|
|
max_grad_norm=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
predictions = np.argmax(logits, axis=-1) |
|
|
acc = accuracy_score(labels, predictions) |
|
|
f1 = f1_score(labels, predictions, average='weighted') |
|
|
return {"accuracy": acc, "f1": f1} |
|
|
|
|
|
|
|
|
print("\n🚀 Starting training...") |
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=val_dataset, |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=compute_metrics, |
|
|
callbacks=[ |
|
|
EarlyStoppingCallback(early_stopping_patience=early_stopping_patience), |
|
|
CUDACacheClearCallback(), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
try: |
|
|
train_result = trainer.train() |
|
|
except torch.cuda.OutOfMemoryError as e: |
|
|
print(f"\n❌ Out of Memory!") |
|
|
print(f" Solutions:") |
|
|
print(f" 1. Reduce batch size (currently: {batch_size})") |
|
|
print(f" 2. Reduce max length (currently: {max_length})") |
|
|
print(f" 3. Use smaller model") |
|
|
raise |
|
|
|
|
|
|
|
|
print("\n📊 Final evaluation...") |
|
|
eval_result = trainer.evaluate() |
|
|
|
|
|
print(f"\n✅ Training complete!") |
|
|
print(f" Final Loss: {train_result.training_loss:.4f}") |
|
|
print(f" Accuracy: {eval_result.get('eval_accuracy', 0):.4f}") |
|
|
print(f" F1 Score: {eval_result.get('eval_f1', 0):.4f}") |
|
|
|
|
|
|
|
|
print(f"\n💾 Saving model to: {output_dir}") |
|
|
trainer.save_model(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
print("\n🧹 Final CUDA cache clear complete") |
|
|
|
|
|
print("=" * 80) |
|
|
|
|
|
return { |
|
|
"accuracy": eval_result.get('eval_accuracy', 0), |
|
|
"f1": eval_result.get('eval_f1', 0), |
|
|
"loss": train_result.training_loss, |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train DeBERTa for CWE classification with K-Fold CV") |
|
|
parser.add_argument("--model", type=str, default="deberta-v3-base", |
|
|
choices=list(MODELS.keys()), |
|
|
help="Model to use (default: deberta-v3-base)") |
|
|
parser.add_argument("--epochs", type=int, default=10, |
|
|
help="Number of epochs (default: 10)") |
|
|
parser.add_argument("--batch-size", type=int, default=32, |
|
|
help="Total batch size - optimized for 4x L4 GPUs (default: 32)") |
|
|
parser.add_argument("--learning-rate", type=float, default=2e-5, |
|
|
help="Learning rate (default: 2e-5)") |
|
|
parser.add_argument("--max-length", type=int, default=256, |
|
|
help="Max sequence length (default: 256)") |
|
|
parser.add_argument("--early-stopping", type=int, default=5, |
|
|
help="Early stopping patience (default: 5)") |
|
|
parser.add_argument("--cv-folds", type=int, default=1, |
|
|
help="Number of cross-validation folds (default: 1 = no CV)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model_path = MODELS[args.model] |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("LOADING DATASET") |
|
|
print("=" * 80) |
|
|
print("\n📦 Loading cleaned dataset...") |
|
|
if os.path.exists(DATASET_PATH): |
|
|
print(f" Using local: {DATASET_PATH}") |
|
|
dataset = load_from_disk(DATASET_PATH) |
|
|
else: |
|
|
print(" Local dataset not found, downloading from HuggingFace...") |
|
|
dataset = load_dataset("LorenzoNava/cve-cwe-dataset-cleaned") |
|
|
print(f" ✅ Loaded {len(dataset['train']):,} samples (cleaned, no NVD-CWE-Other)") |
|
|
|
|
|
|
|
|
print("\n🏷️ Building CWE label mapping...") |
|
|
cwe_set = set() |
|
|
for example in dataset['train']: |
|
|
if example.get('CWE-ID'): |
|
|
cwe_set.add(example['CWE-ID']) |
|
|
|
|
|
cwe_list = sorted(list(cwe_set)) |
|
|
label2id = {cwe: idx for idx, cwe in enumerate(cwe_list)} |
|
|
id2label = {idx: cwe for cwe, idx in label2id.items()} |
|
|
num_labels = len(label2id) |
|
|
print(f" ✅ Found {num_labels} unique CWE classes") |
|
|
|
|
|
|
|
|
print("\n🔤 Tokenizing dataset...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer( |
|
|
examples['DESCRIPTION'], |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=args.max_length |
|
|
) |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset['train'].column_names |
|
|
) |
|
|
print(" ✅ Tokenization complete") |
|
|
|
|
|
|
|
|
def add_labels(examples, idx): |
|
|
cwe_ids = [dataset['train'][i]['CWE-ID'] for i in idx] |
|
|
return {'labels': [label2id.get(cwe, -100) for cwe in cwe_ids]} |
|
|
|
|
|
tokenized_dataset['train'] = tokenized_dataset['train'].map( |
|
|
add_labels, |
|
|
batched=True, |
|
|
with_indices=True |
|
|
) |
|
|
|
|
|
|
|
|
print("\n🔍 Filtering invalid samples...") |
|
|
tokenized_dataset['train'] = tokenized_dataset['train'].filter( |
|
|
lambda x: x['labels'] != -100 |
|
|
) |
|
|
print(f" ✅ Train: {len(tokenized_dataset['train']):,} valid samples") |
|
|
|
|
|
|
|
|
if args.cv_folds > 1: |
|
|
print(f"\n🔄 Running {args.cv_folds}-Fold Cross-Validation") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
kfold = KFold(n_splits=args.cv_folds, shuffle=True, random_state=42) |
|
|
indices = np.arange(len(tokenized_dataset['train'])) |
|
|
|
|
|
fold_results = [] |
|
|
|
|
|
for fold, (train_idx, val_idx) in enumerate(kfold.split(indices), 1): |
|
|
print(f"\n{'=' * 80}") |
|
|
print(f"FOLD {fold}/{args.cv_folds}") |
|
|
print(f"{'=' * 80}") |
|
|
|
|
|
|
|
|
train_fold_dataset = tokenized_dataset['train'].select(train_idx.tolist()) |
|
|
val_fold_dataset = tokenized_dataset['train'].select(val_idx.tolist()) |
|
|
|
|
|
|
|
|
result = train_fold( |
|
|
model_name=model_path, |
|
|
train_dataset=train_fold_dataset, |
|
|
val_dataset=val_fold_dataset, |
|
|
label2id=label2id, |
|
|
id2label=id2label, |
|
|
num_labels=num_labels, |
|
|
epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
max_length=args.max_length, |
|
|
early_stopping_patience=args.early_stopping, |
|
|
fold_num=fold, |
|
|
total_folds=args.cv_folds, |
|
|
) |
|
|
|
|
|
fold_results.append(result) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("CROSS-VALIDATION RESULTS") |
|
|
print("=" * 80) |
|
|
print("\nPer-Fold Results:") |
|
|
for i, result in enumerate(fold_results, 1): |
|
|
print(f" Fold {i}: Accuracy={result['accuracy']:.4f}, F1={result['f1']:.4f}, Loss={result['loss']:.4f}") |
|
|
|
|
|
avg_accuracy = np.mean([r['accuracy'] for r in fold_results]) |
|
|
avg_f1 = np.mean([r['f1'] for r in fold_results]) |
|
|
avg_loss = np.mean([r['loss'] for r in fold_results]) |
|
|
|
|
|
std_accuracy = np.std([r['accuracy'] for r in fold_results]) |
|
|
std_f1 = np.std([r['f1'] for r in fold_results]) |
|
|
std_loss = np.std([r['loss'] for r in fold_results]) |
|
|
|
|
|
print(f"\nAverage Results ({args.cv_folds} folds):") |
|
|
print(f" Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}") |
|
|
print(f" F1 Score: {avg_f1:.4f} ± {std_f1:.4f}") |
|
|
print(f" Loss: {avg_loss:.4f} ± {std_loss:.4f}") |
|
|
print("=" * 80) |
|
|
|
|
|
else: |
|
|
|
|
|
print("\n📊 Creating 90/10 train/validation split...") |
|
|
split_dataset = tokenized_dataset['train'].train_test_split(test_size=0.1, seed=42) |
|
|
train_dataset = split_dataset['train'] |
|
|
val_dataset = split_dataset['test'] |
|
|
print(f" Train: {len(train_dataset):,} samples") |
|
|
print(f" Validation: {len(val_dataset):,} samples") |
|
|
|
|
|
train_fold( |
|
|
model_name=model_path, |
|
|
train_dataset=train_dataset, |
|
|
val_dataset=val_dataset, |
|
|
label2id=label2id, |
|
|
id2label=id2label, |
|
|
num_labels=num_labels, |
|
|
epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
max_length=args.max_length, |
|
|
early_stopping_patience=args.early_stopping, |
|
|
) |
|
|
|
|
|
print("\n🎉 Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|