"""Regression tests for structural trio/group behavior. Covers two invariants: 1) Runtime structural postprocess: selecting `trio` implies `group`. 2) Eval sample data integrity: no row has `trio` without `group`. """ from __future__ import annotations import json import sys from pathlib import Path from typing import Iterable, List # Ensure repo root is importable when run as a script. _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from psq_rag.llm.select import _postprocess_structural_tags def _flatten_categorized_tags(raw: str) -> List[str]: try: obj = json.loads(raw or "{}") except Exception: return [] if not isinstance(obj, dict): return [] out: List[str] = [] for values in obj.values(): if not isinstance(values, list): continue for item in values: out.append(str(item)) return out def _check_invariant(tag_list: Iterable[str]) -> bool: vals = list(tag_list) return ("trio" not in vals) or ("group" in vals) def test_postprocess_logic() -> bool: print("=" * 80) print("Testing structural postprocess trio->group rule") print("=" * 80) ok = True cases = [ (["solo"], ["solo"]), (["trio"], ["trio", "group"]), (["trio", "group"], ["trio", "group"]), (["duo", "trio", "male"], ["duo", "trio", "male", "group"]), ] for i, (inp, expected) in enumerate(cases, start=1): got = _postprocess_structural_tags(inp) passed = got == expected status = "PASS" if passed else "FAIL" print(f"{status}: case {i} input={inp} expected={expected} got={got}") if not passed: ok = False return ok def test_eval_sample_invariant() -> bool: print("\n" + "=" * 80) print("Testing eval sample trio/group data invariant") print("=" * 80) eval_dir = _REPO_ROOT / "data" / "eval_samples" files = sorted(eval_dir.glob("*.jsonl")) if not files: print("FAIL: no eval sample JSONL files found") return False ok = True for fp in files: total = 0 gt_exp_viol = 0 gt_cat_viol = 0 syn_cat_viol = 0 with fp.open("r", encoding="utf-8") as f: for ln, line in enumerate(f, start=1): line = line.strip() if not line: continue try: obj = json.loads(line) except Exception: continue if obj.get("_meta") is True: continue total += 1 expanded = obj.get("tags_ground_truth_expanded") if isinstance(expanded, list) and not _check_invariant(str(x) for x in expanded): gt_exp_viol += 1 ok = False print(f"FAIL: {fp.name}:{ln} ground_truth_expanded has trio without group") cat_raw = obj.get("tags_ground_truth_categorized") if isinstance(cat_raw, str): cat_vals = _flatten_categorized_tags(cat_raw) if not _check_invariant(cat_vals): gt_cat_viol += 1 ok = False print(f"FAIL: {fp.name}:{ln} ground_truth_categorized has trio without group") syn_raw = obj.get("tags_synthetic_categorized") if isinstance(syn_raw, str): syn_vals = _flatten_categorized_tags(syn_raw) if not _check_invariant(syn_vals): syn_cat_viol += 1 ok = False print(f"FAIL: {fp.name}:{ln} tags_synthetic_categorized has trio without group") print( f"PASS: {fp.name} rows={total} " f"gt_exp_viol={gt_exp_viol} gt_cat_viol={gt_cat_viol} syn_cat_viol={syn_cat_viol}" ) return ok def main() -> int: ok = True ok &= test_postprocess_logic() ok &= test_eval_sample_invariant() print("\n" + "=" * 80) if ok: print("ALL TESTS PASSED") else: print("SOME TESTS FAILED") print("=" * 80) return 0 if ok else 1 if __name__ == "__main__": raise SystemExit(main())