#!/usr/bin/env python3 """ Project #1: Prompt Injection Detection Classifier Train a binary classifier to detect safe (0) vs unsafe (1) prompts using the Aegis AI Content Safety Dataset 2.0. Steps: 1. Load dataset with prompt and prompt_label fields 2. Convert labels: "safe" → 0, "unsafe" → 1 3. Create train/validation split (since dataset is for "testing") 4. Train a sequence classification model 5. Evaluate on test split """ from __future__ import annotations import argparse import logging from pathlib import Path import matplotlib.pyplot as plt import numpy as np from datasets import Dataset, DatasetDict from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, TrainerCallback, ) from load_aegis_dataset import load_aegis_dataset # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) def compute_metrics(eval_pred): """Compute classification metrics.""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) precision, recall, f1, _ = precision_recall_fscore_support( labels, predictions, average='weighted', zero_division=0 ) accuracy = accuracy_score(labels, predictions) # Confusion matrix cm = confusion_matrix(labels, predictions) return { 'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall, 'confusion_matrix': cm.tolist(), } def tokenize_function(examples, tokenizer): """Tokenize the prompts.""" return tokenizer( examples["prompt"], truncation=True, padding="max_length", max_length=512, ) class TestLossCallback(TrainerCallback): """Callback to track test loss after each epoch.""" def __init__(self, test_dataset, trainer): self.test_dataset = test_dataset self.trainer = trainer self.test_losses = [] self.test_epochs = [] def on_epoch_end(self, args, state, control, **kwargs): """Evaluate on test set after each epoch.""" if self.test_dataset is not None: test_results = self.trainer.evaluate(eval_dataset=self.test_dataset) if "eval_loss" in test_results: self.test_losses.append(test_results["eval_loss"]) self.test_epochs.append(state.epoch) logger.info(f"Epoch {state.epoch}: Test Loss = {test_results['eval_loss']:.4f}") def main(): parser = argparse.ArgumentParser(description="Train prompt injection detection classifier") parser.add_argument( "--model-name", type=str, default="distilbert-base-uncased", help="Base model for classification (distilbert-base-uncased, bert-base-uncased, roberta-base)" ) parser.add_argument( "--output-dir", type=str, default="./prompt-injection-detector", help="Directory to save the trained model" ) parser.add_argument( "--num-epochs", type=int, default=3, help="Number of training epochs" ) parser.add_argument( "--batch-size", type=int, default=16, help="Training batch size" ) parser.add_argument( "--learning-rate", type=float, default=5e-5, help="Learning rate" ) parser.add_argument( "--test-size", type=float, default=0.1, help="Fraction of data to use for validation (rest for training)" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility" ) args = parser.parse_args() logger.info("=" * 60) logger.info("Project #1: Prompt Injection Detection Classifier") logger.info("=" * 60) logger.info(f"Model: {args.model_name}") logger.info(f"Output directory: {args.output_dir}") logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}") logger.info("=" * 60) # Step 1: Load dataset (train/validation/test if available) logger.info("Step 1: Loading Aegis dataset splits...") dataset = load_aegis_dataset() if isinstance(dataset, DatasetDict): logger.info(f"Available splits: {list(dataset.keys())}") train_dataset = dataset.get("train") val_dataset = dataset.get("validation") or dataset.get("val") test_dataset = dataset.get("test") elif isinstance(dataset, Dataset): logger.warning("Dataset returned a single split. Treating as 'train'.") train_dataset = dataset val_dataset = None test_dataset = None else: raise ValueError("Unexpected dataset type returned from load_aegis_dataset.") if train_dataset is None: raise ValueError("Train split not found in dataset.") logger.info(f"Train split size: {len(train_dataset)}") logger.info(f"Train fields: {train_dataset.column_names}") logger.info(f"Train sample: {train_dataset[0]}") if val_dataset is not None: logger.info(f"Validation split size: {len(val_dataset)}") else: logger.info("Validation split not found; will create from train split.") if test_dataset is not None: logger.info(f"Test split size: {len(test_dataset)}") else: logger.info("Test split not found; will fall back to validation split for final evaluation if needed.") # Step 2: Verify label mapping and create validation split if missing logger.info("\nStep 2: Verifying label mapping and preparing splits...") unique_labels = set(train_dataset["prompt_label"]) logger.info(f"Unique labels: {unique_labels}") assert unique_labels == {0, 1}, f"Expected labels {{0, 1}}, got {unique_labels}" # Count safe vs unsafe safe_count = sum(1 for label in train_dataset["prompt_label"] if label == 0) unsafe_count = sum(1 for label in train_dataset["prompt_label"] if label == 1) logger.info(f"Safe prompts: {safe_count}, Unsafe prompts: {unsafe_count}") if val_dataset is None: logger.info("Creating validation split from train data...") split_dataset = train_dataset.train_test_split( test_size=args.test_size, shuffle=True, seed=args.seed ) train_dataset = split_dataset["train"] val_dataset = split_dataset["test"] logger.info(f"Final train samples: {len(train_dataset)}") logger.info(f"Final validation samples: {len(val_dataset)}") # Step 3: Load model and tokenizer logger.info(f"\nStep 3: Loading model and tokenizer: {args.model_name}") tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained( args.model_name, num_labels=2, ) # Step 4: Tokenize datasets logger.info("\nStep 4: Tokenizing datasets...") tokenize_fn = lambda examples: tokenize_function(examples, tokenizer) train_tokenized = train_dataset.map( tokenize_fn, batched=True, remove_columns=["prompt"], # Keep prompt_label for labels ) val_tokenized = val_dataset.map( tokenize_fn, batched=True, remove_columns=["prompt"], ) # Rename prompt_label to labels for Trainer train_tokenized = train_tokenized.rename_column("prompt_label", "labels") val_tokenized = val_tokenized.rename_column("prompt_label", "labels") # Set format for PyTorch train_tokenized.set_format("torch") val_tokenized.set_format("torch") # Prepare test dataset if available test_tokenized = None if test_dataset is not None: test_tokenized = test_dataset.map( tokenize_fn, batched=True, remove_columns=["prompt"], ) test_tokenized = test_tokenized.rename_column("prompt_label", "labels") test_tokenized.set_format("torch") # Step 5: Set up training logger.info("\nStep 5: Setting up training...") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=args.num_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, learning_rate=args.learning_rate, weight_decay=0.01, warmup_steps=500, logging_dir=str(output_dir / "logs"), logging_steps=100, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, save_total_limit=3, fp16=False, # Set to True if you have GPU report_to="none", ) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) trainer = Trainer( model=model, args=training_args, train_dataset=train_tokenized, eval_dataset=val_tokenized, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, ) # Add callback to track test loss if test dataset is available test_callback = None if test_tokenized is not None: test_callback = TestLossCallback(test_tokenized, trainer) trainer.add_callback(test_callback) # Step 6: Train logger.info("\nStep 6: Training classifier...") trainer.train() # Extract training history for plotting train_losses = [] train_epochs = [] val_losses = [] val_epochs = [] for log_entry in trainer.state.log_history: if "loss" in log_entry and "epoch" in log_entry: train_losses.append(log_entry["loss"]) train_epochs.append(log_entry["epoch"]) elif "eval_loss" in log_entry and "epoch" in log_entry: val_losses.append(log_entry["eval_loss"]) val_epochs.append(log_entry["epoch"]) # Step 7: Evaluate on validation set logger.info("\nStep 7: Evaluating on validation set...") eval_results = trainer.evaluate() logger.info("\nValidation Results:") for key, value in eval_results.items(): if key != "confusion_matrix": logger.info(f" {key}: {value:.4f}") else: logger.info(f" {key}:") logger.info(" " + "\n ".join(str(row) for row in value)) # Step 8: Test on test split (if available) logger.info("\nStep 8: Testing on test split...") if test_tokenized is not None: logger.info(f"Test dataset found with {len(test_dataset)} samples.") # Get test losses from callback if available if test_callback and test_callback.test_losses: test_losses = test_callback.test_losses test_epochs = test_callback.test_epochs logger.info(f"Test losses tracked over {len(test_losses)} epochs via callback.") else: # Fallback: evaluate final model on test set test_results = trainer.evaluate(eval_dataset=test_tokenized) test_losses = [test_results["eval_loss"]] test_epochs = [args.num_epochs] logger.info("Evaluated final model on test set.") # Final test evaluation test_results = trainer.evaluate(eval_dataset=test_tokenized) logger.info("\nFinal Test Results:") for key, value in test_results.items(): if key != "confusion_matrix": logger.info(f" {key}: {value:.4f}") else: logger.info(f" {key}:") logger.info(" " + "\n ".join(str(row) for row in value)) else: logger.warning("Test split not found; using validation losses for plotting.") # Use validation losses as test losses for plotting test_losses = val_losses test_epochs = val_epochs # Step 9: Plot training and test loss logger.info("\nStep 9: Plotting training and test loss...") plt.figure(figsize=(10, 6)) if train_losses and train_epochs: plt.plot(train_epochs, train_losses, 'b-o', label='Train Loss', linewidth=2, markersize=6) if test_losses and test_epochs: plt.plot(test_epochs, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6) plt.xlabel('Epoch', fontsize=12) plt.ylabel('Loss', fontsize=12) plt.title('Training and Test Loss Over Epochs', fontsize=14, fontweight='bold') plt.legend(fontsize=11) plt.grid(True, alpha=0.3) plt.tight_layout() # Save plot plot_path = output_dir / "loss_plot.png" plt.savefig(plot_path, dpi=300, bbox_inches='tight') logger.info(f"Loss plot saved to: {plot_path}") plt.close() # Step 10: Save model logger.info(f"\nStep 10: Saving model to {output_dir}...") trainer.save_model() tokenizer.save_pretrained(str(output_dir)) logger.info("=" * 60) logger.info("Training complete!") logger.info(f"Model saved to: {output_dir}") logger.info(f"Loss plot saved to: {plot_path}") logger.info("=" * 60) if __name__ == "__main__": main()