Spaces:
Running
Running
File size: 4,355 Bytes
65b7582 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Regression tests for structural trio/group behavior.
Covers two invariants:
1) Runtime structural postprocess: selecting `trio` implies `group`.
2) Eval sample data integrity: no row has `trio` without `group`.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Iterable, List
# Ensure repo root is importable when run as a script.
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.llm.select import _postprocess_structural_tags
def _flatten_categorized_tags(raw: str) -> List[str]:
try:
obj = json.loads(raw or "{}")
except Exception:
return []
if not isinstance(obj, dict):
return []
out: List[str] = []
for values in obj.values():
if not isinstance(values, list):
continue
for item in values:
out.append(str(item))
return out
def _check_invariant(tag_list: Iterable[str]) -> bool:
vals = list(tag_list)
return ("trio" not in vals) or ("group" in vals)
def test_postprocess_logic() -> bool:
print("=" * 80)
print("Testing structural postprocess trio->group rule")
print("=" * 80)
ok = True
cases = [
(["solo"], ["solo"]),
(["trio"], ["trio", "group"]),
(["trio", "group"], ["trio", "group"]),
(["duo", "trio", "male"], ["duo", "trio", "male", "group"]),
]
for i, (inp, expected) in enumerate(cases, start=1):
got = _postprocess_structural_tags(inp)
passed = got == expected
status = "PASS" if passed else "FAIL"
print(f"{status}: case {i} input={inp} expected={expected} got={got}")
if not passed:
ok = False
return ok
def test_eval_sample_invariant() -> bool:
print("\n" + "=" * 80)
print("Testing eval sample trio/group data invariant")
print("=" * 80)
eval_dir = _REPO_ROOT / "data" / "eval_samples"
files = sorted(eval_dir.glob("*.jsonl"))
if not files:
print("FAIL: no eval sample JSONL files found")
return False
ok = True
for fp in files:
total = 0
gt_exp_viol = 0
gt_cat_viol = 0
syn_cat_viol = 0
with fp.open("r", encoding="utf-8") as f:
for ln, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except Exception:
continue
if obj.get("_meta") is True:
continue
total += 1
expanded = obj.get("tags_ground_truth_expanded")
if isinstance(expanded, list) and not _check_invariant(str(x) for x in expanded):
gt_exp_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} ground_truth_expanded has trio without group")
cat_raw = obj.get("tags_ground_truth_categorized")
if isinstance(cat_raw, str):
cat_vals = _flatten_categorized_tags(cat_raw)
if not _check_invariant(cat_vals):
gt_cat_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} ground_truth_categorized has trio without group")
syn_raw = obj.get("tags_synthetic_categorized")
if isinstance(syn_raw, str):
syn_vals = _flatten_categorized_tags(syn_raw)
if not _check_invariant(syn_vals):
syn_cat_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} tags_synthetic_categorized has trio without group")
print(
f"PASS: {fp.name} rows={total} "
f"gt_exp_viol={gt_exp_viol} gt_cat_viol={gt_cat_viol} syn_cat_viol={syn_cat_viol}"
)
return ok
def main() -> int:
ok = True
ok &= test_postprocess_logic()
ok &= test_eval_sample_invariant()
print("\n" + "=" * 80)
if ok:
print("ALL TESTS PASSED")
else:
print("SOME TESTS FAILED")
print("=" * 80)
return 0 if ok else 1
if __name__ == "__main__":
raise SystemExit(main())
|