File size: 5,004 Bytes
c4fe0a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Tests for the human feedback loop module."""
import tempfile
from pathlib import Path
from unittest.mock import patch

from pipeline.feedback import (
    save_feedback,
    save_approval,
    load_all_feedback,
    compute_agreement_stats,
    _reviewable_fields,
    _compute_field_agreement,
    FEEDBACK_PATH,
)


def _with_tmp_feedback(func):
    """Decorator to redirect feedback writes to a temp file."""
    def wrapper(*args, **kwargs):
        with tempfile.NamedTemporaryFile(suffix=".jsonl", delete=False) as f:
            tmp_path = Path(f.name)
        with patch("pipeline.feedback.FEEDBACK_PATH", tmp_path):
            try:
                return func(tmp_path, *args, **kwargs)
            finally:
                tmp_path.unlink(missing_ok=True)
    return wrapper


@_with_tmp_feedback
def test_save_approval_creates_entry(tmp_path):
    entry = save_approval("case-001", {"root_cause_l1": "billing"}, "Looks good")
    assert entry["action"] == "approval"
    assert entry["case_id"] == "case-001"
    assert entry["agreement"]["agreement_rate"] == 1.0
    assert entry["reviewer_notes"] == "Looks good"


@_with_tmp_feedback
def test_save_feedback_records_correction(tmp_path):
    original = {"root_cause_l1": "billing", "risk_level": "low", "confidence": 0.9}
    corrected = {"root_cause_l1": "network", "risk_level": "high"}
    entry = save_feedback("case-002", original, corrected, "Wrong root cause")

    assert entry["action"] == "correction"
    assert entry["original"] == {"root_cause_l1": "billing", "risk_level": "low"}
    assert entry["corrected"] == corrected
    assert "root_cause_l1" in entry["agreement"]["fields_corrected"]
    assert "risk_level" in entry["agreement"]["fields_corrected"]
    assert entry["agreement"]["agreement_rate"] < 1.0


@_with_tmp_feedback
def test_load_all_feedback_roundtrip(tmp_path):
    save_approval("case-001", {})
    save_feedback("case-002", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})

    entries = load_all_feedback()
    assert len(entries) == 2
    assert entries[0]["action"] == "approval"
    assert entries[1]["action"] == "correction"


@_with_tmp_feedback
def test_load_empty_feedback(tmp_path):
    entries = load_all_feedback()
    assert entries == []


@_with_tmp_feedback
def test_compute_agreement_stats_empty(tmp_path):
    stats = compute_agreement_stats()
    assert stats["total_reviews"] == 0
    assert stats["overall_agreement_rate"] == 0.0


@_with_tmp_feedback
def test_compute_agreement_stats_all_approvals(tmp_path):
    save_approval("case-001", {})
    save_approval("case-002", {})

    stats = compute_agreement_stats()
    assert stats["total_reviews"] == 2
    assert stats["approvals"] == 2
    assert stats["corrections"] == 0
    assert stats["overall_agreement_rate"] == 1.0


@_with_tmp_feedback
def test_compute_agreement_stats_mixed(tmp_path):
    save_approval("case-001", {})
    save_feedback("case-002", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})

    stats = compute_agreement_stats()
    assert stats["total_reviews"] == 2
    assert stats["approvals"] == 1
    assert stats["corrections"] == 1
    assert 0.0 < stats["overall_agreement_rate"] < 1.0
    # root_cause_l1 was corrected in one of two reviews
    assert stats["per_field_agreement"]["root_cause_l1"] == 0.5
    assert stats["most_corrected_fields"][0] == ("root_cause_l1", 1)


@_with_tmp_feedback
def test_compute_agreement_per_field(tmp_path):
    # Correct 2 different fields across 2 reviews
    save_feedback("case-001", {"root_cause_l1": "billing"}, {"root_cause_l1": "network"})
    save_feedback("case-002", {"risk_level": "low"}, {"risk_level": "high"})

    stats = compute_agreement_stats()
    # root_cause_l1 was corrected once out of 2 reviews
    assert stats["per_field_agreement"]["root_cause_l1"] == 0.5
    # risk_level was corrected once out of 2 reviews
    assert stats["per_field_agreement"]["risk_level"] == 0.5
    # confidence was never corrected
    assert stats["per_field_agreement"]["confidence"] == 1.0


def test_reviewable_fields_match_schema():
    """Ensure all reviewable fields exist in ExtractionOutput."""
    from pipeline.schemas import ExtractionOutput
    schema_fields = {f.name for f in ExtractionOutput.__dataclass_fields__.values()}
    for field in _reviewable_fields():
        assert field in schema_fields, f"Reviewable field '{field}' not in ExtractionOutput"


def test_compute_field_agreement_no_corrections():
    agreement = _compute_field_agreement(
        {"root_cause_l1": "billing", "risk_level": "low"},
        {},
    )
    assert agreement["agreement_rate"] == 1.0
    assert agreement["fields_corrected"] == []


def test_compute_field_agreement_all_corrected():
    corrected = {field: "new_value" for field in _reviewable_fields()}
    agreement = _compute_field_agreement({}, corrected)
    assert agreement["agreement_rate"] == 0.0
    assert len(agreement["fields_corrected"]) == len(_reviewable_fields())