Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Project #1: Prompt Injection Detection Classifier | |
| Train a binary classifier to detect safe (0) vs unsafe (1) prompts | |
| using the Aegis AI Content Safety Dataset 2.0. | |
| Steps: | |
| 1. Load dataset with prompt and prompt_label fields | |
| 2. Convert labels: "safe" → 0, "unsafe" → 1 | |
| 3. Create train/validation split (since dataset is for "testing") | |
| 4. Train a sequence classification model | |
| 5. Evaluate on test split | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from datasets import Dataset, DatasetDict | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| TrainingArguments, | |
| Trainer, | |
| TrainerCallback, | |
| ) | |
| from load_aegis_dataset import load_aegis_dataset | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def compute_metrics(eval_pred): | |
| """Compute classification metrics.""" | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels, predictions, average='weighted', zero_division=0 | |
| ) | |
| accuracy = accuracy_score(labels, predictions) | |
| # Confusion matrix | |
| cm = confusion_matrix(labels, predictions) | |
| return { | |
| 'accuracy': accuracy, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'confusion_matrix': cm.tolist(), | |
| } | |
| def tokenize_function(examples, tokenizer): | |
| """Tokenize the prompts.""" | |
| return tokenizer( | |
| examples["prompt"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| ) | |
| class TestLossCallback(TrainerCallback): | |
| """Callback to track test loss after each epoch.""" | |
| def __init__(self, test_dataset, trainer): | |
| self.test_dataset = test_dataset | |
| self.trainer = trainer | |
| self.test_losses = [] | |
| self.test_epochs = [] | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| """Evaluate on test set after each epoch.""" | |
| if self.test_dataset is not None: | |
| test_results = self.trainer.evaluate(eval_dataset=self.test_dataset) | |
| if "eval_loss" in test_results: | |
| self.test_losses.append(test_results["eval_loss"]) | |
| self.test_epochs.append(state.epoch) | |
| logger.info(f"Epoch {state.epoch}: Test Loss = {test_results['eval_loss']:.4f}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train prompt injection detection classifier") | |
| parser.add_argument( | |
| "--model-name", | |
| type=str, | |
| default="distilbert-base-uncased", | |
| help="Base model for classification (distilbert-base-uncased, bert-base-uncased, roberta-base)" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default="./prompt-injection-detector", | |
| help="Directory to save the trained model" | |
| ) | |
| parser.add_argument( | |
| "--num-epochs", | |
| type=int, | |
| default=3, | |
| help="Number of training epochs" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=16, | |
| help="Training batch size" | |
| ) | |
| parser.add_argument( | |
| "--learning-rate", | |
| type=float, | |
| default=5e-5, | |
| help="Learning rate" | |
| ) | |
| parser.add_argument( | |
| "--test-size", | |
| type=float, | |
| default=0.1, | |
| help="Fraction of data to use for validation (rest for training)" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Random seed for reproducibility" | |
| ) | |
| args = parser.parse_args() | |
| logger.info("=" * 60) | |
| logger.info("Project #1: Prompt Injection Detection Classifier") | |
| logger.info("=" * 60) | |
| logger.info(f"Model: {args.model_name}") | |
| logger.info(f"Output directory: {args.output_dir}") | |
| logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}") | |
| logger.info("=" * 60) | |
| # Step 1: Load dataset (train/validation/test if available) | |
| logger.info("Step 1: Loading Aegis dataset splits...") | |
| dataset = load_aegis_dataset() | |
| if isinstance(dataset, DatasetDict): | |
| logger.info(f"Available splits: {list(dataset.keys())}") | |
| train_dataset = dataset.get("train") | |
| val_dataset = dataset.get("validation") or dataset.get("val") | |
| test_dataset = dataset.get("test") | |
| elif isinstance(dataset, Dataset): | |
| logger.warning("Dataset returned a single split. Treating as 'train'.") | |
| train_dataset = dataset | |
| val_dataset = None | |
| test_dataset = None | |
| else: | |
| raise ValueError("Unexpected dataset type returned from load_aegis_dataset.") | |
| if train_dataset is None: | |
| raise ValueError("Train split not found in dataset.") | |
| logger.info(f"Train split size: {len(train_dataset)}") | |
| logger.info(f"Train fields: {train_dataset.column_names}") | |
| logger.info(f"Train sample: {train_dataset[0]}") | |
| if val_dataset is not None: | |
| logger.info(f"Validation split size: {len(val_dataset)}") | |
| else: | |
| logger.info("Validation split not found; will create from train split.") | |
| if test_dataset is not None: | |
| logger.info(f"Test split size: {len(test_dataset)}") | |
| else: | |
| logger.info("Test split not found; will fall back to validation split for final evaluation if needed.") | |
| # Step 2: Verify label mapping and create validation split if missing | |
| logger.info("\nStep 2: Verifying label mapping and preparing splits...") | |
| unique_labels = set(train_dataset["prompt_label"]) | |
| logger.info(f"Unique labels: {unique_labels}") | |
| assert unique_labels == {0, 1}, f"Expected labels {{0, 1}}, got {unique_labels}" | |
| # Count safe vs unsafe | |
| safe_count = sum(1 for label in train_dataset["prompt_label"] if label == 0) | |
| unsafe_count = sum(1 for label in train_dataset["prompt_label"] if label == 1) | |
| logger.info(f"Safe prompts: {safe_count}, Unsafe prompts: {unsafe_count}") | |
| if val_dataset is None: | |
| logger.info("Creating validation split from train data...") | |
| split_dataset = train_dataset.train_test_split( | |
| test_size=args.test_size, | |
| shuffle=True, | |
| seed=args.seed | |
| ) | |
| train_dataset = split_dataset["train"] | |
| val_dataset = split_dataset["test"] | |
| logger.info(f"Final train samples: {len(train_dataset)}") | |
| logger.info(f"Final validation samples: {len(val_dataset)}") | |
| # Step 3: Load model and tokenizer | |
| logger.info(f"\nStep 3: Loading model and tokenizer: {args.model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| args.model_name, | |
| num_labels=2, | |
| ) | |
| # Step 4: Tokenize datasets | |
| logger.info("\nStep 4: Tokenizing datasets...") | |
| tokenize_fn = lambda examples: tokenize_function(examples, tokenizer) | |
| train_tokenized = train_dataset.map( | |
| tokenize_fn, | |
| batched=True, | |
| remove_columns=["prompt"], # Keep prompt_label for labels | |
| ) | |
| val_tokenized = val_dataset.map( | |
| tokenize_fn, | |
| batched=True, | |
| remove_columns=["prompt"], | |
| ) | |
| # Rename prompt_label to labels for Trainer | |
| train_tokenized = train_tokenized.rename_column("prompt_label", "labels") | |
| val_tokenized = val_tokenized.rename_column("prompt_label", "labels") | |
| # Set format for PyTorch | |
| train_tokenized.set_format("torch") | |
| val_tokenized.set_format("torch") | |
| # Prepare test dataset if available | |
| test_tokenized = None | |
| if test_dataset is not None: | |
| test_tokenized = test_dataset.map( | |
| tokenize_fn, | |
| batched=True, | |
| remove_columns=["prompt"], | |
| ) | |
| test_tokenized = test_tokenized.rename_column("prompt_label", "labels") | |
| test_tokenized.set_format("torch") | |
| # Step 5: Set up training | |
| logger.info("\nStep 5: Setting up training...") | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| training_args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=args.num_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| learning_rate=args.learning_rate, | |
| weight_decay=0.01, | |
| warmup_steps=500, | |
| logging_dir=str(output_dir / "logs"), | |
| logging_steps=100, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| save_total_limit=3, | |
| fp16=False, # Set to True if you have GPU | |
| report_to="none", | |
| ) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_tokenized, | |
| eval_dataset=val_tokenized, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| # Add callback to track test loss if test dataset is available | |
| test_callback = None | |
| if test_tokenized is not None: | |
| test_callback = TestLossCallback(test_tokenized, trainer) | |
| trainer.add_callback(test_callback) | |
| # Step 6: Train | |
| logger.info("\nStep 6: Training classifier...") | |
| trainer.train() | |
| # Extract training history for plotting | |
| train_losses = [] | |
| train_epochs = [] | |
| val_losses = [] | |
| val_epochs = [] | |
| for log_entry in trainer.state.log_history: | |
| if "loss" in log_entry and "epoch" in log_entry: | |
| train_losses.append(log_entry["loss"]) | |
| train_epochs.append(log_entry["epoch"]) | |
| elif "eval_loss" in log_entry and "epoch" in log_entry: | |
| val_losses.append(log_entry["eval_loss"]) | |
| val_epochs.append(log_entry["epoch"]) | |
| # Step 7: Evaluate on validation set | |
| logger.info("\nStep 7: Evaluating on validation set...") | |
| eval_results = trainer.evaluate() | |
| logger.info("\nValidation Results:") | |
| for key, value in eval_results.items(): | |
| if key != "confusion_matrix": | |
| logger.info(f" {key}: {value:.4f}") | |
| else: | |
| logger.info(f" {key}:") | |
| logger.info(" " + "\n ".join(str(row) for row in value)) | |
| # Step 8: Test on test split (if available) | |
| logger.info("\nStep 8: Testing on test split...") | |
| if test_tokenized is not None: | |
| logger.info(f"Test dataset found with {len(test_dataset)} samples.") | |
| # Get test losses from callback if available | |
| if test_callback and test_callback.test_losses: | |
| test_losses = test_callback.test_losses | |
| test_epochs = test_callback.test_epochs | |
| logger.info(f"Test losses tracked over {len(test_losses)} epochs via callback.") | |
| else: | |
| # Fallback: evaluate final model on test set | |
| test_results = trainer.evaluate(eval_dataset=test_tokenized) | |
| test_losses = [test_results["eval_loss"]] | |
| test_epochs = [args.num_epochs] | |
| logger.info("Evaluated final model on test set.") | |
| # Final test evaluation | |
| test_results = trainer.evaluate(eval_dataset=test_tokenized) | |
| logger.info("\nFinal Test Results:") | |
| for key, value in test_results.items(): | |
| if key != "confusion_matrix": | |
| logger.info(f" {key}: {value:.4f}") | |
| else: | |
| logger.info(f" {key}:") | |
| logger.info(" " + "\n ".join(str(row) for row in value)) | |
| else: | |
| logger.warning("Test split not found; using validation losses for plotting.") | |
| # Use validation losses as test losses for plotting | |
| test_losses = val_losses | |
| test_epochs = val_epochs | |
| # Step 9: Plot training and test loss | |
| logger.info("\nStep 9: Plotting training and test loss...") | |
| plt.figure(figsize=(10, 6)) | |
| if train_losses and train_epochs: | |
| plt.plot(train_epochs, train_losses, 'b-o', label='Train Loss', linewidth=2, markersize=6) | |
| if test_losses and test_epochs: | |
| plt.plot(test_epochs, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6) | |
| plt.xlabel('Epoch', fontsize=12) | |
| plt.ylabel('Loss', fontsize=12) | |
| plt.title('Training and Test Loss Over Epochs', fontsize=14, fontweight='bold') | |
| plt.legend(fontsize=11) | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| # Save plot | |
| plot_path = output_dir / "loss_plot.png" | |
| plt.savefig(plot_path, dpi=300, bbox_inches='tight') | |
| logger.info(f"Loss plot saved to: {plot_path}") | |
| plt.close() | |
| # Step 10: Save model | |
| logger.info(f"\nStep 10: Saving model to {output_dir}...") | |
| trainer.save_model() | |
| tokenizer.save_pretrained(str(output_dir)) | |
| logger.info("=" * 60) | |
| logger.info("Training complete!") | |
| logger.info(f"Model saved to: {output_dir}") | |
| logger.info(f"Loss plot saved to: {plot_path}") | |
| logger.info("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |