KES-Hack / tests /test_fusion_engine.py
Meshyboi's picture
Upload 86 files
214209a verified
"""
Tests for the Fusion Engine
"""
from src.services.fusion_engine import (
fuse_scores,
compute_shap_confidence_boost,
get_severity_label,
W_RULE_BASED, W_MODEL, W_SHAP_BOOST, W_USER_HISTORY
)
def test_weights_sum_to_one():
"""Fusion weights must sum to exactly 1.0."""
total = W_RULE_BASED + W_MODEL + W_SHAP_BOOST + W_USER_HISTORY
assert abs(total - 1.0) < 1e-9, f"Weights sum to {total}, expected 1.0"
def test_fusion_high_risk():
"""When all components report high risk, fused score should be CRITICAL."""
result = fuse_scores(
rule_based_score=90,
model_score=95,
shap_values=[0.5, 0.4, 0.3, 0.2],
user_history_boost=80
)
assert result["score"] >= 80, f"Expected >= 80, got {result['score']}"
assert result["severity"] == "CRITICAL"
assert "Block" in result["cta"]
def test_fusion_safe():
"""When all components report safe, fused score should be LOW."""
result = fuse_scores(
rule_based_score=5,
model_score=3,
shap_values=[],
user_history_boost=0
)
assert result["score"] < 40, f"Expected < 40, got {result['score']}"
assert result["severity"] == "LOW"
assert "safe" in result["cta"].lower()
def test_fusion_moderate():
"""Mixed signals should produce MODERATE severity."""
result = fuse_scores(
rule_based_score=50,
model_score=60,
shap_values=[0.15, 0.1],
user_history_boost=0
)
assert 40 <= result["score"] < 80, f"Expected 40-79, got {result['score']}"
assert result["severity"] in ("MODERATE", "HIGH")
def test_fusion_component_scores_present():
"""Output must include all component scores."""
result = fuse_scores(rule_based_score=50, model_score=50)
cs = result["component_scores"]
assert "rule_based" in cs
assert "model" in cs
assert "shap_boost" in cs
assert "user_history" in cs
def test_shap_boost_empty():
"""No SHAP values → 0 boost."""
assert compute_shap_confidence_boost([]) == 0.0
assert compute_shap_confidence_boost(None or []) == 0.0
def test_shap_boost_positive():
"""High SHAP positive values → non-zero boost."""
boost = compute_shap_confidence_boost([0.5, 0.3, 0.2])
assert boost > 0, f"Expected positive boost, got {boost}"
def test_severity_labels():
"""Severity label mapping."""
assert get_severity_label(90) == "CRITICAL"
assert get_severity_label(65) == "HIGH"
assert get_severity_label(50) == "MODERATE"
assert get_severity_label(10) == "LOW"
def test_fusion_clamped_to_100():
"""Score should never exceed 100."""
result = fuse_scores(
rule_based_score=100,
model_score=100,
shap_values=[1.0, 1.0, 1.0],
user_history_boost=100
)
assert result["score"] <= 100