Prompt_Squirrel_RAG / scripts /test_structural_trio_group_rule.py
Food Desert
Add trio->group structural postprocess and regression test
65b7582
"""Regression tests for structural trio/group behavior.
Covers two invariants:
1) Runtime structural postprocess: selecting `trio` implies `group`.
2) Eval sample data integrity: no row has `trio` without `group`.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Iterable, List
# Ensure repo root is importable when run as a script.
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.llm.select import _postprocess_structural_tags
def _flatten_categorized_tags(raw: str) -> List[str]:
try:
obj = json.loads(raw or "{}")
except Exception:
return []
if not isinstance(obj, dict):
return []
out: List[str] = []
for values in obj.values():
if not isinstance(values, list):
continue
for item in values:
out.append(str(item))
return out
def _check_invariant(tag_list: Iterable[str]) -> bool:
vals = list(tag_list)
return ("trio" not in vals) or ("group" in vals)
def test_postprocess_logic() -> bool:
print("=" * 80)
print("Testing structural postprocess trio->group rule")
print("=" * 80)
ok = True
cases = [
(["solo"], ["solo"]),
(["trio"], ["trio", "group"]),
(["trio", "group"], ["trio", "group"]),
(["duo", "trio", "male"], ["duo", "trio", "male", "group"]),
]
for i, (inp, expected) in enumerate(cases, start=1):
got = _postprocess_structural_tags(inp)
passed = got == expected
status = "PASS" if passed else "FAIL"
print(f"{status}: case {i} input={inp} expected={expected} got={got}")
if not passed:
ok = False
return ok
def test_eval_sample_invariant() -> bool:
print("\n" + "=" * 80)
print("Testing eval sample trio/group data invariant")
print("=" * 80)
eval_dir = _REPO_ROOT / "data" / "eval_samples"
files = sorted(eval_dir.glob("*.jsonl"))
if not files:
print("FAIL: no eval sample JSONL files found")
return False
ok = True
for fp in files:
total = 0
gt_exp_viol = 0
gt_cat_viol = 0
syn_cat_viol = 0
with fp.open("r", encoding="utf-8") as f:
for ln, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except Exception:
continue
if obj.get("_meta") is True:
continue
total += 1
expanded = obj.get("tags_ground_truth_expanded")
if isinstance(expanded, list) and not _check_invariant(str(x) for x in expanded):
gt_exp_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} ground_truth_expanded has trio without group")
cat_raw = obj.get("tags_ground_truth_categorized")
if isinstance(cat_raw, str):
cat_vals = _flatten_categorized_tags(cat_raw)
if not _check_invariant(cat_vals):
gt_cat_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} ground_truth_categorized has trio without group")
syn_raw = obj.get("tags_synthetic_categorized")
if isinstance(syn_raw, str):
syn_vals = _flatten_categorized_tags(syn_raw)
if not _check_invariant(syn_vals):
syn_cat_viol += 1
ok = False
print(f"FAIL: {fp.name}:{ln} tags_synthetic_categorized has trio without group")
print(
f"PASS: {fp.name} rows={total} "
f"gt_exp_viol={gt_exp_viol} gt_cat_viol={gt_cat_viol} syn_cat_viol={syn_cat_viol}"
)
return ok
def main() -> int:
ok = True
ok &= test_postprocess_logic()
ok &= test_eval_sample_invariant()
print("\n" + "=" * 80)
if ok:
print("ALL TESTS PASSED")
else:
print("SOME TESTS FAILED")
print("=" * 80)
return 0 if ok else 1
if __name__ == "__main__":
raise SystemExit(main())