""" Batch retrieval test suite — loads the model once and runs 10 test queries against Qdrant, graded by difficulty (4 easy, 3 medium, 3 hard). Each test specifies: - query: natural-language question a doctor/staff might ask - expected_policy: slug that MUST appear in the top-K results - expected_section: section that SHOULD appear for the best hit - filters: optional section/policy filter to exercise filtered search - difficulty: easy | medium | hard Scoring: - policy_hit@K : expected policy appears anywhere in top-K - policy_hit@1 : expected policy is the rank-1 result - section_match : rank-1 result matches expected section - mrr : 1/rank of first correct-policy hit (mean reciprocal rank) Usage: python test_retrieval.py python test_retrieval.py --top-k 5 python test_retrieval.py --verbose """ import argparse import json import sys import time from dataclasses import dataclass, field import torch from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from qdrant_client.models import Filter, FieldCondition, MatchValue from config import ( EMBEDDING_MODEL_NAME, MAX_SEQ_LENGTH, QDRANT_HOST, QDRANT_PORT, QDRANT_COLLECTION, QDRANT_URL, QDRANT_API_KEY, TOP_K, ) # ── Test cases ──────────────────────────────────────────────────────────────── TEST_CASES = [ # ── EASY (4): direct keyword overlap, single policy, obvious answer ────── { "id": "E1", "difficulty": "easy", "query": "Is bariatric surgery covered for patients with BMI over 40?", "expected_policy": "bariatric-surgery", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Direct policy name in query; BMI 40 threshold explicitly stated in Coverage Rationale.", }, { "id": "E2", "difficulty": "easy", "query": "What conditions are treated with hyperbaric oxygen therapy?", "expected_policy": "hyperbaric-topical-oxygen-therapy", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Policy lists conditions (crush injury, osteomyelitis, etc.) directly in Coverage Rationale.", }, { "id": "E3", "difficulty": "easy", "query": "What is the coverage policy for cochlear implants in adults?", "expected_policy": "cochlear-implants", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Exact policy name; Coverage Rationale states criteria for adults 18+.", }, { "id": "E4", "difficulty": "easy", "query": "Is TENS covered for pain management?", "expected_policy": "electrical-stimulation-treatment-pain-muscle-rehabilitation", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "TENS is the primary device discussed in this policy's Coverage Rationale.", }, # ── MEDIUM (3): requires semantic understanding, cross-section, or filter ─ { "id": "M1", "difficulty": "medium", "query": "What are the eligibility criteria for gene therapy in hemophilia B patients?", "expected_policy": "gene-therapies-hemophilia", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Must match 'hemophilia B' to Beqvez criteria; query uses 'eligibility' not 'coverage'.", }, { "id": "M2", "difficulty": "medium", "query": "When is proton beam radiation approved instead of standard radiation for cancer?", "expected_policy": "proton-beam-radiation-therapy", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Requires understanding that PBRT is an alternative; policy specifies indications by age and tumor type.", }, { "id": "M3", "difficulty": "medium", "query": "Does UHC cover continuous glucose monitors for diabetic patients on insulin pumps?", "expected_policy": "continuous-glucose-monitoring-insulin-delivery-managing-diabetes", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Long policy slug; query combines two sub-topics (CGM + insulin delivery) from the same policy.", }, # ── HARD (3): paraphrased, multi-hop, or requires domain reasoning ──────── { "id": "H1", "difficulty": "hard", "query": "A 16-year-old patient needs genetic testing for an undiagnosed developmental disorder — is whole genome sequencing covered?", "expected_policy": "whole-exome-and-whole-genome-sequencing", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Heavily paraphrased; must link 'undiagnosed developmental disorder' + 'genetic testing' to WES/WGS policy criteria about suspected genetic cause.", }, { "id": "H2", "difficulty": "hard", "query": "What documentation is needed before a patient can get gender-affirming mastectomy?", "expected_policy": "gender-dysphoria-treatment", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Uses 'gender-affirming mastectomy' instead of 'Gender Dysphoria'; must connect to breast surgery documentation requirements in Coverage Rationale.", }, { "id": "H3", "difficulty": "hard", "query": "Patient has failed oral appliance therapy for sleep apnea — what surgical options does UHC cover?", "expected_policy": "obstructive-sleep-apnea-treatment", "expected_section": "Coverage Rationale", "filters": {}, "rationale": "Multi-hop reasoning: failed OAT → surgical alternatives; query never uses 'obstructive' or policy name. Must infer from clinical scenario.", }, ] # ── Helpers ─────────────────────────────────────────────────────────────────── MAX_RETRIES = 3 RETRY_BACKOFF = 2 PASS = "\033[92m✓ PASS\033[0m" FAIL = "\033[91m✗ FAIL\033[0m" WARN = "\033[93m~ PARTIAL\033[0m" def get_client() -> QdrantClient: if QDRANT_URL: return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=30, prefer_grpc=False) return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=30) def run_query(client, model, device, query, top_k, section_filter=None, policy_filter=None): vec = model.encode(query, convert_to_numpy=True, normalize_embeddings=True, device=device).tolist() conditions = [] if section_filter: conditions.append(FieldCondition(key="section", match=MatchValue(value=section_filter))) if policy_filter: conditions.append(FieldCondition(key="policy_name", match=MatchValue(value=policy_filter))) qf = Filter(must=conditions) if conditions else None for attempt in range(1, MAX_RETRIES + 1): try: return client.query_points( collection_name=QDRANT_COLLECTION, query=vec, query_filter=qf, limit=top_k, with_payload=True, ).points except Exception as e: if attempt < MAX_RETRIES: time.sleep(RETRY_BACKOFF ** attempt) else: raise RuntimeError(f"Qdrant query failed after {MAX_RETRIES} retries: {e}") from e @dataclass class TestResult: test_id: str difficulty: str query: str expected_policy: str expected_section: str policy_hit_at_k: bool = False policy_hit_at_1: bool = False section_match: bool = False first_hit_rank: int = 0 top1_policy: str = "" top1_section: str = "" top1_score: float = 0.0 latency_ms: float = 0.0 error: str = "" def evaluate(tc: dict, client, model, device, top_k: int) -> TestResult: res = TestResult( test_id=tc["id"], difficulty=tc["difficulty"], query=tc["query"], expected_policy=tc["expected_policy"], expected_section=tc["expected_section"], ) t0 = time.perf_counter() try: hits = run_query( client, model, device, tc["query"], top_k, tc["filters"].get("section"), tc["filters"].get("policy"), ) except RuntimeError as e: res.error = str(e) res.latency_ms = (time.perf_counter() - t0) * 1000 return res res.latency_ms = (time.perf_counter() - t0) * 1000 if not hits: return res res.top1_policy = hits[0].payload.get("policy_name", "") res.top1_section = hits[0].payload.get("section", "") res.top1_score = hits[0].score res.policy_hit_at_1 = res.top1_policy == tc["expected_policy"] res.section_match = res.top1_section == tc["expected_section"] for rank, hit in enumerate(hits, 1): if hit.payload.get("policy_name") == tc["expected_policy"]: res.policy_hit_at_k = True res.first_hit_rank = rank break return res # ── Main ────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="Batch retrieval test suite") parser.add_argument("--top-k", type=int, default=TOP_K) parser.add_argument("--verbose", "-v", action="store_true", help="Print top-3 results per test") args = parser.parse_args() print("=" * 80) print(" UHC Policy RAG — Retrieval Test Suite") print(f" Model: {EMBEDDING_MODEL_NAME} | Top-K: {args.top_k}") print("=" * 80) print("\nLoading model (one-time)...") device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=False) model.max_seq_length = MAX_SEQ_LENGTH print(f" Model loaded on {device}") print("Connecting to Qdrant...\n") client = get_client() results: list[TestResult] = [] for tc in TEST_CASES: r = evaluate(tc, client, model, device, args.top_k) results.append(r) if r.error: status = f"\033[91mERROR\033[0m" elif r.policy_hit_at_1 and r.section_match: status = PASS elif r.policy_hit_at_k: status = WARN else: status = FAIL print(f" [{r.test_id}] {status} ({r.difficulty.upper():6s}) {r.latency_ms:6.0f}ms " f"score={r.top1_score:.4f} {r.query[:60]}...") if r.error: print(f" ERROR: {r.error[:120]}") elif not r.policy_hit_at_1: print(f" Expected: {r.expected_policy} / {r.expected_section}") print(f" Got top1: {r.top1_policy} / {r.top1_section}") if r.policy_hit_at_k: print(f" Correct policy first found at rank #{r.first_hit_rank}") if args.verbose and not r.error: hits = run_query( client, model, device, tc["query"], min(3, args.top_k), tc["filters"].get("section"), tc["filters"].get("policy"), ) for rank, hit in enumerate(hits, 1): p = hit.payload print(f" #{rank} [{hit.score:.4f}] {p.get('policy_name')} / {p.get('section')} | {p.get('text')}") # ── Summary ─────────────────────────────────────────────────────────── print("\n" + "=" * 80) print(" SUMMARY") print("=" * 80) valid = [r for r in results if not r.error] errored = [r for r in results if r.error] if not valid: print(" All tests errored. Check Qdrant connection / API key.") sys.exit(1) hit_at_1 = sum(1 for r in valid if r.policy_hit_at_1) hit_at_k = sum(1 for r in valid if r.policy_hit_at_k) section_ok = sum(1 for r in valid if r.policy_hit_at_1 and r.section_match) mrr = sum((1.0 / r.first_hit_rank) for r in valid if r.first_hit_rank > 0) / len(valid) avg_latency = sum(r.latency_ms for r in valid) / len(valid) avg_score = sum(r.top1_score for r in valid) / len(valid) print(f"\n Total tests: {len(results)}") print(f" Successful: {len(valid)}") if errored: print(f" Errors: {len(errored)} ({', '.join(r.test_id for r in errored)})") print(f"\n Policy Hit@1: {hit_at_1}/{len(valid)} ({100*hit_at_1/len(valid):.0f}%)") print(f" Policy Hit@K: {hit_at_k}/{len(valid)} ({100*hit_at_k/len(valid):.0f}%)") print(f" Section Match@1: {section_ok}/{len(valid)} ({100*section_ok/len(valid):.0f}%)") print(f" MRR: {mrr:.4f}") print(f" Avg Cosine Score: {avg_score:.4f}") print(f" Avg Latency: {avg_latency:.0f}ms") for difficulty in ("easy", "medium", "hard"): subset = [r for r in valid if r.difficulty == difficulty] if not subset: continue h1 = sum(1 for r in subset if r.policy_hit_at_1) hk = sum(1 for r in subset if r.policy_hit_at_k) sub_mrr = sum((1.0 / r.first_hit_rank) for r in subset if r.first_hit_rank > 0) / len(subset) print(f"\n {difficulty.upper():6s} Hit@1: {h1}/{len(subset)} Hit@K: {hk}/{len(subset)} MRR: {sub_mrr:.4f}") print("\n" + "=" * 80) # ── JSON dump for programmatic use ──────────────────────────────────── report = { "model": EMBEDDING_MODEL_NAME, "top_k": args.top_k, "total": len(results), "policy_hit_at_1": hit_at_1, "policy_hit_at_k": hit_at_k, "section_match_at_1": section_ok, "mrr": round(mrr, 4), "avg_cosine_score": round(avg_score, 4), "avg_latency_ms": round(avg_latency, 1), "tests": [ { "id": r.test_id, "difficulty": r.difficulty, "query": r.query, "policy_hit_at_1": r.policy_hit_at_1, "policy_hit_at_k": r.policy_hit_at_k, "section_match": r.section_match, "first_hit_rank": r.first_hit_rank, "top1_score": round(r.top1_score, 4), "latency_ms": round(r.latency_ms, 1), "error": r.error or None, } for r in results ], } out_path = "test_results.json" with open(out_path, "w") as f: json.dump(report, f, indent=2) print(f" Results saved to {out_path}\n") if __name__ == "__main__": main()