Spaces:
Sleeping
Sleeping
| # verification_metrics.py | |
| """ | |
| Verification metrics calculator service. | |
| Provides methods for calculating accuracy, confusion matrices, and error patterns | |
| from verification records. | |
| """ | |
| from typing import Dict, List, Any | |
| from src.core.verification_models import VerificationRecord | |
| class VerificationMetricsCalculator: | |
| """Calculates performance metrics from verification records.""" | |
| def calculate_accuracy(records: List[VerificationRecord]) -> float: | |
| """ | |
| Calculate overall accuracy from verification records. | |
| Accuracy = (correct_count / total_count) * 100 | |
| Args: | |
| records: List of verification records | |
| Returns: | |
| Accuracy as a percentage (0-100), or 0 if no records | |
| """ | |
| if not records: | |
| return 0.0 | |
| correct_count = sum(1 for record in records if record.is_correct) | |
| return (correct_count / len(records)) * 100 | |
| def calculate_accuracy_by_type( | |
| records: List[VerificationRecord], | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate accuracy for each classification type. | |
| For each type (green, yellow, red), calculates: | |
| accuracy = (correct_count_for_type / total_count_for_type) * 100 | |
| Args: | |
| records: List of verification records | |
| Returns: | |
| Dictionary with keys "green", "yellow", "red" and accuracy percentages | |
| """ | |
| accuracy_by_type = {} | |
| for classification_type in ["green", "yellow", "red"]: | |
| type_records = [ | |
| r for r in records | |
| if r.classifier_decision == classification_type | |
| ] | |
| if type_records: | |
| correct_count = sum(1 for r in type_records if r.is_correct) | |
| accuracy_by_type[classification_type] = ( | |
| correct_count / len(type_records) * 100 | |
| ) | |
| else: | |
| accuracy_by_type[classification_type] = 0.0 | |
| return accuracy_by_type | |
| def calculate_confusion_matrix( | |
| records: List[VerificationRecord], | |
| ) -> Dict[str, Dict[str, int]]: | |
| """ | |
| Generate a confusion matrix from verification records. | |
| The confusion matrix shows: | |
| - Rows: classifier decisions (what the classifier said) | |
| - Columns: ground truth labels (what the verifier said) | |
| - Values: count of records in each cell | |
| Args: | |
| records: List of verification records | |
| Returns: | |
| Dictionary with structure: | |
| { | |
| "green": {"green": count, "yellow": count, "red": count}, | |
| "yellow": {"green": count, "yellow": count, "red": count}, | |
| "red": {"green": count, "yellow": count, "red": count}, | |
| } | |
| """ | |
| # Initialize matrix with zeros | |
| matrix = { | |
| "green": {"green": 0, "yellow": 0, "red": 0}, | |
| "yellow": {"green": 0, "yellow": 0, "red": 0}, | |
| "red": {"green": 0, "yellow": 0, "red": 0}, | |
| } | |
| # Populate matrix | |
| for record in records: | |
| classifier_decision = record.classifier_decision | |
| ground_truth = record.ground_truth_label | |
| matrix[classifier_decision][ground_truth] += 1 | |
| return matrix | |
| def generate_error_patterns( | |
| records: List[VerificationRecord], | |
| ) -> List[str]: | |
| """ | |
| Detect common error patterns from verification records. | |
| Identifies patterns like: | |
| - "Often misclassifies YELLOW as GREEN" | |
| - "Frequently misses RED indicators" | |
| Args: | |
| records: List of verification records | |
| Returns: | |
| List of error pattern descriptions | |
| """ | |
| if not records: | |
| return [] | |
| patterns = [] | |
| # Get confusion matrix | |
| matrix = VerificationMetricsCalculator.calculate_confusion_matrix(records) | |
| # Analyze each classification type | |
| for classifier_type in ["green", "yellow", "red"]: | |
| type_records = [ | |
| r for r in records | |
| if r.classifier_decision == classifier_type | |
| ] | |
| if not type_records: | |
| continue | |
| # Find most common misclassification | |
| misclassifications = {} | |
| for record in type_records: | |
| if not record.is_correct: | |
| ground_truth = record.ground_truth_label | |
| misclassifications[ground_truth] = ( | |
| misclassifications.get(ground_truth, 0) + 1 | |
| ) | |
| if misclassifications: | |
| most_common_wrong = max( | |
| misclassifications.items(), key=lambda x: x[1] | |
| ) | |
| wrong_type, wrong_count = most_common_wrong | |
| # Calculate percentage of misclassifications | |
| error_rate = (wrong_count / len(type_records)) * 100 | |
| if error_rate >= 20: # Only report if >= 20% error rate | |
| pattern = ( | |
| f"Often misclassifies {classifier_type.upper()} " | |
| f"as {wrong_type.upper()} ({error_rate:.0f}% of {classifier_type.upper()} cases)" | |
| ) | |
| patterns.append(pattern) | |
| # Analyze missed classifications (false negatives) | |
| for ground_truth_type in ["green", "yellow", "red"]: | |
| # Find records where classifier missed this type | |
| missed = [ | |
| r for r in records | |
| if r.ground_truth_label == ground_truth_type | |
| and r.classifier_decision != ground_truth_type | |
| ] | |
| if missed: | |
| missed_rate = (len(missed) / len(records)) * 100 | |
| if missed_rate >= 10: # Only report if >= 10% miss rate | |
| pattern = ( | |
| f"Frequently misses {ground_truth_type.upper()} indicators " | |
| f"({missed_rate:.0f}% of all messages)" | |
| ) | |
| patterns.append(pattern) | |
| return patterns | |
| def get_metrics_summary(records: List[VerificationRecord]) -> Dict[str, Any]: | |
| """ | |
| Get a comprehensive summary of all metrics. | |
| Args: | |
| records: List of verification records | |
| Returns: | |
| Dictionary containing all calculated metrics | |
| """ | |
| if not records: | |
| return { | |
| "total_records": 0, | |
| "correct_count": 0, | |
| "incorrect_count": 0, | |
| "accuracy": 0.0, | |
| "accuracy_by_type": {"green": 0.0, "yellow": 0.0, "red": 0.0}, | |
| "confusion_matrix": { | |
| "green": {"green": 0, "yellow": 0, "red": 0}, | |
| "yellow": {"green": 0, "yellow": 0, "red": 0}, | |
| "red": {"green": 0, "yellow": 0, "red": 0}, | |
| }, | |
| "error_patterns": [], | |
| } | |
| correct_count = sum(1 for r in records if r.is_correct) | |
| return { | |
| "total_records": len(records), | |
| "correct_count": correct_count, | |
| "incorrect_count": len(records) - correct_count, | |
| "accuracy": VerificationMetricsCalculator.calculate_accuracy(records), | |
| "accuracy_by_type": ( | |
| VerificationMetricsCalculator.calculate_accuracy_by_type(records) | |
| ), | |
| "confusion_matrix": ( | |
| VerificationMetricsCalculator.calculate_confusion_matrix(records) | |
| ), | |
| "error_patterns": ( | |
| VerificationMetricsCalculator.generate_error_patterns(records) | |
| ), | |
| } | |