File size: 2,835 Bytes
214209a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """
Tests for the Fusion Engine
"""
from src.services.fusion_engine import (
fuse_scores,
compute_shap_confidence_boost,
get_severity_label,
W_RULE_BASED, W_MODEL, W_SHAP_BOOST, W_USER_HISTORY
)
def test_weights_sum_to_one():
"""Fusion weights must sum to exactly 1.0."""
total = W_RULE_BASED + W_MODEL + W_SHAP_BOOST + W_USER_HISTORY
assert abs(total - 1.0) < 1e-9, f"Weights sum to {total}, expected 1.0"
def test_fusion_high_risk():
"""When all components report high risk, fused score should be CRITICAL."""
result = fuse_scores(
rule_based_score=90,
model_score=95,
shap_values=[0.5, 0.4, 0.3, 0.2],
user_history_boost=80
)
assert result["score"] >= 80, f"Expected >= 80, got {result['score']}"
assert result["severity"] == "CRITICAL"
assert "Block" in result["cta"]
def test_fusion_safe():
"""When all components report safe, fused score should be LOW."""
result = fuse_scores(
rule_based_score=5,
model_score=3,
shap_values=[],
user_history_boost=0
)
assert result["score"] < 40, f"Expected < 40, got {result['score']}"
assert result["severity"] == "LOW"
assert "safe" in result["cta"].lower()
def test_fusion_moderate():
"""Mixed signals should produce MODERATE severity."""
result = fuse_scores(
rule_based_score=50,
model_score=60,
shap_values=[0.15, 0.1],
user_history_boost=0
)
assert 40 <= result["score"] < 80, f"Expected 40-79, got {result['score']}"
assert result["severity"] in ("MODERATE", "HIGH")
def test_fusion_component_scores_present():
"""Output must include all component scores."""
result = fuse_scores(rule_based_score=50, model_score=50)
cs = result["component_scores"]
assert "rule_based" in cs
assert "model" in cs
assert "shap_boost" in cs
assert "user_history" in cs
def test_shap_boost_empty():
"""No SHAP values → 0 boost."""
assert compute_shap_confidence_boost([]) == 0.0
assert compute_shap_confidence_boost(None or []) == 0.0
def test_shap_boost_positive():
"""High SHAP positive values → non-zero boost."""
boost = compute_shap_confidence_boost([0.5, 0.3, 0.2])
assert boost > 0, f"Expected positive boost, got {boost}"
def test_severity_labels():
"""Severity label mapping."""
assert get_severity_label(90) == "CRITICAL"
assert get_severity_label(65) == "HIGH"
assert get_severity_label(50) == "MODERATE"
assert get_severity_label(10) == "LOW"
def test_fusion_clamped_to_100():
"""Score should never exceed 100."""
result = fuse_scores(
rule_based_score=100,
model_score=100,
shap_values=[1.0, 1.0, 1.0],
user_history_boost=100
)
assert result["score"] <= 100
|