multilabel-news-classifier / tests /test_enhanced_evaluation.py
Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""Tests for enhanced evaluation metrics."""
import pytest
import torch
import pandas as pd
from evaluation.metrics import (
precision,
recall,
f1_score,
exact_match,
per_class_metrics,
confusion_matrix_per_class,
get_optimal_threshold
)
from evaluation.error_analysis import ErrorAnalyzer
class TestF1Score:
"""Test suite for F1 score calculation."""
def test_f1_score_perfect(self):
"""Test F1 score with perfect predictions."""
target = torch.tensor([[1, 0, 1], [0, 1, 0]])
pred = torch.tensor([[1, 0, 1], [0, 1, 0]])
f1 = f1_score(target, pred)
assert f1 == pytest.approx(1.0, abs=0.01)
def test_f1_score_partial(self):
"""Test F1 score with partial predictions."""
target = torch.tensor([[1, 0, 1], [0, 1, 0]])
pred = torch.tensor([[1, 1, 1], [0, 1, 0]])
f1 = f1_score(target, pred)
assert 0.0 <= f1 <= 1.0
def test_f1_score_no_predictions(self):
"""Test F1 score when no predictions are made."""
target = torch.tensor([[1, 0, 1], [0, 1, 0]])
pred = torch.tensor([[0, 0, 0], [0, 0, 0]])
f1 = f1_score(target, pred)
assert f1 == pytest.approx(0.0, abs=0.01)
class TestPerClassMetrics:
"""Test suite for per-class metrics."""
def test_per_class_metrics_basic(self):
"""Test basic per-class metrics calculation."""
target = torch.tensor([[1, 0, 1], [0, 1, 0]])
pred = torch.tensor([[1, 1, 1], [0, 1, 0]])
class_names = ["class_0", "class_1", "class_2"]
metrics = per_class_metrics(target, pred, class_names)
assert "class_0" in metrics
assert "class_1" in metrics
assert "class_2" in metrics
for class_name, class_metrics in metrics.items():
assert "precision" in class_metrics
assert "recall" in class_metrics
assert "f1" in class_metrics
assert "support" in class_metrics
assert "tp" in class_metrics
assert "fp" in class_metrics
assert "fn" in class_metrics
assert "tn" in class_metrics
def test_per_class_metrics_default_names(self):
"""Test per-class metrics with default class names."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[1, 0], [0, 1]])
metrics = per_class_metrics(target, pred)
assert "class_0" in metrics
assert "class_1" in metrics
def test_per_class_metrics_values(self):
"""Test per-class metrics values."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[1, 0], [0, 1]])
metrics = per_class_metrics(target, pred, ["class_0", "class_1"])
# Perfect predictions should have precision, recall, f1 = 1.0
assert metrics["class_0"]["precision"] == pytest.approx(1.0, abs=0.01)
assert metrics["class_0"]["recall"] == pytest.approx(1.0, abs=0.01)
assert metrics["class_0"]["f1"] == pytest.approx(1.0, abs=0.01)
class TestConfusionMatrix:
"""Test suite for confusion matrix calculation."""
def test_confusion_matrix_per_class(self):
"""Test confusion matrix calculation per class."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[1, 0], [0, 1]])
class_names = ["class_0", "class_1"]
matrices = confusion_matrix_per_class(target, pred, class_names)
assert "class_0" in matrices
assert "class_1" in matrices
# Check matrix shape (2x2)
assert matrices["class_0"].shape == (2, 2)
assert matrices["class_1"].shape == (2, 2)
def test_confusion_matrix_values(self):
"""Test confusion matrix values."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[1, 0], [0, 1]])
matrices = confusion_matrix_per_class(target, pred, ["class_0", "class_1"])
# Perfect predictions: TN=1, TP=1, FP=0, FN=0
matrix = matrices["class_0"]
assert matrix[0, 0].item() == 1.0 # TN
assert matrix[0, 1].item() == 0.0 # FP
assert matrix[1, 0].item() == 0.0 # FN
assert matrix[1, 1].item() == 1.0 # TP
class TestErrorAnalysis:
"""Test suite for error analysis."""
@pytest.fixture
def analyzer(self):
"""Create ErrorAnalyzer instance."""
return ErrorAnalyzer()
def test_analyze_false_positives(self, analyzer):
"""Test false positive analysis."""
target = torch.tensor([[0, 1], [1, 0]])
pred = torch.tensor([[1, 1], [1, 0]])
fps = analyzer.analyze_false_positives(target, pred, ["class_0", "class_1"])
assert "class_0" in fps
assert "class_1" in fps
# Sample 0 has FP for class_0
assert 0 in fps["class_0"]
def test_analyze_false_negatives(self, analyzer):
"""Test false negative analysis."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[0, 0], [0, 0]])
fns = analyzer.analyze_false_negatives(target, pred, ["class_0", "class_1"])
assert "class_0" in fns
assert "class_1" in fns
# Sample 0 has FN for class_0, sample 1 has FN for class_1
assert 0 in fns["class_0"]
assert 1 in fns["class_1"]
def test_find_common_misclassification_patterns(self, analyzer):
"""Test finding common misclassification patterns."""
target = torch.tensor([[1, 0], [0, 1], [1, 1]])
pred = torch.tensor([[0, 1], [1, 0], [1, 1]])
patterns = analyzer.find_common_misclassification_patterns(
target, pred, ["class_0", "class_1"], top_k=5
)
assert isinstance(patterns, list)
# Should find some patterns (samples 0 and 1 are misclassified)
def test_analyze_class_confusion(self, analyzer):
"""Test class confusion analysis."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[0, 1], [1, 0]])
confusion_df = analyzer.analyze_class_confusion(
target, pred, ["class_0", "class_1"]
)
assert isinstance(confusion_df, type(pd.DataFrame()))
def test_get_error_summary(self, analyzer):
"""Test error summary generation."""
target = torch.tensor([[1, 0], [0, 1]])
pred = torch.tensor([[1, 0], [0, 1]])
summary = analyzer.get_error_summary(target, pred, ["class_0", "class_1"])
assert "total_samples" in summary
assert "total_classes" in summary
assert "total_false_positives" in summary
assert "total_false_negatives" in summary
assert "per_class_metrics" in summary
class TestThresholdOptimization:
"""Test suite for threshold optimization."""
def test_get_optimal_threshold_f1(self):
"""Test threshold optimization for F1 score."""
# Create dummy model and dataset for testing
# This is a simplified test - in practice, you'd need actual model and data
# For now, we'll test that the function accepts 'f1' as metric
# Actual optimization requires real model and dataset
# This test verifies the metric is accepted
threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5]
# We can't easily test the full function without a model,
# but we can verify the metric parameter is handled
assert True # Placeholder - actual test would require model setup