| """Tests for enhanced evaluation metrics.""" |
|
|
| import pytest |
| import torch |
| import pandas as pd |
| from evaluation.metrics import ( |
| precision, |
| recall, |
| f1_score, |
| exact_match, |
| per_class_metrics, |
| confusion_matrix_per_class, |
| get_optimal_threshold |
| ) |
| from evaluation.error_analysis import ErrorAnalyzer |
|
|
|
|
| class TestF1Score: |
| """Test suite for F1 score calculation.""" |
| |
| def test_f1_score_perfect(self): |
| """Test F1 score with perfect predictions.""" |
| target = torch.tensor([[1, 0, 1], [0, 1, 0]]) |
| pred = torch.tensor([[1, 0, 1], [0, 1, 0]]) |
| |
| f1 = f1_score(target, pred) |
| assert f1 == pytest.approx(1.0, abs=0.01) |
| |
| def test_f1_score_partial(self): |
| """Test F1 score with partial predictions.""" |
| target = torch.tensor([[1, 0, 1], [0, 1, 0]]) |
| pred = torch.tensor([[1, 1, 1], [0, 1, 0]]) |
| |
| f1 = f1_score(target, pred) |
| assert 0.0 <= f1 <= 1.0 |
| |
| def test_f1_score_no_predictions(self): |
| """Test F1 score when no predictions are made.""" |
| target = torch.tensor([[1, 0, 1], [0, 1, 0]]) |
| pred = torch.tensor([[0, 0, 0], [0, 0, 0]]) |
| |
| f1 = f1_score(target, pred) |
| assert f1 == pytest.approx(0.0, abs=0.01) |
|
|
|
|
| class TestPerClassMetrics: |
| """Test suite for per-class metrics.""" |
| |
| def test_per_class_metrics_basic(self): |
| """Test basic per-class metrics calculation.""" |
| target = torch.tensor([[1, 0, 1], [0, 1, 0]]) |
| pred = torch.tensor([[1, 1, 1], [0, 1, 0]]) |
| class_names = ["class_0", "class_1", "class_2"] |
| |
| metrics = per_class_metrics(target, pred, class_names) |
| |
| assert "class_0" in metrics |
| assert "class_1" in metrics |
| assert "class_2" in metrics |
| |
| for class_name, class_metrics in metrics.items(): |
| assert "precision" in class_metrics |
| assert "recall" in class_metrics |
| assert "f1" in class_metrics |
| assert "support" in class_metrics |
| assert "tp" in class_metrics |
| assert "fp" in class_metrics |
| assert "fn" in class_metrics |
| assert "tn" in class_metrics |
| |
| def test_per_class_metrics_default_names(self): |
| """Test per-class metrics with default class names.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[1, 0], [0, 1]]) |
| |
| metrics = per_class_metrics(target, pred) |
| |
| assert "class_0" in metrics |
| assert "class_1" in metrics |
| |
| def test_per_class_metrics_values(self): |
| """Test per-class metrics values.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[1, 0], [0, 1]]) |
| |
| metrics = per_class_metrics(target, pred, ["class_0", "class_1"]) |
| |
| |
| assert metrics["class_0"]["precision"] == pytest.approx(1.0, abs=0.01) |
| assert metrics["class_0"]["recall"] == pytest.approx(1.0, abs=0.01) |
| assert metrics["class_0"]["f1"] == pytest.approx(1.0, abs=0.01) |
|
|
|
|
| class TestConfusionMatrix: |
| """Test suite for confusion matrix calculation.""" |
| |
| def test_confusion_matrix_per_class(self): |
| """Test confusion matrix calculation per class.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[1, 0], [0, 1]]) |
| class_names = ["class_0", "class_1"] |
| |
| matrices = confusion_matrix_per_class(target, pred, class_names) |
| |
| assert "class_0" in matrices |
| assert "class_1" in matrices |
| |
| |
| assert matrices["class_0"].shape == (2, 2) |
| assert matrices["class_1"].shape == (2, 2) |
| |
| def test_confusion_matrix_values(self): |
| """Test confusion matrix values.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[1, 0], [0, 1]]) |
| |
| matrices = confusion_matrix_per_class(target, pred, ["class_0", "class_1"]) |
| |
| |
| matrix = matrices["class_0"] |
| assert matrix[0, 0].item() == 1.0 |
| assert matrix[0, 1].item() == 0.0 |
| assert matrix[1, 0].item() == 0.0 |
| assert matrix[1, 1].item() == 1.0 |
|
|
|
|
| class TestErrorAnalysis: |
| """Test suite for error analysis.""" |
| |
| @pytest.fixture |
| def analyzer(self): |
| """Create ErrorAnalyzer instance.""" |
| return ErrorAnalyzer() |
| |
| def test_analyze_false_positives(self, analyzer): |
| """Test false positive analysis.""" |
| target = torch.tensor([[0, 1], [1, 0]]) |
| pred = torch.tensor([[1, 1], [1, 0]]) |
| |
| fps = analyzer.analyze_false_positives(target, pred, ["class_0", "class_1"]) |
| |
| assert "class_0" in fps |
| assert "class_1" in fps |
| |
| assert 0 in fps["class_0"] |
| |
| def test_analyze_false_negatives(self, analyzer): |
| """Test false negative analysis.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[0, 0], [0, 0]]) |
| |
| fns = analyzer.analyze_false_negatives(target, pred, ["class_0", "class_1"]) |
| |
| assert "class_0" in fns |
| assert "class_1" in fns |
| |
| assert 0 in fns["class_0"] |
| assert 1 in fns["class_1"] |
| |
| def test_find_common_misclassification_patterns(self, analyzer): |
| """Test finding common misclassification patterns.""" |
| target = torch.tensor([[1, 0], [0, 1], [1, 1]]) |
| pred = torch.tensor([[0, 1], [1, 0], [1, 1]]) |
| |
| patterns = analyzer.find_common_misclassification_patterns( |
| target, pred, ["class_0", "class_1"], top_k=5 |
| ) |
| |
| assert isinstance(patterns, list) |
| |
| |
| def test_analyze_class_confusion(self, analyzer): |
| """Test class confusion analysis.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[0, 1], [1, 0]]) |
| |
| confusion_df = analyzer.analyze_class_confusion( |
| target, pred, ["class_0", "class_1"] |
| ) |
| |
| assert isinstance(confusion_df, type(pd.DataFrame())) |
| |
| def test_get_error_summary(self, analyzer): |
| """Test error summary generation.""" |
| target = torch.tensor([[1, 0], [0, 1]]) |
| pred = torch.tensor([[1, 0], [0, 1]]) |
| |
| summary = analyzer.get_error_summary(target, pred, ["class_0", "class_1"]) |
| |
| assert "total_samples" in summary |
| assert "total_classes" in summary |
| assert "total_false_positives" in summary |
| assert "total_false_negatives" in summary |
| assert "per_class_metrics" in summary |
|
|
|
|
| class TestThresholdOptimization: |
| """Test suite for threshold optimization.""" |
| |
| def test_get_optimal_threshold_f1(self): |
| """Test threshold optimization for F1 score.""" |
| |
| |
| |
| |
| |
| |
| |
| threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5] |
| |
| |
| |
| assert True |
|
|
|
|