File size: 4,866 Bytes
08ac559
34caac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c7fad
34caac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08ac559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import csv
import json
from datetime import datetime


def test_conversation_verification_export_serializes_without_record_id(tmp_path, monkeypatch):
    """Regression test: conversation verification records use exchange_id, not record_id."""

    # Minimal stand-in objects (match the attributes used by open_verification_window export code)
    class _Record:
        def __init__(self):
            self.exchange_id = "sess_1"
            self.exchange_number = 1
            self.timestamp = datetime(2025, 12, 12, 0, 0, 0)
            self.user_message = "hi"
            self.assistant_response = "hello"
            self.original_classification = "YELLOW"
            self.original_confidence = 0.9
            self.original_indicators = ["stress"]
            self.original_reasoning = "reason"
            self.is_correct = None
            self.correct_classification = None
            self.correction_reason = None
            self.verifier_notes = None

    class _Session:
        def __init__(self):
            self.session_id = "verification_test"
            self.patient_name = "Test"
            self.verifier_name = "Verifier"
            self.start_time = datetime(2025, 12, 12, 0, 0, 0)
            self.verification_records = [_Record()]

    vs = _Session()

    # This mirrors the payload schema in gradio_app.open_verification_window
    payload = {
        "session_id": vs.session_id,
        "patient_name": vs.patient_name,
        "verifier_name": vs.verifier_name,
        "start_time": vs.start_time.isoformat(),
        "verification_records": [
            {
                "exchange_id": getattr(r, "exchange_id", None),
                "record_id": getattr(r, "exchange_id", None),
                "timestamp": r.timestamp.isoformat(),
                "user_message": r.user_message,
                "assistant_response": r.assistant_response,
                "original_classification": r.original_classification,
                "original_confidence": r.original_confidence,
                "original_indicators": r.original_indicators,
                "original_reasoning": r.original_reasoning,
                "is_correct": r.is_correct,
                "correct_classification": r.correct_classification,
                "correction_reason": r.correction_reason,
                "verifier_notes": r.verifier_notes,
            }
            for r in vs.verification_records
        ],
    }

    out = tmp_path / "export.json"
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")

    loaded = json.loads(out.read_text(encoding="utf-8"))
    assert loaded["verification_records"][0]["exchange_id"] == "sess_1"
    assert loaded["verification_records"][0]["record_id"] == "sess_1"


def test_conversation_verification_csv_contains_expected_columns(tmp_path):
    meta = {
        "session_id": "verification_test",
        "patient_name": "Test",
        "verifier_name": "Verifier",
        "start_time": "2025-12-12T00:00:00",
    }
    records = [
        {
            "exchange_id": "sess_1",
            "exchange_number": 1,
            "original_classification": "YELLOW",
            "original_confidence": 0.9,
            "is_correct": False,
            "verifier_notes": "Needs follow-up",
            "user_message": "hi",
            "assistant_response": "hello",
        }
    ]

    out = tmp_path / "export.csv"

    fieldnames = [
        "session_id",
        "patient_name",
        "verifier_name",
        "start_time",
        "exchange_number",
        "exchange_id",
        "original_classification",
        "original_confidence",
        "is_correct",
        "verifier_notes",
        "user_message",
        "assistant_response",
    ]

    with out.open("w", encoding="utf-8", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in records:
            w.writerow(
                {
                    "session_id": meta["session_id"],
                    "patient_name": meta["patient_name"],
                    "verifier_name": meta["verifier_name"],
                    "start_time": meta["start_time"],
                    "exchange_number": r.get("exchange_number"),
                    "exchange_id": r.get("exchange_id"),
                    "original_classification": r.get("original_classification"),
                    "original_confidence": r.get("original_confidence"),
                    "is_correct": r.get("is_correct"),
                    "verifier_notes": r.get("verifier_notes"),
                    "user_message": r.get("user_message"),
                    "assistant_response": r.get("assistant_response"),
                }
            )

    rows = list(csv.DictReader(out.open("r", encoding="utf-8")))
    assert rows and rows[0]["verifier_notes"] == "Needs follow-up"