SPG_ML / tests /test_explainability.py
meetmendapara's picture
Initial commit for ML space
df31aa1
"""
Unit tests for ML explainability module
"""
import pytest
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from explainability import (
SHAPExplainer,
CounterfactualExplainer,
RecommendationGenerator,
ExplanationAggregator,
)
class TestSHAPExplainer:
"""Tests for SHAPExplainer class"""
@pytest.fixture
def explainer(self):
return SHAPExplainer()
def test_explain_returns_required_fields(self, explainer):
"""Test that explanation contains required fields"""
features = {
"complexity_normalized": 0.6,
"pri_attention_demand": 0.5,
"trait_conscientiousness": 0.7,
}
prediction = 0.75
explanation = explainer.explain(features, prediction)
assert "shap_values" in explanation
assert "base_value" in explanation
assert "feature_ranking" in explanation
assert "prediction" in explanation
def test_shap_values_exist_for_features(self, explainer):
"""Test that SHAP values are computed for input features"""
features = {
"complexity_normalized": 0.6,
"pri_attention_demand": 0.5,
"trait_conscientiousness": 0.7,
}
prediction = 0.75
explanation = explainer.explain(features, prediction)
# Should have SHAP values for the features we provided
assert len(explanation["shap_values"]) > 0
def test_feature_ranking_ordered_by_importance(self, explainer):
"""Test that features are ranked by importance"""
features = {
"complexity_normalized": 1.0, # High complexity
"pri_attention_demand": 0.1,
"trait_conscientiousness": 0.3,
}
prediction = 0.4
explanation = explainer.explain(features, prediction)
ranking = explanation["feature_ranking"]
# Check that ranking is sorted by absolute impact
impacts = [abs(r["impact"]) for r in ranking]
assert impacts == sorted(impacts, reverse=True)
class TestCounterfactualExplainer:
"""Tests for CounterfactualExplainer class"""
@pytest.fixture
def explainer(self):
return CounterfactualExplainer()
def test_generate_counterfactuals(self, explainer):
"""Test counterfactual generation"""
features = {"complexity_normalized": 1.0, "pri_attention_demand": 0.8}
prediction = 0.4
target = 0.7
counterfactuals = explainer.generate_counterfactuals(features, prediction, target)
assert isinstance(counterfactuals, list)
# Should generate some counterfactuals
assert len(counterfactuals) >= 0 # May be empty if no improvements possible
def test_counterfactuals_when_already_meeting_target(self, explainer):
"""Test counterfactuals when prediction already meets target"""
features = {"complexity_normalized": 0.3, "pri_attention_demand": 0.6}
prediction = 0.8 # Already above target
target = 0.7
counterfactuals = explainer.generate_counterfactuals(features, prediction, target)
# Should return message that target is already met
assert len(counterfactuals) > 0
class TestRecommendationGenerator:
"""Tests for RecommendationGenerator class"""
@pytest.fixture
def generator(self):
return RecommendationGenerator()
def test_generate_recommendations(self, generator):
"""Test recommendation generation"""
features = {
"complexity_normalized": 0.8,
"time_pressure": 0.5,
"trait_conscientiousness": 0.6,
}
prediction = 0.4
stress_level = 8
difficulty = "HARD"
recommendations = generator.generate_recommendations(features, prediction, stress_level, difficulty)
assert isinstance(recommendations, list)
assert len(recommendations) > 0
for rec in recommendations:
assert "title" in rec
assert "description" in rec
assert "priority" in rec
def test_high_stress_triggers_recommendations(self, generator):
"""Test that high stress triggers appropriate recommendations"""
features = {
"complexity_normalized": 0.5,
"time_pressure": 0.2,
}
prediction = 0.7
stress_level = 9 # High stress
difficulty = "MODERATE"
recommendations = generator.generate_recommendations(features, prediction, stress_level, difficulty)
# Recommendations should exist
assert len(recommendations) > 0
def test_low_probability_triggers_recommendations(self, generator):
"""Test that low probability triggers task-related recommendations"""
features = {
"complexity_normalized": 0.9, # High complexity
"time_pressure": 0.6,
}
prediction = 0.3
stress_level = 4
difficulty = "HARD"
recommendations = generator.generate_recommendations(features, prediction, stress_level, difficulty)
# Should have recommendations
assert len(recommendations) > 0
# At least one should be high priority
priorities = [r["priority"] for r in recommendations]
assert "high" in priorities or "medium" in priorities
class TestExplanationAggregator:
"""Tests for ExplanationAggregator class"""
@pytest.fixture
def aggregator(self):
return ExplanationAggregator()
def test_generate_full_explanation(self, aggregator):
"""Test that aggregator creates comprehensive explanation"""
features = {
"complexity_normalized": 0.6,
"trait_conscientiousness": 0.7,
"time_pressure": 0.4,
}
prediction = {
"completion_probability": 0.65,
"stress_level": 6,
"difficulty_level": "MODERATE",
}
task_data = {
"title": "Test Task",
"category": "WORK",
}
explanation = aggregator.generate_full_explanation(features, prediction, task_data)
assert isinstance(explanation, dict)
assert "prediction_summary" in explanation
assert "feature_attribution" in explanation
assert "recommendations" in explanation
def test_generate_full_explanation_with_low_probability(self, aggregator):
"""Test explanation generation for low probability task"""
features = {
"complexity_normalized": 0.9,
"trait_conscientiousness": 0.4,
"time_pressure": 0.7,
}
prediction = {
"completion_probability": 0.35,
"stress_level": 8,
"difficulty_level": "HARD",
}
task_data = {
"title": "Difficult Task",
"category": "ACADEMIC",
}
explanation = aggregator.generate_full_explanation(features, prediction, task_data)
# Should have counterfactual scenarios for low probability
assert "counterfactual_scenarios" in explanation
# Should have recommendations
assert len(explanation["recommendations"]) > 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])