| """ |
| Tests for the XAI Engine (SHAP token attribution) |
| |
| Note: These tests mock the SHAP explainer since the actual HuggingFace model |
| may not be installed in the test environment. |
| """ |
|
|
| import asyncio |
| from unittest.mock import MagicMock, patch |
|
|
| from src.services.xai_engine import XAIEngine |
|
|
|
|
| def _run(coro): |
| """Helper to run async coroutines in sync tests.""" |
| return asyncio.run(coro) |
|
|
|
|
| def test_explain_without_initialization(): |
| """When the engine isn't initialized, should return fallback output.""" |
| engine = XAIEngine() |
| result = _run(engine.explain("verify your account immediately")) |
|
|
| assert result["fallback"] is True |
| assert result["tokens"] == [] |
| assert result["shap_values"] == [] |
| assert "model_score" in result |
| assert "timestamp" in result |
|
|
|
|
| def test_model_prediction_without_pipeline(): |
| """When no pipeline is set, model prediction returns UNKNOWN.""" |
| engine = XAIEngine() |
| result = engine._get_model_prediction("test text") |
|
|
| assert result["label"] == "UNKNOWN" |
| assert result["score"] == 0.0 |
|
|
|
|
| def test_model_prediction_with_pipeline(): |
| """When a pipeline is provided, model prediction returns results.""" |
| engine = XAIEngine() |
|
|
| mock_pipeline = MagicMock() |
| mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.92}] |
| engine.model_pipeline = mock_pipeline |
|
|
| result = engine._get_model_prediction("verify your account") |
|
|
| assert result["label"] == "LABEL_1" |
| assert result["score"] == 0.92 |
|
|
|
|
| def test_explain_caches_results(): |
| """When engine IS initialized, repeated calls with the same text should return cached results.""" |
| engine = XAIEngine() |
|
|
| |
| mock_pipeline = MagicMock() |
| mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.88}] |
| engine.model_pipeline = mock_pipeline |
|
|
| |
| |
| result1 = _run(engine.explain("test text for caching")) |
| result2 = _run(engine.explain("test text for caching")) |
|
|
| assert result1["fallback"] == result2["fallback"] |
| assert result1["model_label"] == result2["model_label"] |
| assert "timestamp" in result1 |
| assert "timestamp" in result2 |
|
|
|
|
| def test_explain_timeout_fallback(): |
| """If SHAP computation takes too long, should fall back gracefully.""" |
| engine = XAIEngine() |
|
|
| |
| engine._initialized = True |
|
|
| |
| mock_pipeline = MagicMock() |
| mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.88}] |
| engine.model_pipeline = mock_pipeline |
|
|
| |
| import time |
|
|
| def slow_explainer(texts): |
| time.sleep(10) |
|
|
| engine.explainer = slow_explainer |
|
|
| |
| original_compute = engine._compute_shap |
| def slow_compute(text): |
| import time |
| time.sleep(10) |
| return original_compute(text) |
|
|
| engine._compute_shap = slow_compute |
|
|
| result = _run(engine.explain("timeout test text unique")) |
|
|
| assert result["fallback"] is True |
| assert "timed out" in result.get("reason", "").lower() or result["tokens"] == [] |
|
|