#!/usr/bin/env python3 """ DeBERTa CWE Classification - Training Script (Optimized for 4x L4 GPUs) ======================================================================== Direct training script without Gradio UI. Run via SSH. Optimized for 4x NVIDIA L4 GPUs (24GB each = 96GB total VRAM) - Gradient checkpointing DISABLED (we have plenty of VRAM) - Batch size optimized for maximum GPU utilization - Quality-focused training parameters Usage: python3 train.py --model deberta-v3-base --epochs 10 --batch-size 32 Author: Berghem - Smart Information Security """ import argparse import os import sys import time import torch from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, TrainerCallback, ) from sklearn.metrics import accuracy_score, f1_score import numpy as np # Dataset DATASET_NAME = "stasvinokur/cve-and-cwe-dataset-1999-2025" # Model options MODELS = { "deberta-v3-small": "microsoft/deberta-v3-small", "deberta-v3-base": "microsoft/deberta-v3-base", "deberta-v3-large": "microsoft/deberta-v3-large", } class CUDACacheClearCallback(TrainerCallback): """Clear CUDA cache after each epoch to prevent memory buildup""" def on_epoch_end(self, args, state, control, **kwargs): if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"\n๐Ÿงน CUDA cache cleared after epoch {state.epoch}") def train_model(model_name, epochs, batch_size, learning_rate, max_length, early_stopping_patience): """Train DeBERTa model on CVE-CWE dataset""" print("=" * 80) print("DEBERTA CWE CLASSIFICATION TRAINING") print("=" * 80) print(f"Model: {model_name}") print(f"Epochs: {epochs}") print(f"Total batch size: {batch_size}") print(f"Learning rate: {learning_rate}") print(f"Max length: {max_length}") print("=" * 80) # Check device if torch.cuda.is_available(): device = "cuda" print(f"\n๐Ÿ–ฅ๏ธ Device: {device}") print(f" GPU: {torch.cuda.get_device_name(0)}") else: device = "cpu" print(f"\n๐Ÿ–ฅ๏ธ Device: {device} (CPU only)") # Load dataset with retry logic print("\n๐Ÿ“ฆ Loading dataset...") max_retries = 3 dataset = None for attempt in range(max_retries): try: dataset = load_dataset(DATASET_NAME) print(f" โœ… Loaded {len(dataset['train']):,} samples") break except Exception as e: if attempt < max_retries - 1: wait_time = 2 ** attempt print(f" โš ๏ธ Attempt {attempt + 1} failed: {str(e)}") print(f" Retrying in {wait_time}s...") time.sleep(wait_time) else: print(f" โŒ Failed after {max_retries} attempts") raise # Create validation split if needed if 'validation' not in dataset and 'test' not in dataset: print("\n๐Ÿ“Š Creating 90/10 train/validation split...") split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) dataset['train'] = split_dataset['train'] dataset['validation'] = split_dataset['test'] print(f" Train: {len(dataset['train']):,} samples") print(f" Validation: {len(dataset['validation']):,} samples") # Build label mapping print("\n๐Ÿท๏ธ Building CWE label mapping...") cwe_set = set() for example in dataset['train']: if example.get('CWE-ID'): cwe_set.add(example['CWE-ID']) cwe_list = sorted(list(cwe_set)) label2id = {cwe: idx for idx, cwe in enumerate(cwe_list)} id2label = {idx: cwe for cwe, idx in label2id.items()} num_labels = len(label2id) print(f" โœ… Found {num_labels} unique CWE classes") # Load tokenizer print(f"\n๐Ÿ“š Loading tokenizer: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) # Tokenize dataset print("\n๐Ÿ”ค Tokenizing dataset...") def tokenize_function(examples): return tokenizer( examples['DESCRIPTION'], padding='max_length', truncation=True, max_length=max_length ) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=dataset['train'].column_names ) print(" โœ… Tokenization complete") # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Add labels def add_labels(examples, idx): cwe_ids = [dataset['train'][i]['CWE-ID'] for i in idx] return {'labels': [label2id.get(cwe, -100) for cwe in cwe_ids]} tokenized_dataset['train'] = tokenized_dataset['train'].map( add_labels, batched=True, with_indices=True ) if 'validation' in tokenized_dataset: def add_val_labels(examples, idx): cwe_ids = [dataset['validation'][i]['CWE-ID'] for i in idx] return {'labels': [label2id.get(cwe, -100) for cwe in cwe_ids]} tokenized_dataset['validation'] = tokenized_dataset['validation'].map( add_val_labels, batched=True, with_indices=True ) # Filter invalid labels print("\n๐Ÿ” Filtering invalid samples...") tokenized_dataset['train'] = tokenized_dataset['train'].filter( lambda x: x['labels'] != -100 ) if 'validation' in tokenized_dataset: tokenized_dataset['validation'] = tokenized_dataset['validation'].filter( lambda x: x['labels'] != -100 ) print(f" โœ… Train: {len(tokenized_dataset['train']):,} valid samples") # Load model print(f"\n๐Ÿค– Loading model: {model_name}") # Determine precision use_bf16_model = False use_fp16_model = False if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0).upper() if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']): use_bf16_model = True else: use_fp16_model = True # Determine model dtype model_dtype = None if torch.cuda.is_available(): model_dtype = torch.bfloat16 if use_bf16_model else torch.float16 model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, label2id=label2id, id2label=id2label, torch_dtype=model_dtype, ) model = model.to(device) print(f" โœ… Model loaded on {device}") print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Training configuration print("\nโš™๏ธ Configuring training...") output_dir = "./models/deberta-cwe-final" # Precision settings use_bf16 = False use_fp16 = False if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0).upper() if any(x in gpu_name for x in ['A100', 'H100', 'L4', 'L40']): use_bf16 = True print(f" Using bf16 precision (optimal for {gpu_name})") else: use_fp16 = True print(f" Using fp16 precision ({gpu_name})") # Multi-GPU detection num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 print(f" GPUs detected: {num_gpus}") # Memory monitoring if torch.cuda.is_available(): for i in range(num_gpus): mem_total = torch.cuda.get_device_properties(i).total_memory / 1e9 mem_allocated = torch.cuda.memory_allocated(i) / 1e9 print(f" GPU {i}: {mem_total:.1f}GB total, {mem_allocated:.1f}GB allocated") # Optimized batch size distribution for 4x L4 GPUs (96GB total VRAM) # With gradient checkpointing DISABLED, we can use larger batches if num_gpus >= 4: # 4 GPUs: batch_size 32 โ†’ 8 per GPU, no accumulation needed per_device_batch = max(4, batch_size // num_gpus) gradient_accum = 1 # No accumulation needed with 4 GPUs elif num_gpus == 2: # 2 GPUs: use accumulation if needed per_device_batch = max(4, batch_size // num_gpus) gradient_accum = max(1, batch_size // (per_device_batch * num_gpus)) else: # Single GPU: use smaller batch + accumulation per_device_batch = min(8, batch_size) gradient_accum = max(1, batch_size // per_device_batch) print(f" Per-device batch: {per_device_batch}") print(f" Gradient accumulation: {gradient_accum}") print(f" Effective batch: {per_device_batch * gradient_accum * num_gpus}") training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, per_device_train_batch_size=per_device_batch, per_device_eval_batch_size=per_device_batch * 2, # Can use 2x for eval gradient_accumulation_steps=gradient_accum, learning_rate=learning_rate, weight_decay=0.01, warmup_ratio=0.1, lr_scheduler_type="cosine", eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, logging_steps=100, logging_dir=f"{output_dir}/logs", fp16=use_fp16, bf16=use_bf16, dataloader_num_workers=0, report_to="none", push_to_hub=False, ddp_find_unused_parameters=False if num_gpus > 1 else None, # CRITICAL FIX: Disable gradient checkpointing for multi-GPU # With 96GB VRAM (4x L4), we don't need it and it causes backward pass errors gradient_checkpointing=False, # Use memory-efficient optimizer (compatible with no checkpointing) optim="paged_adamw_8bit", max_grad_norm=1.0, ) # Metrics def compute_metrics(eval_pred): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) acc = accuracy_score(labels, predictions) f1 = f1_score(labels, predictions, average='weighted') return {"accuracy": acc, "f1": f1} # Create trainer with CUDA cache clearing callback print("\n๐Ÿš€ Starting training...") trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset.get('validation'), tokenizer=tokenizer, compute_metrics=compute_metrics, callbacks=[ EarlyStoppingCallback(early_stopping_patience=early_stopping_patience), CUDACacheClearCallback(), # Clear cache after each epoch ], ) # Clear cache before training if torch.cuda.is_available(): torch.cuda.empty_cache() # Train try: train_result = trainer.train() except torch.cuda.OutOfMemoryError as e: print(f"\nโŒ Out of Memory!") print(f" Solutions:") print(f" 1. Reduce batch size (currently: {batch_size})") print(f" 2. Reduce max length (currently: {max_length})") print(f" 3. Use smaller model") raise # Evaluate print("\n๐Ÿ“Š Final evaluation...") eval_result = trainer.evaluate() print(f"\nโœ… Training complete!") print(f" Final Loss: {train_result.training_loss:.4f}") print(f" Accuracy: {eval_result.get('eval_accuracy', 0):.4f}") print(f" F1 Score: {eval_result.get('eval_f1', 0):.4f}") # Save model print(f"\n๐Ÿ’พ Saving model to: {output_dir}") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) # Final CUDA cache clear if torch.cuda.is_available(): torch.cuda.empty_cache() print("\n๐Ÿงน Final CUDA cache clear complete") print(f"\n๐ŸŽ‰ Done! Model saved successfully.") print("=" * 80) def main(): parser = argparse.ArgumentParser(description="Train DeBERTa for CWE classification") parser.add_argument("--model", type=str, default="deberta-v3-base", choices=list(MODELS.keys()), help="Model to use (default: deberta-v3-base)") parser.add_argument("--epochs", type=int, default=10, help="Number of epochs (default: 10)") parser.add_argument("--batch-size", type=int, default=32, help="Total batch size - optimized for 4x L4 GPUs (default: 32)") parser.add_argument("--learning-rate", type=float, default=2e-5, help="Learning rate (default: 2e-5)") parser.add_argument("--max-length", type=int, default=256, help="Max sequence length (default: 256)") parser.add_argument("--early-stopping", type=int, default=5, help="Early stopping patience (default: 5)") args = parser.parse_args() model_path = MODELS[args.model] train_model( model_name=model_path, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, max_length=args.max_length, early_stopping_patience=args.early_stopping ) if __name__ == "__main__": main()