File size: 4,355 Bytes
65b7582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Regression tests for structural trio/group behavior.

Covers two invariants:
1) Runtime structural postprocess: selecting `trio` implies `group`.
2) Eval sample data integrity: no row has `trio` without `group`.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Iterable, List

# Ensure repo root is importable when run as a script.
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

from psq_rag.llm.select import _postprocess_structural_tags


def _flatten_categorized_tags(raw: str) -> List[str]:
    try:
        obj = json.loads(raw or "{}")
    except Exception:
        return []
    if not isinstance(obj, dict):
        return []
    out: List[str] = []
    for values in obj.values():
        if not isinstance(values, list):
            continue
        for item in values:
            out.append(str(item))
    return out


def _check_invariant(tag_list: Iterable[str]) -> bool:
    vals = list(tag_list)
    return ("trio" not in vals) or ("group" in vals)


def test_postprocess_logic() -> bool:
    print("=" * 80)
    print("Testing structural postprocess trio->group rule")
    print("=" * 80)

    ok = True

    cases = [
        (["solo"], ["solo"]),
        (["trio"], ["trio", "group"]),
        (["trio", "group"], ["trio", "group"]),
        (["duo", "trio", "male"], ["duo", "trio", "male", "group"]),
    ]

    for i, (inp, expected) in enumerate(cases, start=1):
        got = _postprocess_structural_tags(inp)
        passed = got == expected
        status = "PASS" if passed else "FAIL"
        print(f"{status}: case {i} input={inp} expected={expected} got={got}")
        if not passed:
            ok = False

    return ok


def test_eval_sample_invariant() -> bool:
    print("\n" + "=" * 80)
    print("Testing eval sample trio/group data invariant")
    print("=" * 80)

    eval_dir = _REPO_ROOT / "data" / "eval_samples"
    files = sorted(eval_dir.glob("*.jsonl"))
    if not files:
        print("FAIL: no eval sample JSONL files found")
        return False

    ok = True
    for fp in files:
        total = 0
        gt_exp_viol = 0
        gt_cat_viol = 0
        syn_cat_viol = 0

        with fp.open("r", encoding="utf-8") as f:
            for ln, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                if obj.get("_meta") is True:
                    continue
                total += 1

                expanded = obj.get("tags_ground_truth_expanded")
                if isinstance(expanded, list) and not _check_invariant(str(x) for x in expanded):
                    gt_exp_viol += 1
                    ok = False
                    print(f"FAIL: {fp.name}:{ln} ground_truth_expanded has trio without group")

                cat_raw = obj.get("tags_ground_truth_categorized")
                if isinstance(cat_raw, str):
                    cat_vals = _flatten_categorized_tags(cat_raw)
                    if not _check_invariant(cat_vals):
                        gt_cat_viol += 1
                        ok = False
                        print(f"FAIL: {fp.name}:{ln} ground_truth_categorized has trio without group")

                syn_raw = obj.get("tags_synthetic_categorized")
                if isinstance(syn_raw, str):
                    syn_vals = _flatten_categorized_tags(syn_raw)
                    if not _check_invariant(syn_vals):
                        syn_cat_viol += 1
                        ok = False
                        print(f"FAIL: {fp.name}:{ln} tags_synthetic_categorized has trio without group")

        print(
            f"PASS: {fp.name} rows={total} "
            f"gt_exp_viol={gt_exp_viol} gt_cat_viol={gt_cat_viol} syn_cat_viol={syn_cat_viol}"
        )

    return ok


def main() -> int:
    ok = True
    ok &= test_postprocess_logic()
    ok &= test_eval_sample_invariant()

    print("\n" + "=" * 80)
    if ok:
        print("ALL TESTS PASSED")
    else:
        print("SOME TESTS FAILED")
    print("=" * 80)

    return 0 if ok else 1


if __name__ == "__main__":
    raise SystemExit(main())