Tameem7 commited on
Commit
849ca5b
·
1 Parent(s): dd881ce

fix eval speed

Browse files
Files changed (3) hide show
  1. app.py +66 -10
  2. eval.py +48 -0
  3. train_prompt_injection_detector.py +393 -0
app.py CHANGED
@@ -12,7 +12,13 @@ import numpy as np
12
  import torch
13
  from datasets import DatasetDict
14
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
15
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
 
 
 
 
 
 
16
 
17
  from load_aegis_dataset import load_aegis_dataset
18
 
@@ -48,7 +54,8 @@ def load_model_and_data(model_dir: str):
48
  print(f"Test samples: {len(test_dataset)}")
49
 
50
  def tokenize(batch):
51
- return tokenizer(batch['prompt'], truncation=True, padding='max_length', max_length=512)
 
52
 
53
  test_tokenized = test_dataset.map(tokenize, batched=True, remove_columns=['prompt'])
54
  test_tokenized = test_tokenized.rename_column('prompt_label', 'labels')
@@ -70,7 +77,26 @@ def load_model_and_data(model_dir: str):
70
  'confusion_matrix': cm.tolist()
71
  }
72
 
73
- trainer = Trainer(model=model, tokenizer=tokenizer, compute_metrics=compute_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  print("Model and dataset loaded successfully!")
76
  return "Model and dataset loaded successfully!"
@@ -119,16 +145,29 @@ def classify_prompt(prompt: str) -> tuple[str, str]:
119
  return result_text, label
120
 
121
 
122
- def evaluate_test_set(progress=gr.Progress()) -> str:
123
- """Evaluate the model on the test dataset and return metrics."""
 
 
 
 
124
  if trainer is None or test_tokenized is None:
125
  return "⚠️ Error: Model or test dataset not loaded."
126
 
 
 
 
 
 
 
 
 
 
127
  # Ensure tqdm is enabled for progress tracking
128
  trainer.args.disable_tqdm = False
129
 
130
  # Calculate total steps for progress tracking
131
- total_samples = len(test_tokenized)
132
  batch_size = trainer.args.per_device_eval_batch_size
133
  num_devices = max(1, torch.cuda.device_count()) if torch.cuda.is_available() else 1
134
  total_batches = (total_samples + batch_size * num_devices - 1) // (batch_size * num_devices)
@@ -162,7 +201,7 @@ def evaluate_test_set(progress=gr.Progress()) -> str:
162
 
163
  try:
164
  # Run evaluation - tqdm progress will be shown in console and Gradio should track it
165
- results = trainer.evaluate(eval_dataset=test_tokenized)
166
  progress(1.0, desc="✅ Evaluation complete!")
167
  finally:
168
  # Remove the callback
@@ -171,6 +210,12 @@ def evaluate_test_set(progress=gr.Progress()) -> str:
171
  # Format results
172
  output = "## Test Set Evaluation Results\n\n"
173
 
 
 
 
 
 
 
174
  # Main metrics
175
  output += "### Classification Metrics\n\n"
176
  output += f"- **Accuracy:** {results.get('eval_accuracy', 0):.4f}\n"
@@ -373,8 +418,17 @@ with app:
373
 
374
  # Tab 2: Test Set Evaluation
375
  with gr.Tab("📊 Evaluate Test Set"):
376
- gr.Markdown("### Evaluate the model on the full test dataset")
377
  gr.Markdown("**Note:** Progress percentage will be shown during evaluation.")
 
 
 
 
 
 
 
 
 
378
 
379
  eval_btn = gr.Button(
380
  "Run Evaluation",
@@ -383,9 +437,10 @@ with app:
383
  )
384
  eval_output = gr.Markdown(label="Evaluation Results")
385
 
386
- def run_evaluation():
387
  """Run evaluation and return result."""
388
- result = evaluate_test_set()
 
389
  return result
390
 
391
  def enable_button():
@@ -397,6 +452,7 @@ with app:
397
  outputs=eval_btn
398
  ).then(
399
  fn=run_evaluation,
 
400
  outputs=eval_output
401
  ).then(
402
  fn=enable_button,
 
12
  import torch
13
  from datasets import DatasetDict
14
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
15
+ from transformers import (
16
+ AutoModelForSequenceClassification,
17
+ AutoTokenizer,
18
+ Trainer,
19
+ TrainingArguments,
20
+ DataCollatorWithPadding,
21
+ )
22
 
23
  from load_aegis_dataset import load_aegis_dataset
24
 
 
54
  print(f"Test samples: {len(test_dataset)}")
55
 
56
  def tokenize(batch):
57
+ # Use dynamic padding - DataCollatorWithPadding will handle padding efficiently
58
+ return tokenizer(batch['prompt'], truncation=True, max_length=512)
59
 
60
  test_tokenized = test_dataset.map(tokenize, batched=True, remove_columns=['prompt'])
61
  test_tokenized = test_tokenized.rename_column('prompt_label', 'labels')
 
77
  'confusion_matrix': cm.tolist()
78
  }
