|
|
""" |
|
|
CTI Bench Evaluation Script for Cybersecurity Retrieval System |
|
|
|
|
|
This script evaluates the retrieval supervisor system against the CTI Bench dataset, |
|
|
including both CTI-ATE (attack technique extraction) and CTI-MCQ (multiple choice questions). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import pandas as pd |
|
|
import re |
|
|
import json |
|
|
import csv |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Any, Optional |
|
|
from datetime import datetime |
|
|
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor |
|
|
|
|
|
|
|
|
class CTIBenchEvaluator: |
|
|
"""Evaluator for CTI Bench dataset using the Retrieval Supervisor.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
supervisor: Optional[RetrievalSupervisor], |
|
|
dataset_dir: str = "cti_bench/datasets", |
|
|
output_dir: str = "cti_bench/eval_output", |
|
|
): |
|
|
""" |
|
|
Initialize the CTI Bench evaluator. |
|
|
|
|
|
Args: |
|
|
supervisor: RetrievalSupervisor instance (can be None for CSV processing) |
|
|
dataset_dir: Directory containing CTI Bench datasets |
|
|
output_dir: Directory to save evaluation results |
|
|
""" |
|
|
self.supervisor = supervisor |
|
|
self.dataset_dir = Path(dataset_dir) |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.ate_query_template = """You are a cybersecurity expert specializing in cyber threat intelligence. |
|
|
Extract all MITRE Enterprise attack patterns from the following text and map them to their corresponding MITRE technique IDs. |
|
|
Provide reasoning for each identification. |
|
|
Ensure the final line contains only the IDs for the main techniques, separated by commas, excluding any subtechnique IDs. |
|
|
|
|
|
Example of the final line: T1071, T1560, T1547 |
|
|
|
|
|
Text: |
|
|
{attack_description} |
|
|
""" |
|
|
|
|
|
def load_datasets(self) -> Tuple[pd.DataFrame, pd.DataFrame]: |
|
|
"""Load CTI-ATE and CTI-MCQ datasets.""" |
|
|
try: |
|
|
|
|
|
ate_path = self.dataset_dir / "cti-ate.tsv" |
|
|
ate_df = pd.read_csv(ate_path, sep="\t") |
|
|
print(f"Loaded CTI-ATE dataset: {len(ate_df)} samples") |
|
|
|
|
|
|
|
|
mcq_path = self.dataset_dir / "cti-mcq.tsv" |
|
|
mcq_df = pd.read_csv(mcq_path, sep="\t") |
|
|
print(f"Loaded CTI-MCQ dataset: {len(mcq_df)} samples") |
|
|
|
|
|
return ate_df, mcq_df |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading datasets: {e}") |
|
|
raise |
|
|
|
|
|
def filter_dataset(self, df: pd.DataFrame, dataset_type: str) -> pd.DataFrame: |
|
|
"""Filter dataset according to requirements.""" |
|
|
if dataset_type == "ate": |
|
|
|
|
|
filtered_df = df[df["Platform"] == "Enterprise"].copy() |
|
|
print( |
|
|
f"CTI-ATE filtered to Enterprise platform: {len(filtered_df)} samples" |
|
|
) |
|
|
elif dataset_type == "mcq": |
|
|
|
|
|
filtered_df = df[df["URL"].str.contains("techniques", na=False)].copy() |
|
|
print(f"CTI-MCQ filtered to technique URLs: {len(filtered_df)} samples") |
|
|
else: |
|
|
raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
|
|
|
return filtered_df |
|
|
|
|
|
def extract_technique_ids_from_response(self, response: str) -> List[str]: |
|
|
""" |
|
|
Extract MITRE technique IDs from the response text. |
|
|
Simplified version: only checks the final line. |
|
|
|
|
|
Args: |
|
|
response: Response text from the supervisor |
|
|
|
|
|
Returns: |
|
|
List of extracted technique IDs, or empty list if not successful |
|
|
""" |
|
|
|
|
|
lines = response.strip().split("\n") |
|
|
if not lines: |
|
|
return [] |
|
|
|
|
|
final_line = lines[-1].strip() |
|
|
if not final_line: |
|
|
return [] |
|
|
|
|
|
|
|
|
technique_pattern = r"\bT\d{4}(?:\.\d{3})?\b" |
|
|
|
|
|
|
|
|
techniques_in_line = re.findall(technique_pattern, final_line) |
|
|
if not techniques_in_line: |
|
|
return [] |
|
|
|
|
|
|
|
|
clean_line = re.sub(r"[T\d.,\s]", "", final_line) |
|
|
if len(clean_line) > 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
return [t for t in techniques_in_line if "." not in t] |
|
|
|
|
|
def extract_mcq_answer_from_response(self, response: str) -> str: |
|
|
""" |
|
|
Extract the final answer (A, B, C, or D) from MCQ response. |
|
|
|
|
|
Args: |
|
|
response: Response text from the supervisor |
|
|
|
|
|
Returns: |
|
|
Extracted answer letter or empty string if not found |
|
|
""" |
|
|
|
|
|
lines = response.strip().split("\n") |
|
|
|
|
|
|
|
|
for line in reversed(lines[-3:]): |
|
|
line = line.strip() |
|
|
if line in ["A", "B", "C", "D"]: |
|
|
return line |
|
|
|
|
|
|
|
|
match = re.search(r"\b([ABCD])\b(?:\s*[.)]?)\s*$", line) |
|
|
if match: |
|
|
return match.group(1) |
|
|
|
|
|
|
|
|
answer_patterns = [ |
|
|
r"(?:answer|choice|option).*?([ABCD])", |
|
|
r"\b([ABCD])\b(?:\s*[.)]?)\s*$", |
|
|
r"^([ABCD])$", |
|
|
] |
|
|
|
|
|
for pattern in answer_patterns: |
|
|
matches = re.findall(pattern, response, re.IGNORECASE | re.MULTILINE) |
|
|
if matches: |
|
|
return matches[-1].upper() |
|
|
|
|
|
return "" |
|
|
|
|
|
def evaluate_ate_dataset(self, ate_df: pd.DataFrame) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Evaluate the CTI-ATE dataset. |
|
|
|
|
|
Args: |
|
|
ate_df: Filtered CTI-ATE dataset |
|
|
|
|
|
Returns: |
|
|
List of evaluation results |
|
|
""" |
|
|
results = [] |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("EVALUATING CTI-ATE DATASET") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for i, (idx, row) in enumerate(ate_df.iterrows()): |
|
|
print(f"Processing ATE sample {i + 1}/{len(ate_df)}: {row['URL']}") |
|
|
|
|
|
|
|
|
max_retries = 3 |
|
|
success = False |
|
|
result = None |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
print(f" Attempt {attempt + 1}/{max_retries}") |
|
|
|
|
|
|
|
|
query = self.ate_query_template.format( |
|
|
attack_description=row["Description"] |
|
|
) |
|
|
|
|
|
|
|
|
response = self.supervisor.invoke_direct_query(query, trace=False) |
|
|
|
|
|
|
|
|
if "messages" in response and response["messages"]: |
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(response["messages"]): |
|
|
try: |
|
|
if ( |
|
|
hasattr(msg, "content") |
|
|
and hasattr(msg, "type") |
|
|
and msg.type == "ai" |
|
|
): |
|
|
last_message = msg |
|
|
break |
|
|
except (AttributeError, TypeError) as e: |
|
|
|
|
|
print(f" Warning: Error accessing message type: {e}") |
|
|
continue |
|
|
|
|
|
if last_message: |
|
|
response_text = last_message.content |
|
|
else: |
|
|
|
|
|
try: |
|
|
response_text = response["messages"][-1].content |
|
|
except (AttributeError, TypeError) as e: |
|
|
print( |
|
|
f" Warning: Error accessing last message content: {e}" |
|
|
) |
|
|
response_text = str(response["messages"][-1]) |
|
|
else: |
|
|
response_text = str(response) |
|
|
|
|
|
|
|
|
predicted_techniques = self.extract_technique_ids_from_response( |
|
|
response_text |
|
|
) |
|
|
|
|
|
|
|
|
gt_techniques = [ |
|
|
t.strip() for t in row["GT"].split(",") if t.strip() |
|
|
] |
|
|
|
|
|
|
|
|
if len(predicted_techniques) > 0: |
|
|
success = True |
|
|
result = { |
|
|
"url": row["URL"], |
|
|
"description": row["Description"], |
|
|
"ground_truth": gt_techniques, |
|
|
"predicted": predicted_techniques, |
|
|
"response_text": response_text, |
|
|
"success": True, |
|
|
"attempts": attempt + 1, |
|
|
} |
|
|
print(f" GT: {gt_techniques}") |
|
|
print(f" Predicted: {predicted_techniques}") |
|
|
print(f" Success: {result['success']} (attempt {attempt + 1})") |
|
|
break |
|
|
else: |
|
|
print(f" No techniques extracted on attempt {attempt + 1}") |
|
|
if attempt == max_retries - 1: |
|
|
|
|
|
result = { |
|
|
"url": row["URL"], |
|
|
"description": row["Description"], |
|
|
"ground_truth": gt_techniques, |
|
|
"predicted": [], |
|
|
"response_text": response_text, |
|
|
"success": False, |
|
|
"attempts": max_retries, |
|
|
} |
|
|
print(f" GT: {gt_techniques}") |
|
|
print(f" Predicted: {predicted_techniques}") |
|
|
print( |
|
|
f" Success: {result['success']} (all attempts failed)" |
|
|
) |
|
|
print(f" Response text: {response_text}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Error processing sample (attempt {attempt + 1}): {e}") |
|
|
if attempt == max_retries - 1: |
|
|
|
|
|
result = { |
|
|
"url": row["URL"], |
|
|
"description": row["Description"], |
|
|
"ground_truth": [ |
|
|
t.strip() for t in row["GT"].split(",") if t.strip() |
|
|
], |
|
|
"predicted": [], |
|
|
"response_text": f"Error: {str(e)}", |
|
|
"success": False, |
|
|
"attempts": max_retries, |
|
|
} |
|
|
print(f" Success: {result['success']} (all attempts failed)") |
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_mcq_dataset(self, mcq_df: pd.DataFrame) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Evaluate the CTI-MCQ dataset. |
|
|
|
|
|
Args: |
|
|
mcq_df: Filtered CTI-MCQ dataset |
|
|
|
|
|
Returns: |
|
|
List of evaluation results |
|
|
""" |
|
|
results = [] |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("EVALUATING CTI-MCQ DATASET") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for i, (idx, row) in enumerate(mcq_df.iterrows()): |
|
|
print(f"Processing MCQ sample {i + 1}/{len(mcq_df)}: {row['URL']}") |
|
|
|
|
|
try: |
|
|
|
|
|
query = row["Prompt"] |
|
|
|
|
|
|
|
|
response = self.supervisor.invoke_direct_query(query, trace=False) |
|
|
|
|
|
|
|
|
if "messages" in response and response["messages"]: |
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(response["messages"]): |
|
|
try: |
|
|
if ( |
|
|
hasattr(msg, "content") |
|
|
and hasattr(msg, "type") |
|
|
and msg.type == "ai" |
|
|
): |
|
|
last_message = msg |
|
|
break |
|
|
except (AttributeError, TypeError) as e: |
|
|
|
|
|
print(f" Warning: Error accessing message type: {e}") |
|
|
continue |
|
|
|
|
|
if last_message: |
|
|
response_text = last_message.content |
|
|
else: |
|
|
|
|
|
try: |
|
|
response_text = response["messages"][-1].content |
|
|
except (AttributeError, TypeError) as e: |
|
|
print( |
|
|
f" Warning: Error accessing last message content: {e}" |
|
|
) |
|
|
response_text = str(response["messages"][-1]) |
|
|
else: |
|
|
response_text = str(response) |
|
|
|
|
|
|
|
|
predicted_answer = self.extract_mcq_answer_from_response(response_text) |
|
|
|
|
|
|
|
|
gt_answer = row["GT"].strip().upper() |
|
|
|
|
|
|
|
|
result = { |
|
|
"url": row["URL"], |
|
|
"prompt": row["Prompt"], |
|
|
"ground_truth": gt_answer, |
|
|
"predicted": predicted_answer, |
|
|
"response_text": response_text, |
|
|
"correct": predicted_answer == gt_answer, |
|
|
"success": len(predicted_answer) > 0, |
|
|
} |
|
|
|
|
|
results.append(result) |
|
|
|
|
|
print(f" GT: {gt_answer}") |
|
|
print(f" Predicted: {predicted_answer}") |
|
|
print(f" Correct: {result['correct']}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Error processing sample: {e}") |
|
|
result = { |
|
|
"url": row["URL"], |
|
|
"prompt": row["Prompt"], |
|
|
"ground_truth": row["GT"].strip().upper(), |
|
|
"predicted": "", |
|
|
"response_text": f"Error: {str(e)}", |
|
|
"correct": False, |
|
|
"success": False, |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
def calculate_ate_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, float]: |
|
|
""" |
|
|
Calculate evaluation metrics for ATE dataset using sample-level metrics. |
|
|
|
|
|
Args: |
|
|
results: List of ATE evaluation results |
|
|
|
|
|
Returns: |
|
|
Dictionary of calculated metrics |
|
|
""" |
|
|
if not results: |
|
|
return {} |
|
|
|
|
|
|
|
|
all_techniques = set() |
|
|
for result in results: |
|
|
all_techniques.update(result["ground_truth"]) |
|
|
all_techniques.update(result["predicted"]) |
|
|
|
|
|
all_techniques = sorted(list(all_techniques)) |
|
|
|
|
|
|
|
|
sample_precisions = [] |
|
|
sample_recalls = [] |
|
|
sample_f1s = [] |
|
|
|
|
|
for result in results: |
|
|
gt_set = set(result["ground_truth"]) |
|
|
pred_set = set(result["predicted"]) |
|
|
|
|
|
|
|
|
if len(pred_set) == 0: |
|
|
precision = 0.0 |
|
|
else: |
|
|
precision = len(gt_set.intersection(pred_set)) / len(pred_set) |
|
|
|
|
|
if len(gt_set) == 0: |
|
|
recall = 1.0 if len(pred_set) == 0 else 0.0 |
|
|
else: |
|
|
recall = len(gt_set.intersection(pred_set)) / len(gt_set) |
|
|
|
|
|
if precision + recall == 0: |
|
|
f1 = 0.0 |
|
|
else: |
|
|
f1 = 2 * (precision * recall) / (precision + recall) |
|
|
|
|
|
sample_precisions.append(precision) |
|
|
sample_recalls.append(recall) |
|
|
sample_f1s.append(f1) |
|
|
|
|
|
|
|
|
macro_precision = np.mean(sample_precisions) |
|
|
macro_recall = np.mean(sample_recalls) |
|
|
macro_f1 = np.mean(sample_f1s) |
|
|
|
|
|
|
|
|
total_tp = 0 |
|
|
total_fp = 0 |
|
|
total_fn = 0 |
|
|
|
|
|
for result in results: |
|
|
gt_set = set(result["ground_truth"]) |
|
|
pred_set = set(result["predicted"]) |
|
|
|
|
|
tp = len(gt_set.intersection(pred_set)) |
|
|
fp = len(pred_set - gt_set) |
|
|
fn = len(gt_set - pred_set) |
|
|
|
|
|
total_tp += tp |
|
|
total_fp += fp |
|
|
total_fn += fn |
|
|
|
|
|
|
|
|
if total_tp + total_fp == 0: |
|
|
micro_precision = 0.0 |
|
|
else: |
|
|
micro_precision = total_tp / (total_tp + total_fp) |
|
|
|
|
|
if total_tp + total_fn == 0: |
|
|
micro_recall = 0.0 |
|
|
else: |
|
|
micro_recall = total_tp / (total_tp + total_fn) |
|
|
|
|
|
if micro_precision + micro_recall == 0: |
|
|
micro_f1 = 0.0 |
|
|
else: |
|
|
micro_f1 = ( |
|
|
2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) |
|
|
) |
|
|
|
|
|
|
|
|
exact_match = sum( |
|
|
1 for r in results if set(r["ground_truth"]) == set(r["predicted"]) |
|
|
) / len(results) |
|
|
success_rate = sum(1 for r in results if r["success"]) / len(results) |
|
|
|
|
|
return { |
|
|
|
|
|
"macro_f1": macro_f1, |
|
|
"macro_precision": macro_precision, |
|
|
"macro_recall": macro_recall, |
|
|
"micro_f1": micro_f1, |
|
|
"micro_precision": micro_precision, |
|
|
"micro_recall": micro_recall, |
|
|
|
|
|
"exact_match_ratio": exact_match, |
|
|
"success_rate": success_rate, |
|
|
"total_samples": len(results), |
|
|
"total_techniques": len(all_techniques), |
|
|
} |
|
|
|
|
|
def calculate_mcq_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, float]: |
|
|
""" |
|
|
Calculate evaluation metrics for MCQ dataset. |
|
|
|
|
|
Args: |
|
|
results: List of MCQ evaluation results |
|
|
|
|
|
Returns: |
|
|
Dictionary of calculated metrics |
|
|
""" |
|
|
if not results: |
|
|
return {} |
|
|
|
|
|
|
|
|
y_true = [] |
|
|
y_pred = [] |
|
|
|
|
|
for result in results: |
|
|
if result["success"]: |
|
|
y_true.append(result["ground_truth"]) |
|
|
y_pred.append(result["predicted"]) |
|
|
|
|
|
if not y_true: |
|
|
return { |
|
|
"accuracy": 0.0, |
|
|
"f1_macro": 0.0, |
|
|
"f1_micro": 0.0, |
|
|
"precision_macro": 0.0, |
|
|
"recall_macro": 0.0, |
|
|
"success_rate": 0.0, |
|
|
"total_samples": len(results), |
|
|
"answered_samples": 0, |
|
|
} |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
|
f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0) |
|
|
f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0) |
|
|
precision_macro = precision_score( |
|
|
y_true, y_pred, average="macro", zero_division=0 |
|
|
) |
|
|
recall_macro = recall_score(y_true, y_pred, average="macro", zero_division=0) |
|
|
|
|
|
success_rate = sum(1 for r in results if r["success"]) / len(results) |
|
|
|
|
|
return { |
|
|
"accuracy": accuracy, |
|
|
"f1_macro": f1_macro, |
|
|
"f1_micro": f1_micro, |
|
|
"precision_macro": precision_macro, |
|
|
"recall_macro": recall_macro, |
|
|
"success_rate": success_rate, |
|
|
"total_samples": len(results), |
|
|
"answered_samples": len(y_true), |
|
|
} |
|
|
|
|
|
def save_results_to_csv( |
|
|
self, results: List[Dict[str, Any]], dataset_type: str, model_name: str = None |
|
|
): |
|
|
""" |
|
|
Save evaluation results to CSV file. |
|
|
|
|
|
Args: |
|
|
results: Evaluation results |
|
|
dataset_type: Type of dataset ("ate" or "mcq") |
|
|
model_name: Model name (if None, extracted from supervisor) |
|
|
""" |
|
|
if model_name is None: |
|
|
if self.supervisor is not None: |
|
|
model_name = self.supervisor.llm_model.split(":")[-1] |
|
|
else: |
|
|
model_name = "unknown_model" |
|
|
|
|
|
|
|
|
sanitized_model_name = self._sanitize_filename(model_name) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
if dataset_type == "ate": |
|
|
csv_path = ( |
|
|
self.output_dir / f"cti-ate_{sanitized_model_name}_{timestamp}.csv" |
|
|
) |
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["Description", "GT", "Predicted"]) |
|
|
|
|
|
for result in results: |
|
|
description = result["description"] |
|
|
gt = ", ".join(result["ground_truth"]) |
|
|
predicted = ", ".join(result["predicted"]) |
|
|
writer.writerow([description, gt, predicted]) |
|
|
|
|
|
print(f"ATE results saved to: {csv_path}") |
|
|
|
|
|
elif dataset_type == "mcq": |
|
|
csv_path = ( |
|
|
self.output_dir / f"cti-mcq_{sanitized_model_name}_{timestamp}.csv" |
|
|
) |
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["Prompt", "GT", "Predicted"]) |
|
|
|
|
|
for result in results: |
|
|
prompt = result["prompt"] |
|
|
writer.writerow( |
|
|
[prompt, result["ground_truth"], result["predicted"]] |
|
|
) |
|
|
|
|
|
print(f"MCQ results saved to: {csv_path}") |
|
|
else: |
|
|
raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
|
|
|
def save_evaluation_summary( |
|
|
self, metrics: Dict[str, float], dataset_type: str, model_name: str = None |
|
|
): |
|
|
""" |
|
|
Save evaluation summary to JSON file. |
|
|
|
|
|
Args: |
|
|
metrics: Evaluation metrics |
|
|
dataset_type: Type of dataset ("ate" or "mcq") |
|
|
model_name: Model name (if None, extracted from supervisor) |
|
|
""" |
|
|
if model_name is None: |
|
|
if self.supervisor is not None: |
|
|
model_name = self.supervisor.llm_model.split(":")[-1] |
|
|
else: |
|
|
model_name = "unknown_model" |
|
|
|
|
|
|
|
|
sanitized_model_name = self._sanitize_filename(model_name) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
summary = { |
|
|
"evaluation_timestamp": datetime.now().isoformat(), |
|
|
"dataset_type": dataset_type, |
|
|
"model_name": model_name, |
|
|
"metrics": metrics, |
|
|
} |
|
|
|
|
|
summary_path = ( |
|
|
self.output_dir |
|
|
/ f"evaluation_summary_{dataset_type}_{sanitized_model_name}_{timestamp}.json" |
|
|
) |
|
|
with open(summary_path, "w", encoding="utf-8") as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
|
|
|
print(f"Evaluation summary saved to: {summary_path}") |
|
|
|
|
|
def _extract_dataset_type_from_filename(self, filename: str) -> str: |
|
|
""" |
|
|
Extract dataset type from CSV filename. |
|
|
|
|
|
Args: |
|
|
filename: The filename (without extension) to extract dataset type from |
|
|
|
|
|
Returns: |
|
|
Dataset type ("ate" or "mcq") |
|
|
""" |
|
|
if "cti-ate" in filename.lower(): |
|
|
return "ate" |
|
|
elif "cti-mcq" in filename.lower(): |
|
|
return "mcq" |
|
|
else: |
|
|
raise ValueError(f"Cannot determine dataset type from filename: {filename}") |
|
|
|
|
|
def _sanitize_filename(self, filename: str) -> str: |
|
|
""" |
|
|
Sanitize a string to be safe for use in filenames. |
|
|
|
|
|
Args: |
|
|
filename: The string to sanitize |
|
|
|
|
|
Returns: |
|
|
Sanitized filename string |
|
|
""" |
|
|
import re |
|
|
|
|
|
|
|
|
sanitized = re.sub(r'[/\\:*?"<>|]', "-", filename) |
|
|
|
|
|
|
|
|
sanitized = re.sub(r"-+", "-", sanitized).strip("-") |
|
|
|
|
|
return sanitized if sanitized else "unknown" |
|
|
|
|
|
def read_csv_results( |
|
|
self, csv_path: str, dataset_type: str |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Read existing CSV results and convert to evaluation results format. |
|
|
|
|
|
Args: |
|
|
csv_path: Path to the CSV file |
|
|
dataset_type: Type of dataset ("ate" or "mcq") |
|
|
|
|
|
Returns: |
|
|
List of evaluation results in the same format as evaluate_*_dataset methods |
|
|
""" |
|
|
try: |
|
|
df = pd.read_csv(csv_path) |
|
|
results = [] |
|
|
|
|
|
if dataset_type == "ate": |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
|
|
|
gt_techniques = [ |
|
|
t.strip() for t in str(row["GT"]).split(",") if t.strip() |
|
|
] |
|
|
predicted_techniques = [ |
|
|
t.strip() for t in str(row["Predicted"]).split(",") if t.strip() |
|
|
] |
|
|
|
|
|
result = { |
|
|
"url": f"csv_row_{len(results)}", |
|
|
"description": str(row["Description"]), |
|
|
"ground_truth": gt_techniques, |
|
|
"predicted": predicted_techniques, |
|
|
"response_text": f"GT: {', '.join(gt_techniques)}, Predicted: {', '.join(predicted_techniques)}", |
|
|
"success": len(predicted_techniques) > 0, |
|
|
"attempts": 1, |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
elif dataset_type == "mcq": |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
gt_answer = str(row["GT"]).strip().upper() |
|
|
predicted_answer = str(row["Predicted"]).strip().upper() |
|
|
|
|
|
result = { |
|
|
"url": f"csv_row_{len(results)}", |
|
|
"prompt": str(row["Prompt"]), |
|
|
"ground_truth": gt_answer, |
|
|
"predicted": predicted_answer, |
|
|
"response_text": f"GT: {gt_answer}, Predicted: {predicted_answer}", |
|
|
"correct": predicted_answer == gt_answer, |
|
|
"success": len(predicted_answer) > 0, |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
|
|
|
print(f"Successfully read {len(results)} results from {csv_path}") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error reading CSV file {csv_path}: {e}") |
|
|
raise |
|
|
|
|
|
def calculate_metrics_from_csv( |
|
|
self, csv_path: str, model_name: str = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Read existing CSV results, calculate metrics, and save summary. |
|
|
|
|
|
Args: |
|
|
csv_path: Path to the CSV file |
|
|
model_name: Model name to use in summary (if None, extracted from filename) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing results and metrics |
|
|
""" |
|
|
|
|
|
filename = Path(csv_path).stem |
|
|
dataset_type = self._extract_dataset_type_from_filename(filename) |
|
|
|
|
|
if model_name is None: |
|
|
|
|
|
parts = filename.split("_") |
|
|
if len(parts) >= 2: |
|
|
model_name = parts[1] |
|
|
else: |
|
|
model_name = "unknown_model" |
|
|
|
|
|
print(f"Processing CSV file: {csv_path}") |
|
|
print(f"Dataset type: {dataset_type} (extracted from filename)") |
|
|
print(f"Model name: {model_name}") |
|
|
|
|
|
|
|
|
results = self.read_csv_results(csv_path, dataset_type) |
|
|
|
|
|
|
|
|
if dataset_type == "ate": |
|
|
metrics = self.calculate_ate_metrics(results) |
|
|
elif dataset_type == "mcq": |
|
|
metrics = self.calculate_mcq_metrics(results) |
|
|
else: |
|
|
raise ValueError(f"Invalid dataset type: {dataset_type}") |
|
|
|
|
|
|
|
|
sanitized_model_name = self._sanitize_filename(model_name) |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
summary = { |
|
|
"evaluation_timestamp": datetime.now().isoformat(), |
|
|
"dataset_type": dataset_type, |
|
|
"model_name": model_name, |
|
|
"source_csv": csv_path, |
|
|
"metrics": metrics, |
|
|
} |
|
|
|
|
|
summary_path = ( |
|
|
self.output_dir |
|
|
/ f"evaluation_summary_{dataset_type}_{sanitized_model_name}_{timestamp}.json" |
|
|
) |
|
|
with open(summary_path, "w", encoding="utf-8") as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
|
|
|
print(f"Evaluation summary saved to: {summary_path}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"METRICS FROM CSV: {dataset_type.upper()}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
if dataset_type == "ate": |
|
|
print(f"Macro F1: {metrics.get('macro_f1', 0.0):.3f}") |
|
|
print(f"Macro Precision: {metrics.get('macro_precision', 0.0):.3f}") |
|
|
print(f"Macro Recall: {metrics.get('macro_recall', 0.0):.3f}") |
|
|
print(f"Micro F1: {metrics.get('micro_f1', 0.0):.3f}") |
|
|
print(f"Exact Match: {metrics.get('exact_match_ratio', 0.0):.3f}") |
|
|
print(f"Success Rate: {metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f"Total Samples: {metrics.get('total_samples', 0)}") |
|
|
elif dataset_type == "mcq": |
|
|
print(f"Accuracy: {metrics.get('accuracy', 0.0):.3f}") |
|
|
print(f"F1 Macro: {metrics.get('f1_macro', 0.0):.3f}") |
|
|
print(f"Success Rate: {metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f"Total Samples: {metrics.get('total_samples', 0)}") |
|
|
|
|
|
print(f"{'='*60}") |
|
|
|
|
|
return { |
|
|
"results": results, |
|
|
"metrics": metrics, |
|
|
"summary_path": str(summary_path), |
|
|
} |
|
|
|
|
|
def run_full_evaluation(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Run the complete evaluation pipeline. |
|
|
|
|
|
Returns: |
|
|
Dictionary containing all evaluation results and metrics |
|
|
""" |
|
|
print("Starting CTI Bench evaluation...") |
|
|
print(f"Output directory: {self.output_dir}") |
|
|
|
|
|
|
|
|
ate_df, mcq_df = self.load_datasets() |
|
|
ate_filtered = self.filter_dataset(ate_df, "ate") |
|
|
mcq_filtered = self.filter_dataset(mcq_df, "mcq") |
|
|
|
|
|
|
|
|
ate_results = self.evaluate_ate_dataset(ate_filtered) |
|
|
ate_metrics = self.calculate_ate_metrics(ate_results) |
|
|
|
|
|
|
|
|
mcq_results = self.evaluate_mcq_dataset(mcq_filtered) |
|
|
mcq_metrics = self.calculate_mcq_metrics(mcq_results) |
|
|
|
|
|
|
|
|
self.save_results_to_csv(ate_results, "ate") |
|
|
self.save_results_to_csv(mcq_results, "mcq") |
|
|
self.save_evaluation_summary(ate_metrics, "ate") |
|
|
self.save_evaluation_summary(mcq_metrics, "mcq") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
print(f"CTI-ATE Results:") |
|
|
print(f" Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") |
|
|
print(f" Macro Precision: {ate_metrics.get('macro_precision', 0.0):.3f}") |
|
|
print(f" Macro Recall: {ate_metrics.get('macro_recall', 0.0):.3f}") |
|
|
print(f" Micro F1: {ate_metrics.get('micro_f1', 0.0):.3f}") |
|
|
print(f" Exact Match: {ate_metrics.get('exact_match_ratio', 0.0):.3f}") |
|
|
print(f" Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f" Total Samples: {ate_metrics.get('total_samples', 0)}") |
|
|
|
|
|
print(f"\nCTI-MCQ Results:") |
|
|
print(f" Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") |
|
|
print(f" F1 Macro: {mcq_metrics.get('f1_macro', 0.0):.3f}") |
|
|
print(f" Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f" Total Samples: {mcq_metrics.get('total_samples', 0)}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
return { |
|
|
"ate_results": ate_results, |
|
|
"mcq_results": mcq_results, |
|
|
"ate_metrics": ate_metrics, |
|
|
"mcq_metrics": mcq_metrics, |
|
|
} |
|
|
|
|
|
def run_ate_evaluation(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Run evaluation on ATE dataset only. |
|
|
|
|
|
Returns: |
|
|
Dictionary containing ATE evaluation results and metrics |
|
|
""" |
|
|
print("Starting CTI-ATE evaluation...") |
|
|
print(f"Output directory: {self.output_dir}") |
|
|
|
|
|
|
|
|
ate_df, mcq_df = self.load_datasets() |
|
|
ate_filtered = self.filter_dataset(ate_df, "ate") |
|
|
|
|
|
|
|
|
ate_results = self.evaluate_ate_dataset(ate_filtered) |
|
|
ate_metrics = self.calculate_ate_metrics(ate_results) |
|
|
|
|
|
|
|
|
self.save_results_to_csv(ate_results, "ate") |
|
|
self.save_evaluation_summary(ate_metrics, "ate") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("CTI-ATE EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
print(f"CTI-ATE Results:") |
|
|
print(f" Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") |
|
|
print(f" Macro Precision: {ate_metrics.get('macro_precision', 0.0):.3f}") |
|
|
print(f" Macro Recall: {ate_metrics.get('macro_recall', 0.0):.3f}") |
|
|
print(f" Micro F1: {ate_metrics.get('micro_f1', 0.0):.3f}") |
|
|
print(f" Exact Match: {ate_metrics.get('exact_match_ratio', 0.0):.3f}") |
|
|
print(f" Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f" Total Samples: {ate_metrics.get('total_samples', 0)}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
return { |
|
|
"ate_results": ate_results, |
|
|
"ate_metrics": ate_metrics, |
|
|
} |
|
|
|
|
|
def run_mcq_evaluation(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Run evaluation on MCQ dataset only. |
|
|
|
|
|
Returns: |
|
|
Dictionary containing MCQ evaluation results and metrics |
|
|
""" |
|
|
print("Starting CTI-MCQ evaluation...") |
|
|
print(f"Output directory: {self.output_dir}") |
|
|
|
|
|
|
|
|
ate_df, mcq_df = self.load_datasets() |
|
|
mcq_filtered = self.filter_dataset(mcq_df, "mcq") |
|
|
|
|
|
|
|
|
mcq_results = self.evaluate_mcq_dataset(mcq_filtered) |
|
|
mcq_metrics = self.calculate_mcq_metrics(mcq_results) |
|
|
|
|
|
|
|
|
self.save_results_to_csv(mcq_results, "mcq") |
|
|
self.save_evaluation_summary(mcq_metrics, "mcq") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("CTI-MCQ EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
print(f"CTI-MCQ Results:") |
|
|
print(f" Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") |
|
|
print(f" F1 Macro: {mcq_metrics.get('f1_macro', 0.0):.3f}") |
|
|
print(f" Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") |
|
|
print(f" Total Samples: {mcq_metrics.get('total_samples', 0)}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
return { |
|
|
"mcq_results": mcq_results, |
|
|
"mcq_metrics": mcq_metrics, |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the evaluation.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate Retrieval Supervisor on CTI Bench dataset" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset-dir", |
|
|
default="cti_bench/datasets", |
|
|
help="Directory containing CTI Bench datasets", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
default="cti_bench/eval_output", |
|
|
help="Directory to save evaluation results", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--kb-path", |
|
|
default="./cyber_knowledge_base", |
|
|
help="Path to cyber knowledge base", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--llm-model", default="google_genai:gemini-2.0-flash", help="LLM model to use" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-samples", |
|
|
type=int, |
|
|
help="Maximum number of samples to evaluate (for testing)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
|
|
|
|
print("Initializing Retrieval Supervisor...") |
|
|
supervisor = RetrievalSupervisor( |
|
|
llm_model=args.llm_model, kb_path=args.kb_path, max_iterations=3 |
|
|
) |
|
|
|
|
|
|
|
|
evaluator = CTIBenchEvaluator( |
|
|
supervisor=supervisor, |
|
|
dataset_dir=args.dataset_dir, |
|
|
output_dir=args.output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
results = evaluator.run_full_evaluation() |
|
|
|
|
|
print("Evaluation completed successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Evaluation failed: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|