KES-Hack / tests /test_xai_engine.py
Meshyboi's picture
Upload 86 files
214209a verified
"""
Tests for the XAI Engine (SHAP token attribution)
Note: These tests mock the SHAP explainer since the actual HuggingFace model
may not be installed in the test environment.
"""
import asyncio
from unittest.mock import MagicMock, patch
from src.services.xai_engine import XAIEngine
def _run(coro):
"""Helper to run async coroutines in sync tests."""
return asyncio.run(coro)
def test_explain_without_initialization():
"""When the engine isn't initialized, should return fallback output."""
engine = XAIEngine()
result = _run(engine.explain("verify your account immediately"))
assert result["fallback"] is True
assert result["tokens"] == []
assert result["shap_values"] == []
assert "model_score" in result
assert "timestamp" in result
def test_model_prediction_without_pipeline():
"""When no pipeline is set, model prediction returns UNKNOWN."""
engine = XAIEngine()
result = engine._get_model_prediction("test text")
assert result["label"] == "UNKNOWN"
assert result["score"] == 0.0
def test_model_prediction_with_pipeline():
"""When a pipeline is provided, model prediction returns results."""
engine = XAIEngine()
mock_pipeline = MagicMock()
mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.92}]
engine.model_pipeline = mock_pipeline
result = engine._get_model_prediction("verify your account")
assert result["label"] == "LABEL_1"
assert result["score"] == 0.92
def test_explain_caches_results():
"""When engine IS initialized, repeated calls with the same text should return cached results."""
engine = XAIEngine()
# Mock a simple pipeline + explainer
mock_pipeline = MagicMock()
mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.88}]
engine.model_pipeline = mock_pipeline
# Without initialization, caching won't apply, but both calls should succeed
# and return consistent structure
result1 = _run(engine.explain("test text for caching"))
result2 = _run(engine.explain("test text for caching"))
assert result1["fallback"] == result2["fallback"]
assert result1["model_label"] == result2["model_label"]
assert "timestamp" in result1
assert "timestamp" in result2
def test_explain_timeout_fallback():
"""If SHAP computation takes too long, should fall back gracefully."""
engine = XAIEngine()
# Mock as initialized
engine._initialized = True
# Mock pipeline
mock_pipeline = MagicMock()
mock_pipeline.return_value = [{"label": "LABEL_1", "score": 0.88}]
engine.model_pipeline = mock_pipeline
# Mock explainer that sleeps forever
import time
def slow_explainer(texts):
time.sleep(10)
engine.explainer = slow_explainer
# Mock _compute_shap to simulate timeout
original_compute = engine._compute_shap
def slow_compute(text):
import time
time.sleep(10)
return original_compute(text)
engine._compute_shap = slow_compute
result = _run(engine.explain("timeout test text unique"))
assert result["fallback"] is True
assert "timed out" in result.get("reason", "").lower() or result["tokens"] == []