Spaces:
Sleeping
Sleeping
File size: 5,004 Bytes
c4fe0a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Tests for the human feedback loop module."""
import tempfile
from pathlib import Path
from unittest.mock import patch
from pipeline.feedback import (
save_feedback,
save_approval,
load_all_feedback,
compute_agreement_stats,
_reviewable_fields,
_compute_field_agreement,
FEEDBACK_PATH,
)
def _with_tmp_feedback(func):
"""Decorator to redirect feedback writes to a temp file."""
def wrapper(*args, **kwargs):
with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False) as f:
tmp_path = Path(f.name)
with patch("pipeline.feedback.FEEDBACK_PATH", tmp_path):
try:
return func(tmp_path, *args, **kwargs)
finally:
tmp_path.unlink(missing_ok=True)
return wrapper
@_with_tmp_feedback
def test_save_approval_creates_entry(tmp_path):
entry = save_approval("case-001", {"root_cause_l1": "billing"}, "Looks good")
assert entry["action"] == "approval"
assert entry["case_id"] == "case-001"
assert entry["agreement"]["agreement_rate"] == 1.0
assert entry["reviewer_notes"] == "Looks good"
@_with_tmp_feedback
def test_save_feedback_records_correction(tmp_path):
original = {"root_cause_l1": "billing", "risk_level": "low", "confidence": 0.9}
corrected = {"root_cause_l1": "network", "risk_level": "high"}
entry = save_feedback("case-002", original, corrected, "Wrong root cause")
assert entry["action"] == "correction"
assert entry["original"] == {"root_cause_l1": "billing", "risk_level": "low"}
assert entry["corrected"] == corrected
assert "root_cause_l1" in entry["agreement"]["fields_corrected"]
assert "risk_level" in entry["agreement"]["fields_corrected"]
assert entry["agreement"]["agreement_rate"] < 1.0
@_with_tmp_feedback
def test_load_all_feedback_roundtrip(tmp_path):
save_approval("case-001", {})
save_feedback("case-002", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})
entries = load_all_feedback()
assert len(entries) == 2
assert entries[0]["action"] == "approval"
assert entries[1]["action"] == "correction"
@_with_tmp_feedback
def test_load_empty_feedback(tmp_path):
entries = load_all_feedback()
assert entries == []
@_with_tmp_feedback
def test_compute_agreement_stats_empty(tmp_path):
stats = compute_agreement_stats()
assert stats["total_reviews"] == 0
assert stats["overall_agreement_rate"] == 0.0
@_with_tmp_feedback
def test_compute_agreement_stats_all_approvals(tmp_path):
save_approval("case-001", {})
save_approval("case-002", {})
stats = compute_agreement_stats()
assert stats["total_reviews"] == 2
assert stats["approvals"] == 2
assert stats["corrections"] == 0
assert stats["overall_agreement_rate"] == 1.0
@_with_tmp_feedback
def test_compute_agreement_stats_mixed(tmp_path):
save_approval("case-001", {})
save_feedback("case-002", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})
stats = compute_agreement_stats()
assert stats["total_reviews"] == 2
assert stats["approvals"] == 1
assert stats["corrections"] == 1
assert 0.0 < stats["overall_agreement_rate"] < 1.0
# root_cause_l1 was corrected in one of two reviews
assert stats["per_field_agreement"]["root_cause_l1"] == 0.5
assert stats["most_corrected_fields"][0] == ("root_cause_l1", 1)
@_with_tmp_feedback
def test_compute_agreement_per_field(tmp_path):
# Correct 2 different fields across 2 reviews
save_feedback("case-001", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})
save_feedback("case-002", {"risk_level": "low"}, {"risk_level": "high"})
stats = compute_agreement_stats()
# root_cause_l1 was corrected once out of 2 reviews
assert stats["per_field_agreement"]["root_cause_l1"] == 0.5
# risk_level was corrected once out of 2 reviews
assert stats["per_field_agreement"]["risk_level"] == 0.5
# confidence was never corrected
assert stats["per_field_agreement"]["confidence"] == 1.0
def test_reviewable_fields_match_schema():
"""Ensure all reviewable fields exist in ExtractionOutput."""
from pipeline.schemas import ExtractionOutput
schema_fields = {f.name for f in ExtractionOutput.__dataclass_fields__.values()}
for field in _reviewable_fields():
assert field in schema_fields, f"Reviewable field '{field}' not in ExtractionOutput"
def test_compute_field_agreement_no_corrections():
agreement = _compute_field_agreement(
{"root_cause_l1": "billing", "risk_level": "low"},
{},
)
assert agreement["agreement_rate"] == 1.0
assert agreement["fields_corrected"] == []
def test_compute_field_agreement_all_corrected():
corrected = {field: "new_value" for field in _reviewable_fields()}
agreement = _compute_field_agreement({}, corrected)
assert agreement["agreement_rate"] == 0.0
assert len(agreement["fields_corrected"]) == len(_reviewable_fields())
|