""" Tests for the Fusion Engine """ from src.services.fusion_engine import ( fuse_scores, compute_shap_confidence_boost, get_severity_label, W_RULE_BASED, W_MODEL, W_SHAP_BOOST, W_USER_HISTORY ) def test_weights_sum_to_one(): """Fusion weights must sum to exactly 1.0.""" total = W_RULE_BASED + W_MODEL + W_SHAP_BOOST + W_USER_HISTORY assert abs(total - 1.0) < 1e-9, f"Weights sum to {total}, expected 1.0" def test_fusion_high_risk(): """When all components report high risk, fused score should be CRITICAL.""" result = fuse_scores( rule_based_score=90, model_score=95, shap_values=[0.5, 0.4, 0.3, 0.2], user_history_boost=80 ) assert result["score"] >= 80, f"Expected >= 80, got {result['score']}" assert result["severity"] == "CRITICAL" assert "Block" in result["cta"] def test_fusion_safe(): """When all components report safe, fused score should be LOW.""" result = fuse_scores( rule_based_score=5, model_score=3, shap_values=[], user_history_boost=0 ) assert result["score"] < 40, f"Expected < 40, got {result['score']}" assert result["severity"] == "LOW" assert "safe" in result["cta"].lower() def test_fusion_moderate(): """Mixed signals should produce MODERATE severity.""" result = fuse_scores( rule_based_score=50, model_score=60, shap_values=[0.15, 0.1], user_history_boost=0 ) assert 40 <= result["score"] < 80, f"Expected 40-79, got {result['score']}" assert result["severity"] in ("MODERATE", "HIGH") def test_fusion_component_scores_present(): """Output must include all component scores.""" result = fuse_scores(rule_based_score=50, model_score=50) cs = result["component_scores"] assert "rule_based" in cs assert "model" in cs assert "shap_boost" in cs assert "user_history" in cs def test_shap_boost_empty(): """No SHAP values → 0 boost.""" assert compute_shap_confidence_boost([]) == 0.0 assert compute_shap_confidence_boost(None or []) == 0.0 def test_shap_boost_positive(): """High SHAP positive values → non-zero boost.""" boost = compute_shap_confidence_boost([0.5, 0.3, 0.2]) assert boost > 0, f"Expected positive boost, got {boost}" def test_severity_labels(): """Severity label mapping.""" assert get_severity_label(90) == "CRITICAL" assert get_severity_label(65) == "HIGH" assert get_severity_label(50) == "MODERATE" assert get_severity_label(10) == "LOW" def test_fusion_clamped_to_100(): """Score should never exceed 100.""" result = fuse_scores( rule_based_score=100, model_score=100, shap_values=[1.0, 1.0, 1.0], user_history_boost=100 ) assert result["score"] <= 100