79
 
80
+ # Optimize evaluation performance with larger batch size and other settings
81
+ eval_batch_size = 64 if torch.cuda.is_available() else 32
82
+ training_args = TrainingArguments(
83
+ output_dir="./eval_output", # Temporary directory
84
+ per_device_eval_batch_size=eval_batch_size,
85
+ fp16=torch.cuda.is_available(), # Use mixed precision on GPU
86
+ dataloader_num_workers=0, # Avoid multiprocessing issues in Gradio
87
+ report_to="none", # Don't report to any service
88
+ disable_tqdm=False, # Show progress
89
+ )
90
+
91
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
92
+
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=training_args,
96
+ tokenizer=tokenizer,
97
+ data_collator=data_collator,
98
+ compute_metrics=compute_metrics,
99
+ )
100
 
101
  print("Model and dataset loaded successfully!")
102
  return "Model and dataset loaded successfully!"
 
145
  return result_text, label
146
 
147
 
148
+ def evaluate_test_set(max_samples: int = None, progress=gr.Progress()) -> str:
149
+ """Evaluate the model on the test dataset and return metrics.
150
+
151
+ Args:
152
+ max_samples: Maximum number of samples to evaluate. If None, evaluates on full dataset.
153
+ """
154
  if trainer is None or test_tokenized is None:
155
  return "⚠️ Error: Model or test dataset not loaded."
156
 
157
+ # Limit dataset size if specified
158
+ eval_dataset = test_tokenized
159
+ if max_samples is not None and max_samples > 0:
160
+ max_samples = min(max_samples, len(test_tokenized))
161
+ eval_dataset = test_tokenized.select(range(max_samples))
162
+ print(f"Evaluating on {max_samples} samples (out of {len(test_tokenized)} total)")
163
+ else:
164
+ print(f"Evaluating on full test set ({len(test_tokenized)} samples)")
165
+
166
  # Ensure tqdm is enabled for progress tracking
167
  trainer.args.disable_tqdm = False
168
 
169
  # Calculate total steps for progress tracking
170
+ total_samples = len(eval_dataset)
171
  batch_size = trainer.args.per_device_eval_batch_size
172
  num_devices = max(1, torch.cuda.device_count()) if torch.cuda.is_available() else 1
173
  total_batches = (total_samples + batch_size * num_devices - 1) // (batch_size * num_devices)
 
201
 
202
  try:
203
  # Run evaluation - tqdm progress will be shown in console and Gradio should track it
204
+ results = trainer.evaluate(eval_dataset=eval_dataset)
205
  progress(1.0, desc="✅ Evaluation complete!")
206
  finally:
207
  # Remove the callback
 
210
  # Format results
211
  output = "## Test Set Evaluation Results\n\n"
212
 
213
+ # Show dataset size info
214
+ if max_samples is not None and max_samples < len(test_tokenized):
215
+ output += f"**Note:** Evaluated on {max_samples} samples (out of {len(test_tokenized)} total)\n\n"
216
+ else:
217
+ output += f"**Note:** Evaluated on full test set ({len(test_tokenized)} samples)\n\n"
218
+
219
  # Main metrics
220
  output += "### Classification Metrics\n\n"
221
  output += f"- **Accuracy:** {results.get('eval_accuracy', 0):.4f}\n"
 
418
 
419
  # Tab 2: Test Set Evaluation
420
  with gr.Tab("📊 Evaluate Test Set"):
421
+ gr.Markdown("### Evaluate the model on the test dataset")
422
  gr.Markdown("**Note:** Progress percentage will be shown during evaluation.")
423
+ gr.Markdown("**Tip:** Limit the number of samples for faster evaluation during testing.")
424
+
425
+ max_samples_input = gr.Number(
426
+ label="Maximum samples to evaluate (leave empty for full dataset)",
427
+ value=None,
428
+ minimum=1,
429
+ precision=0,
430
+ info="Set a limit to evaluate faster. Leave empty to evaluate on the full dataset."
431
+ )
432
 
