File size: 2,835 Bytes
214209a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Tests for the Fusion Engine
"""

from src.services.fusion_engine import (
    fuse_scores,
    compute_shap_confidence_boost,
    get_severity_label,
    W_RULE_BASED, W_MODEL, W_SHAP_BOOST, W_USER_HISTORY
)


def test_weights_sum_to_one():
    """Fusion weights must sum to exactly 1.0."""
    total = W_RULE_BASED + W_MODEL + W_SHAP_BOOST + W_USER_HISTORY
    assert abs(total - 1.0) < 1e-9, f"Weights sum to {total}, expected 1.0"


def test_fusion_high_risk():
    """When all components report high risk, fused score should be CRITICAL."""
    result = fuse_scores(
        rule_based_score=90,
        model_score=95,
        shap_values=[0.5, 0.4, 0.3, 0.2],
        user_history_boost=80
    )
    assert result["score"] >= 80, f"Expected >= 80, got {result['score']}"
    assert result["severity"] == "CRITICAL"
    assert "Block" in result["cta"]


def test_fusion_safe():
    """When all components report safe, fused score should be LOW."""
    result = fuse_scores(
        rule_based_score=5,
        model_score=3,
        shap_values=[],
        user_history_boost=0
    )
    assert result["score"] < 40, f"Expected < 40, got {result['score']}"
    assert result["severity"] == "LOW"
    assert "safe" in result["cta"].lower()


def test_fusion_moderate():
    """Mixed signals should produce MODERATE severity."""
    result = fuse_scores(
        rule_based_score=50,
        model_score=60,
        shap_values=[0.15, 0.1],
        user_history_boost=0
    )
    assert 40 <= result["score"] < 80, f"Expected 40-79, got {result['score']}"
    assert result["severity"] in ("MODERATE", "HIGH")


def test_fusion_component_scores_present():
    """Output must include all component scores."""
    result = fuse_scores(rule_based_score=50, model_score=50)
    cs = result["component_scores"]
    assert "rule_based" in cs
    assert "model" in cs
    assert "shap_boost" in cs
    assert "user_history" in cs


def test_shap_boost_empty():
    """No SHAP values → 0 boost."""
    assert compute_shap_confidence_boost([]) == 0.0
    assert compute_shap_confidence_boost(None or []) == 0.0


def test_shap_boost_positive():
    """High SHAP positive values → non-zero boost."""
    boost = compute_shap_confidence_boost([0.5, 0.3, 0.2])
    assert boost > 0, f"Expected positive boost, got {boost}"


def test_severity_labels():
    """Severity label mapping."""
    assert get_severity_label(90) == "CRITICAL"
    assert get_severity_label(65) == "HIGH"
    assert get_severity_label(50) == "MODERATE"
    assert get_severity_label(10) == "LOW"


def test_fusion_clamped_to_100():
    """Score should never exceed 100."""
    result = fuse_scores(
        rule_based_score=100,
        model_score=100,
        shap_values=[1.0, 1.0, 1.0],
        user_history_boost=100
    )
    assert result["score"] <= 100