Spiritual_Health_Project / src /core /verification_metrics.py
DocUA's picture
Add property-based tests for verification mode functionality
a3934b1
# verification_metrics.py
"""
Verification metrics calculator service.
Provides methods for calculating accuracy, confusion matrices, and error patterns
from verification records.
"""
from typing import Dict, List, Any
from src.core.verification_models import VerificationRecord
class VerificationMetricsCalculator:
"""Calculates performance metrics from verification records."""
@staticmethod
def calculate_accuracy(records: List[VerificationRecord]) -> float:
"""
Calculate overall accuracy from verification records.
Accuracy = (correct_count / total_count) * 100
Args:
records: List of verification records
Returns:
Accuracy as a percentage (0-100), or 0 if no records
"""
if not records:
return 0.0
correct_count = sum(1 for record in records if record.is_correct)
return (correct_count / len(records)) * 100
@staticmethod
def calculate_accuracy_by_type(
records: List[VerificationRecord],
) -> Dict[str, float]:
"""
Calculate accuracy for each classification type.
For each type (green, yellow, red), calculates:
accuracy = (correct_count_for_type / total_count_for_type) * 100
Args:
records: List of verification records
Returns:
Dictionary with keys "green", "yellow", "red" and accuracy percentages
"""
accuracy_by_type = {}
for classification_type in ["green", "yellow", "red"]:
type_records = [
r for r in records
if r.classifier_decision == classification_type
]
if type_records:
correct_count = sum(1 for r in type_records if r.is_correct)
accuracy_by_type[classification_type] = (
correct_count / len(type_records) * 100
)
else:
accuracy_by_type[classification_type] = 0.0
return accuracy_by_type
@staticmethod
def calculate_confusion_matrix(
records: List[VerificationRecord],
) -> Dict[str, Dict[str, int]]:
"""
Generate a confusion matrix from verification records.
The confusion matrix shows:
- Rows: classifier decisions (what the classifier said)
- Columns: ground truth labels (what the verifier said)
- Values: count of records in each cell
Args:
records: List of verification records
Returns:
Dictionary with structure:
{
"green": {"green": count, "yellow": count, "red": count},
"yellow": {"green": count, "yellow": count, "red": count},
"red": {"green": count, "yellow": count, "red": count},
}
"""
# Initialize matrix with zeros
matrix = {
"green": {"green": 0, "yellow": 0, "red": 0},
"yellow": {"green": 0, "yellow": 0, "red": 0},
"red": {"green": 0, "yellow": 0, "red": 0},
}
# Populate matrix
for record in records:
classifier_decision = record.classifier_decision
ground_truth = record.ground_truth_label
matrix[classifier_decision][ground_truth] += 1
return matrix
@staticmethod
def generate_error_patterns(
records: List[VerificationRecord],
) -> List[str]:
"""
Detect common error patterns from verification records.
Identifies patterns like:
- "Often misclassifies YELLOW as GREEN"
- "Frequently misses RED indicators"
Args:
records: List of verification records
Returns:
List of error pattern descriptions
"""
if not records:
return []
patterns = []
# Get confusion matrix
matrix = VerificationMetricsCalculator.calculate_confusion_matrix(records)
# Analyze each classification type
for classifier_type in ["green", "yellow", "red"]:
type_records = [
r for r in records
if r.classifier_decision == classifier_type
]
if not type_records:
continue
# Find most common misclassification
misclassifications = {}
for record in type_records:
if not record.is_correct:
ground_truth = record.ground_truth_label
misclassifications[ground_truth] = (
misclassifications.get(ground_truth, 0) + 1
)
if misclassifications:
most_common_wrong = max(
misclassifications.items(), key=lambda x: x[1]
)
wrong_type, wrong_count = most_common_wrong
# Calculate percentage of misclassifications
error_rate = (wrong_count / len(type_records)) * 100
if error_rate >= 20: # Only report if >= 20% error rate
pattern = (
f"Often misclassifies {classifier_type.upper()} "
f"as {wrong_type.upper()} ({error_rate:.0f}% of {classifier_type.upper()} cases)"
)
patterns.append(pattern)
# Analyze missed classifications (false negatives)
for ground_truth_type in ["green", "yellow", "red"]:
# Find records where classifier missed this type
missed = [
r for r in records
if r.ground_truth_label == ground_truth_type
and r.classifier_decision != ground_truth_type
]
if missed:
missed_rate = (len(missed) / len(records)) * 100
if missed_rate >= 10: # Only report if >= 10% miss rate
pattern = (
f"Frequently misses {ground_truth_type.upper()} indicators "
f"({missed_rate:.0f}% of all messages)"
)
patterns.append(pattern)
return patterns
@staticmethod
def get_metrics_summary(records: List[VerificationRecord]) -> Dict[str, Any]:
"""
Get a comprehensive summary of all metrics.
Args:
records: List of verification records
Returns:
Dictionary containing all calculated metrics
"""
if not records:
return {
"total_records": 0,
"correct_count": 0,
"incorrect_count": 0,
"accuracy": 0.0,
"accuracy_by_type": {"green": 0.0, "yellow": 0.0, "red": 0.0},
"confusion_matrix": {
"green": {"green": 0, "yellow": 0, "red": 0},
"yellow": {"green": 0, "yellow": 0, "red": 0},
"red": {"green": 0, "yellow": 0, "red": 0},
},
"error_patterns": [],
}
correct_count = sum(1 for r in records if r.is_correct)
return {
"total_records": len(records),
"correct_count": correct_count,
"incorrect_count": len(records) - correct_count,
"accuracy": VerificationMetricsCalculator.calculate_accuracy(records),
"accuracy_by_type": (
VerificationMetricsCalculator.calculate_accuracy_by_type(records)
),
"confusion_matrix": (
VerificationMetricsCalculator.calculate_confusion_matrix(records)
),
"error_patterns": (
VerificationMetricsCalculator.generate_error_patterns(records)
),
}