433
  eval_btn = gr.Button(
434
  "Run Evaluation",
 
437
  )
438
  eval_output = gr.Markdown(label="Evaluation Results")
439
 
440
+ def run_evaluation(max_samples):
441
  """Run evaluation and return result."""
442
+ max_samples_int = int(max_samples) if max_samples is not None and max_samples > 0 else None
443
+ result = evaluate_test_set(max_samples=max_samples_int)
444
  return result
445
 
446
  def enable_button():
 
452
  outputs=eval_btn
453
  ).then(
454
  fn=run_evaluation,
455
+ inputs=max_samples_input,
456
  outputs=eval_output
457
  ).then(
458
  fn=enable_button,
eval.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from load_aegis_dataset import load_aegis_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer
3
+ from datasets import DatasetDict
4
+ import numpy as np
5
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
6
+
7
+ def compute_metrics(eval_pred):
8
+ predictions, labels = eval_pred
9
+ preds = np.argmax(predictions, axis=1)
10
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
11
+ accuracy = accuracy_score(labels, preds)
12
+ cm = confusion_matrix(labels, preds)
13
+ return {
14
+ 'accuracy': accuracy,
15
+ 'precision': precision,
16
+ 'recall': recall,
17
+ 'f1': f1,
18
+ 'confusion_matrix': cm.tolist()
19
+ }
20
+
21
+ model_dir = 'prompt-injection-detector/checkpoint-5628'
22
+ print(f'Loading model from {model_dir}')
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
25
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
26
+
27
+ print('Loading dataset...')
28
+ ds = load_aegis_dataset()
29
+ if not isinstance(ds, DatasetDict) or 'test' not in ds:
30
+ raise RuntimeError('Test split not available in dataset.')
31
+
32
+ test_ds = ds['test']
33
+ print(f'Test samples: {len(test_ds)}')
34
+
35
+ def tokenize(batch):
36
+ return tokenizer(batch['prompt'], truncation=True, padding='max_length', max_length=512)
37
+
38
+ test_tok = test_ds.map(tokenize, batched=True, remove_columns=['prompt'])
39
+ test_tok = test_tok.rename_column('prompt_label', 'labels')
40
+ test_tok.set_format('torch')
41
+
42
+ trainer = Trainer(model=model, tokenizer=tokenizer, compute_metrics=compute_metrics)
43
+
44
+ print('Evaluating...')
45
+ results = trainer.evaluate(eval_dataset=test_tok)
46
+ print('Test metrics:')
47
+ for k, v in results.items():
48
+ print(f' {k}: {v}')
train_prompt_injection_detector.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Project #1: Prompt Injection Detection Classifier
4
+
5
+ Train a binary classifier to detect safe (0) vs unsafe (1) prompts
6
+ using the Aegis AI Content Safety Dataset 2.0.
7
+
8
+ Steps:
9
+ 1. Load dataset with prompt and prompt_label fields
10
+ 2. Convert labels: "safe" → 0, "unsafe" → 1
11
+ 3. Create train/validation split (since dataset is for "testing")
12
+ 4. Train a sequence classification model
13
+ 5. Evaluate on test split
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import logging
20
+ from pathlib import Path
21
+
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ from datasets import Dataset, DatasetDict
25
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
26
+ from transformers import (
27
+ AutoModelForSequenceClassification,
28
+ AutoTokenizer,
29
+ DataCollatorWithPadding,
30
+ TrainingArguments,
31
+ Trainer,
32
+ TrainerCallback,
33
+ )
34
+
35
+ from load_aegis_dataset import load_aegis_dataset
36
+
37
+ # Set up logging
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(levelname)s - %(message)s',
41
+ datefmt='%Y-%m-%d %H:%M:%S'
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ def compute_metrics(eval_pred):
47
+ """Compute classification metrics."""
48
+ predictions, labels = eval_pred
49
+ predictions = np.argmax(predictions, axis=1)
50
+
51
+ precision, recall, f1, _ = precision_recall_fscore_support(
52
+ labels, predictions, average='weighted', zero_division=0
53
+ )
54
+ accuracy = accuracy_score(labels, predictions)
55
+
56
+ # Confusion matrix
57
+ cm = confusion_matrix(labels, predictions)
58
+
59
+ return {
60
+ 'accuracy': accuracy,
61
+ 'f1': f1,
62
+ 'precision': precision,
63
+ 'recall': recall,
64
+ 'confusion_matrix': cm.tolist(),
65
+ }
66
+
67
+
68
+ def tokenize_function(examples, tokenizer):
69
+ """Tokenize the prompts."""
70
+ return tokenizer(
71
+ examples["prompt"],
72
+ truncation=True,
73
+ padding="max_length",
74
+ max_length=512,
75
+ )
76
+
77
+
78
+ class TestLossCallback(TrainerCallback):
79
+ """Callback to track test loss after each epoch."""
80
+
81
+ def __init__(self, test_dataset, trainer):
82
+ self.test_dataset = test_dataset
83
+ self.trainer = trainer
84
+ self.test_losses = []
85
+ self.test_epochs = []
86
+
87
+ def on_epoch_end(self, args, state, control, **kwargs):
88
+ """Evaluate on test set after each epoch."""
89
+ if self.test_dataset is not None:
90
+ test_results = self.trainer.evaluate(eval_dataset=self.test_dataset)
91
+ if "eval_loss" in test_results:
92
+ self.test_losses.append(test_results["eval_loss"])
93
+ self.test_epochs.append(state.epoch)
94
+ logger.info(f"Epoch {state.epoch}: Test Loss = {test_results['eval_loss']:.4f}")
95
+
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(description="Train prompt injection detection classifier")
99
+ parser.add_argument(
100
+ "--model-name",
101
+ type=str,
102
+ default="distilbert-base-uncased",
103
+ help="Base model for classification (distilbert-base-uncased, bert-base-uncased, roberta-base)"
104
+ )
105
+ parser.add_argument(
106
+ "--output-dir",
107
+ type=str,
108
+ default="./prompt-injection-detector",
109
+ help="Directory to save the trained model"
110
+ )
111
+ parser.add_argument(
112
+ "--num-epochs",
113
+ type=int,
114
+ default=3,
115
+ help="Number of training epochs"
116
+ )
117
+ parser.add_argument(
118
+ "--batch-size",
119
+ type=int,
120
+ default=16,
121
+ help="Training batch size"
122
+ )
123
+ parser.add_argument(
124
+ "--learning-rate",
125
+ type=float,
126
+ default=5e-5,
127
+ help="Learning rate"
128
+ )
129
+ parser.add_argument(
130
+ "--test-size",
131
+ type=float,
132
+ default=0.1,
133
+ help="Fraction of data to use for validation (rest for training)"
134
+ )
135
+ parser.add_argument(
136
+ "--seed",
137
+ type=int,
138
+ default=42,
139
+ help="Random seed for reproducibility"
140
+ )
141
+ args = parser.parse_args()
142
+
143
+ logger.info("=" * 60)
144
+ logger.info("Project #1: Prompt Injection Detection Classifier")
145
+ logger.info("=" * 60)
146
+ logger.info(f"Model: {args.model_name}")
147
+ logger.info(f"Output directory: {args.output_dir}")
148
+ logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}")
149
+ logger.info("=" * 60)
150
+
151
+ # Step 1: Load dataset (train/validation/test if available)
152
+ logger.info("Step 1: Loading Aegis dataset splits...")
153
+ dataset = load_aegis_dataset()
154
+
155
+ if isinstance(dataset, DatasetDict):
156
+ logger.info(f"Available splits: {list(dataset.keys())}")
157
+ train_dataset = dataset.get("train")
158
+ val_dataset = dataset.get("validation") or dataset.get("val")
159
+ test_dataset = dataset.get("test")
160
+ elif isinstance(dataset, Dataset):
161
+ logger.warning("Dataset returned a single split. Treating as 'train'.")
162
+ train_dataset = dataset
163
+ val_dataset = None
164
+ test_dataset = None
165
+ else:
166
+ raise ValueError("Unexpected dataset type returned from load_aegis_dataset.")
167
+
168
+ if train_dataset is None:
169
+ raise ValueError("Train split not found in dataset.")
170
+
171
+ logger.info(f"Train split size: {len(train_dataset)}")
172
+ logger.info(f"Train fields: {train_dataset.column_names}")
173
+ logger.info(f"Train sample: {train_dataset[0]}")
174
+
175
+ if val_dataset is not None:
176
+ logger.info(f"Validation split size: {len(val_dataset)}")
177
+ else:
178
+ logger.info("Validation split not found; will create from train split.")
179
+
180
+ if test_dataset is not None:
181
+ logger.info(f"Test split size: {len(test_dataset)}")
182
+ else:
183
+ logger.info("Test split not found; will fall back to validation split for final evaluation if needed.")
184
+
185
+ # Step 2: Verify label mapping and create validation split if missing
186
+ logger.info("\nStep 2: Verifying label mapping and preparing splits...")
187
+ unique_labels = set(train_dataset["prompt_label"])
188
+ logger.info(f"Unique labels: {unique_labels}")
189
+ assert unique_labels == {0, 1}, f"Expected labels {{0, 1}}, got {unique_labels}"
190
+
191
+ # Count safe vs unsafe
192
+ safe_count = sum(1 for label in train_dataset["prompt_label"] if label == 0)
193
+ unsafe_count = sum(1 for label in train_dataset["prompt_label"] if label == 1)
194
+ logger.info(f"Safe prompts: {safe_count}, Unsafe prompts: {unsafe_count}")
195
+
196
+ if val_dataset is None:
197
+ logger.info("Creating validation split from train data...")
198
+ split_dataset = train_dataset.train_test_split(
199
+ test_size=args.test_size,
200
+ shuffle=True,
201
+ seed=args.seed
202
+ )
203
+ train_dataset = split_dataset["train"]
204
+ val_dataset = split_dataset["test"]
205
+
206
+ logger.info(f"Final train samples: {len(train_dataset)}")
207
+ logger.info(f"Final validation samples: {len(val_dataset)}")
208
+
209
+ # Step 3: Load model and tokenizer
210
+ logger.info(f"\nStep 3: Loading model and tokenizer: {args.model_name}")
211
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
212
+ model = AutoModelForSequenceClassification.from_pretrained(
213
+ args.model_name,
214
+ num_labels=2,
215
+ )
216
+
217
+ # Step 4: Tokenize datasets
218
+ logger.info("\nStep 4: Tokenizing datasets...")
219
+ tokenize_fn = lambda examples: tokenize_function(examples, tokenizer)
220
+
221
+ train_tokenized = train_dataset.map(
222
+ tokenize_fn,
223
+ batched=True,
224
+ remove_columns=["prompt"], # Keep prompt_label for labels
225
+ )
226
+ val_tokenized = val_dataset.map(
227
+ tokenize_fn,
228
+ batched=True,
229
+ remove_columns=["prompt"],
230
+ )
231
+
232
+ # Rename prompt_label to labels for Trainer
233
+ train_tokenized = train_tokenized.rename_column("prompt_label", "labels")
234
+ val_tokenized = val_tokenized.rename_column("prompt_label", "labels")
235
+
236
+ # Set format for PyTorch
237
+ train_tokenized.set_format("torch")
238
+ val_tokenized.set_format("torch")
239
+
240
+ # Prepare test dataset if available
241
+ test_tokenized = None
242
+ if test_dataset is not None:
243
+ test_tokenized = test_dataset.map(
244
+ tokenize_fn,
245
+ batched=True,
246
+ remove_columns=["prompt"],
247
+ )
248
+ test_tokenized = test_tokenized.rename_column("prompt_label", "labels")
249
+ test_tokenized.set_format("torch")
250
+
251
+ # Step 5: Set up training
252
+ logger.info("\nStep 5: Setting up training...")
253
+ output_dir = Path(args.output_dir)
254
+ output_dir.mkdir(parents=True, exist_ok=True)
255
+
256
+ training_args = TrainingArguments(
257
+ output_dir=str(output_dir),
258
+ num_train_epochs=args.num_epochs,
259
+ per_device_train_batch_size=args.batch_size,
260
+ per_device_eval_batch_size=args.batch_size,
261
+ learning_rate=args.learning_rate,
262
+ weight_decay=0.01,
263
+ warmup_steps=500,
264
+ logging_dir=str(output_dir / "logs"),
265
+ logging_steps=100,
266
+ eval_strategy="epoch",
267
+ save_strategy="epoch",
268
+ load_best_model_at_end=True,
269
+ metric_for_best_model="f1",
270
+ greater_is_better=True,
271
+ save_total_limit=3,
272
+ fp16=False, # Set to True if you have GPU
273
+ report_to="none",
274
+ )
275
+
276
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
277
+
278
+ trainer = Trainer(
279
+ model=model,
280
+ args=training_args,
281
+ train_dataset=train_tokenized,
282
+ eval_dataset=val_tokenized,
283
+ tokenizer=tokenizer,
284
+ data_collator=data_collator,
285
+ compute_metrics=compute_metrics,
286
+ )
287
+
288
+ # Add callback to track test loss if test dataset is available
289
+ test_callback = None
290
+ if test_tokenized is not None:
291
+ test_callback = TestLossCallback(test_tokenized, trainer)
292
+ trainer.add_callback(test_callback)
293
+
294
+ # Step 6: Train
295
+ logger.info("\nStep 6: Training classifier...")
296
+ trainer.train()
297
+
298
+ # Extract training history for plotting
299
+ train_losses = []
300
+ train_epochs = []
301
+ val_losses = []
302
+ val_epochs = []
303
+
304
+ for log_entry in trainer.state.log_history:
305
+ if "loss" in log_entry and "epoch" in log_entry:
306
+ train_losses.append(log_entry["loss"])
307
+ train_epochs.append(log_entry["epoch"])
308
+ elif "eval_loss" in log_entry and "epoch" in log_entry:
309
+ val_losses.append(log_entry["eval_loss"])
310
+ val_epochs.append(log_entry["epoch"])
311
+
312
+ # Step 7: Evaluate on validation set
313
+ logger.info("\nStep 7: Evaluating on validation set...")
314
+ eval_results = trainer.evaluate()
315
+ logger.info("\nValidation Results:")
316
+ for key, value in eval_results.items():
317
+ if key != "confusion_matrix":
318
+ logger.info(f" {key}: {value:.4f}")
319
+ else:
320
+ logger.info(f" {key}:")
321
+ logger.info(" " + "\n ".join(str(row) for row in value))
322
+
323
+ # Step 8: Test on test split (if available)
324
+ logger.info("\nStep 8: Testing on test split...")
325
+
326
+ if test_tokenized is not None:
327
+ logger.info(f"Test dataset found with {len(test_dataset)} samples.")
328
+
329
+ # Get test losses from callback if available
330
+ if test_callback and test_callback.test_losses:
331
+ test_losses = test_callback.test_losses
332
+ test_epochs = test_callback.test_epochs
333
+ logger.info(f"Test losses tracked over {len(test_losses)} epochs via callback.")
334
+ else:
335
+ # Fallback: evaluate final model on test set
336
+ test_results = trainer.evaluate(eval_dataset=test_tokenized)
337
+ test_losses = [test_results["eval_loss"]]
338
+ test_epochs = [args.num_epochs]
339
+ logger.info("Evaluated final model on test set.")
340
+
341
+ # Final test evaluation
342
+ test_results = trainer.evaluate(eval_dataset=test_tokenized)
343
+ logger.info("\nFinal Test Results:")
344
+ for key, value in test_results.items():
345
+ if key != "confusion_matrix":
346
+ logger.info(f" {key}: {value:.4f}")
347
+ else:
348
+ logger.info(f" {key}:")
349
+ logger.info(" " + "\n ".join(str(row) for row in value))
350
+ else:
351
+ logger.warning("Test split not found; using validation losses for plotting.")
352
+ # Use validation losses as test losses for plotting
353
+ test_losses = val_losses
354
+ test_epochs = val_epochs
355
+
356
+ # Step 9: Plot training and test loss
357
+ logger.info("\nStep 9: Plotting training and test loss...")
358
+ plt.figure(figsize=(10, 6))
359
+
360
+ if train_losses and train_epochs:
361
+ plt.plot(train_epochs, train_losses, 'b-o', label='Train Loss', linewidth=2, markersize=6)
362
+
363
+ if test_losses and test_epochs:
364
+ plt.plot(test_epochs, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6)
365
+
366
+ plt.xlabel('Epoch', fontsize=12)
367
+ plt.ylabel('Loss', fontsize=12)
368
+ plt.title('Training and Test Loss Over Epochs', fontsize=14, fontweight='bold')
369
+ plt.legend(fontsize=11)
370
+ plt.grid(True, alpha=0.3)
371
+ plt.tight_layout()
372
+
373
+ # Save plot
374
+ plot_path = output_dir / "loss_plot.png"
375
+ plt.savefig(plot_path, dpi=300, bbox_inches='tight')
376
+ logger.info(f"Loss plot saved to: {plot_path}")
377
+ plt.close()
378
+
379
+ # Step 10: Save model
380
+ logger.info(f"\nStep 10: Saving model to {output_dir}...")
381
+ trainer.save_model()
382
+ tokenizer.save_pretrained(str(output_dir))
383
+
384
+ logger.info("=" * 60)
385
+ logger.info("Training complete!")
386
+ logger.info(f"Model saved to: {output_dir}")
387
+ logger.info(f"Loss plot saved to: {plot_path}")
388
+ logger.info("=" * 60)
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()
393
+