Spaces:
Running
Running
| """Regression tests for structural trio/group behavior. | |
| Covers two invariants: | |
| 1) Runtime structural postprocess: selecting `trio` implies `group`. | |
| 2) Eval sample data integrity: no row has `trio` without `group`. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Iterable, List | |
| # Ensure repo root is importable when run as a script. | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from psq_rag.llm.select import _postprocess_structural_tags | |
| def _flatten_categorized_tags(raw: str) -> List[str]: | |
| try: | |
| obj = json.loads(raw or "{}") | |
| except Exception: | |
| return [] | |
| if not isinstance(obj, dict): | |
| return [] | |
| out: List[str] = [] | |
| for values in obj.values(): | |
| if not isinstance(values, list): | |
| continue | |
| for item in values: | |
| out.append(str(item)) | |
| return out | |
| def _check_invariant(tag_list: Iterable[str]) -> bool: | |
| vals = list(tag_list) | |
| return ("trio" not in vals) or ("group" in vals) | |
| def test_postprocess_logic() -> bool: | |
| print("=" * 80) | |
| print("Testing structural postprocess trio->group rule") | |
| print("=" * 80) | |
| ok = True | |
| cases = [ | |
| (["solo"], ["solo"]), | |
| (["trio"], ["trio", "group"]), | |
| (["trio", "group"], ["trio", "group"]), | |
| (["duo", "trio", "male"], ["duo", "trio", "male", "group"]), | |
| ] | |
| for i, (inp, expected) in enumerate(cases, start=1): | |
| got = _postprocess_structural_tags(inp) | |
| passed = got == expected | |
| status = "PASS" if passed else "FAIL" | |
| print(f"{status}: case {i} input={inp} expected={expected} got={got}") | |
| if not passed: | |
| ok = False | |
| return ok | |
| def test_eval_sample_invariant() -> bool: | |
| print("\n" + "=" * 80) | |
| print("Testing eval sample trio/group data invariant") | |
| print("=" * 80) | |
| eval_dir = _REPO_ROOT / "data" / "eval_samples" | |
| files = sorted(eval_dir.glob("*.jsonl")) | |
| if not files: | |
| print("FAIL: no eval sample JSONL files found") | |
| return False | |
| ok = True | |
| for fp in files: | |
| total = 0 | |
| gt_exp_viol = 0 | |
| gt_cat_viol = 0 | |
| syn_cat_viol = 0 | |
| with fp.open("r", encoding="utf-8") as f: | |
| for ln, line in enumerate(f, start=1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| continue | |
| if obj.get("_meta") is True: | |
| continue | |
| total += 1 | |
| expanded = obj.get("tags_ground_truth_expanded") | |
| if isinstance(expanded, list) and not _check_invariant(str(x) for x in expanded): | |
| gt_exp_viol += 1 | |
| ok = False | |
| print(f"FAIL: {fp.name}:{ln} ground_truth_expanded has trio without group") | |
| cat_raw = obj.get("tags_ground_truth_categorized") | |
| if isinstance(cat_raw, str): | |
| cat_vals = _flatten_categorized_tags(cat_raw) | |
| if not _check_invariant(cat_vals): | |
| gt_cat_viol += 1 | |
| ok = False | |
| print(f"FAIL: {fp.name}:{ln} ground_truth_categorized has trio without group") | |
| syn_raw = obj.get("tags_synthetic_categorized") | |
| if isinstance(syn_raw, str): | |
| syn_vals = _flatten_categorized_tags(syn_raw) | |
| if not _check_invariant(syn_vals): | |
| syn_cat_viol += 1 | |
| ok = False | |
| print(f"FAIL: {fp.name}:{ln} tags_synthetic_categorized has trio without group") | |
| print( | |
| f"PASS: {fp.name} rows={total} " | |
| f"gt_exp_viol={gt_exp_viol} gt_cat_viol={gt_cat_viol} syn_cat_viol={syn_cat_viol}" | |
| ) | |
| return ok | |
| def main() -> int: | |
| ok = True | |
| ok &= test_postprocess_logic() | |
| ok &= test_eval_sample_invariant() | |
| print("\n" + "=" * 80) | |
| if ok: | |
| print("ALL TESTS PASSED") | |
| else: | |
| print("SOME TESTS FAILED") | |
| print("=" * 80) | |
| return 0 if ok else 1 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |