Spaces:
Sleeping
Sleeping
| """ | |
| Batch retrieval test suite β loads the model once and runs 10 test queries | |
| against Qdrant, graded by difficulty (4 easy, 3 medium, 3 hard). | |
| Each test specifies: | |
| - query: natural-language question a doctor/staff might ask | |
| - expected_policy: slug that MUST appear in the top-K results | |
| - expected_section: section that SHOULD appear for the best hit | |
| - filters: optional section/policy filter to exercise filtered search | |
| - difficulty: easy | medium | hard | |
| Scoring: | |
| - policy_hit@K : expected policy appears anywhere in top-K | |
| - policy_hit@1 : expected policy is the rank-1 result | |
| - section_match : rank-1 result matches expected section | |
| - mrr : 1/rank of first correct-policy hit (mean reciprocal rank) | |
| Usage: | |
| python test_retrieval.py | |
| python test_retrieval.py --top-k 5 | |
| python test_retrieval.py --verbose | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Filter, FieldCondition, MatchValue | |
| from config import ( | |
| EMBEDDING_MODEL_NAME, | |
| MAX_SEQ_LENGTH, | |
| QDRANT_HOST, | |
| QDRANT_PORT, | |
| QDRANT_COLLECTION, | |
| QDRANT_URL, | |
| QDRANT_API_KEY, | |
| TOP_K, | |
| ) | |
| # ββ Test cases ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TEST_CASES = [ | |
| # ββ EASY (4): direct keyword overlap, single policy, obvious answer ββββββ | |
| { | |
| "id": "E1", | |
| "difficulty": "easy", | |
| "query": "Is bariatric surgery covered for patients with BMI over 40?", | |
| "expected_policy": "bariatric-surgery", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Direct policy name in query; BMI 40 threshold explicitly stated in Coverage Rationale.", | |
| }, | |
| { | |
| "id": "E2", | |
| "difficulty": "easy", | |
| "query": "What conditions are treated with hyperbaric oxygen therapy?", | |
| "expected_policy": "hyperbaric-topical-oxygen-therapy", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Policy lists conditions (crush injury, osteomyelitis, etc.) directly in Coverage Rationale.", | |
| }, | |
| { | |
| "id": "E3", | |
| "difficulty": "easy", | |
| "query": "What is the coverage policy for cochlear implants in adults?", | |
| "expected_policy": "cochlear-implants", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Exact policy name; Coverage Rationale states criteria for adults 18+.", | |
| }, | |
| { | |
| "id": "E4", | |
| "difficulty": "easy", | |
| "query": "Is TENS covered for pain management?", | |
| "expected_policy": "electrical-stimulation-treatment-pain-muscle-rehabilitation", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "TENS is the primary device discussed in this policy's Coverage Rationale.", | |
| }, | |
| # ββ MEDIUM (3): requires semantic understanding, cross-section, or filter β | |
| { | |
| "id": "M1", | |
| "difficulty": "medium", | |
| "query": "What are the eligibility criteria for gene therapy in hemophilia B patients?", | |
| "expected_policy": "gene-therapies-hemophilia", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Must match 'hemophilia B' to Beqvez criteria; query uses 'eligibility' not 'coverage'.", | |
| }, | |
| { | |
| "id": "M2", | |
| "difficulty": "medium", | |
| "query": "When is proton beam radiation approved instead of standard radiation for cancer?", | |
| "expected_policy": "proton-beam-radiation-therapy", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Requires understanding that PBRT is an alternative; policy specifies indications by age and tumor type.", | |
| }, | |
| { | |
| "id": "M3", | |
| "difficulty": "medium", | |
| "query": "Does UHC cover continuous glucose monitors for diabetic patients on insulin pumps?", | |
| "expected_policy": "continuous-glucose-monitoring-insulin-delivery-managing-diabetes", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Long policy slug; query combines two sub-topics (CGM + insulin delivery) from the same policy.", | |
| }, | |
| # ββ HARD (3): paraphrased, multi-hop, or requires domain reasoning ββββββββ | |
| { | |
| "id": "H1", | |
| "difficulty": "hard", | |
| "query": "A 16-year-old patient needs genetic testing for an undiagnosed developmental disorder β is whole genome sequencing covered?", | |
| "expected_policy": "whole-exome-and-whole-genome-sequencing", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Heavily paraphrased; must link 'undiagnosed developmental disorder' + 'genetic testing' to WES/WGS policy criteria about suspected genetic cause.", | |
| }, | |
| { | |
| "id": "H2", | |
| "difficulty": "hard", | |
| "query": "What documentation is needed before a patient can get gender-affirming mastectomy?", | |
| "expected_policy": "gender-dysphoria-treatment", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Uses 'gender-affirming mastectomy' instead of 'Gender Dysphoria'; must connect to breast surgery documentation requirements in Coverage Rationale.", | |
| }, | |
| { | |
| "id": "H3", | |
| "difficulty": "hard", | |
| "query": "Patient has failed oral appliance therapy for sleep apnea β what surgical options does UHC cover?", | |
| "expected_policy": "obstructive-sleep-apnea-treatment", | |
| "expected_section": "Coverage Rationale", | |
| "filters": {}, | |
| "rationale": "Multi-hop reasoning: failed OAT β surgical alternatives; query never uses 'obstructive' or policy name. Must infer from clinical scenario.", | |
| }, | |
| ] | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_RETRIES = 3 | |
| RETRY_BACKOFF = 2 | |
| PASS = "\033[92mβ PASS\033[0m" | |
| FAIL = "\033[91mβ FAIL\033[0m" | |
| WARN = "\033[93m~ PARTIAL\033[0m" | |
| def get_client() -> QdrantClient: | |
| if QDRANT_URL: | |
| return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=30, prefer_grpc=False) | |
| return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=30) | |
| def run_query(client, model, device, query, top_k, section_filter=None, policy_filter=None): | |
| vec = model.encode(query, convert_to_numpy=True, normalize_embeddings=True, device=device).tolist() | |
| conditions = [] | |
| if section_filter: | |
| conditions.append(FieldCondition(key="section", match=MatchValue(value=section_filter))) | |
| if policy_filter: | |
| conditions.append(FieldCondition(key="policy_name", match=MatchValue(value=policy_filter))) | |
| qf = Filter(must=conditions) if conditions else None | |
| for attempt in range(1, MAX_RETRIES + 1): | |
| try: | |
| return client.query_points( | |
| collection_name=QDRANT_COLLECTION, | |
| query=vec, | |
| query_filter=qf, | |
| limit=top_k, | |
| with_payload=True, | |
| ).points | |
| except Exception as e: | |
| if attempt < MAX_RETRIES: | |
| time.sleep(RETRY_BACKOFF ** attempt) | |
| else: | |
| raise RuntimeError(f"Qdrant query failed after {MAX_RETRIES} retries: {e}") from e | |
| class TestResult: | |
| test_id: str | |
| difficulty: str | |
| query: str | |
| expected_policy: str | |
| expected_section: str | |
| policy_hit_at_k: bool = False | |
| policy_hit_at_1: bool = False | |
| section_match: bool = False | |
| first_hit_rank: int = 0 | |
| top1_policy: str = "" | |
| top1_section: str = "" | |
| top1_score: float = 0.0 | |
| latency_ms: float = 0.0 | |
| error: str = "" | |
| def evaluate(tc: dict, client, model, device, top_k: int) -> TestResult: | |
| res = TestResult( | |
| test_id=tc["id"], | |
| difficulty=tc["difficulty"], | |
| query=tc["query"], | |
| expected_policy=tc["expected_policy"], | |
| expected_section=tc["expected_section"], | |
| ) | |
| t0 = time.perf_counter() | |
| try: | |
| hits = run_query( | |
| client, model, device, | |
| tc["query"], top_k, | |
| tc["filters"].get("section"), | |
| tc["filters"].get("policy"), | |
| ) | |
| except RuntimeError as e: | |
| res.error = str(e) | |
| res.latency_ms = (time.perf_counter() - t0) * 1000 | |
| return res | |
| res.latency_ms = (time.perf_counter() - t0) * 1000 | |
| if not hits: | |
| return res | |
| res.top1_policy = hits[0].payload.get("policy_name", "") | |
| res.top1_section = hits[0].payload.get("section", "") | |
| res.top1_score = hits[0].score | |
| res.policy_hit_at_1 = res.top1_policy == tc["expected_policy"] | |
| res.section_match = res.top1_section == tc["expected_section"] | |
| for rank, hit in enumerate(hits, 1): | |
| if hit.payload.get("policy_name") == tc["expected_policy"]: | |
| res.policy_hit_at_k = True | |
| res.first_hit_rank = rank | |
| break | |
| return res | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Batch retrieval test suite") | |
| parser.add_argument("--top-k", type=int, default=TOP_K) | |
| parser.add_argument("--verbose", "-v", action="store_true", help="Print top-3 results per test") | |
| args = parser.parse_args() | |
| print("=" * 80) | |
| print(" UHC Policy RAG β Retrieval Test Suite") | |
| print(f" Model: {EMBEDDING_MODEL_NAME} | Top-K: {args.top_k}") | |
| print("=" * 80) | |
| print("\nLoading model (one-time)...") | |
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=False) | |
| model.max_seq_length = MAX_SEQ_LENGTH | |
| print(f" Model loaded on {device}") | |
| print("Connecting to Qdrant...\n") | |
| client = get_client() | |
| results: list[TestResult] = [] | |
| for tc in TEST_CASES: | |
| r = evaluate(tc, client, model, device, args.top_k) | |
| results.append(r) | |
| if r.error: | |
| status = f"\033[91mERROR\033[0m" | |
| elif r.policy_hit_at_1 and r.section_match: | |
| status = PASS | |
| elif r.policy_hit_at_k: | |
| status = WARN | |
| else: | |
| status = FAIL | |
| print(f" [{r.test_id}] {status} ({r.difficulty.upper():6s}) {r.latency_ms:6.0f}ms " | |
| f"score={r.top1_score:.4f} {r.query[:60]}...") | |
| if r.error: | |
| print(f" ERROR: {r.error[:120]}") | |
| elif not r.policy_hit_at_1: | |
| print(f" Expected: {r.expected_policy} / {r.expected_section}") | |
| print(f" Got top1: {r.top1_policy} / {r.top1_section}") | |
| if r.policy_hit_at_k: | |
| print(f" Correct policy first found at rank #{r.first_hit_rank}") | |
| if args.verbose and not r.error: | |
| hits = run_query( | |
| client, model, device, | |
| tc["query"], min(3, args.top_k), | |
| tc["filters"].get("section"), | |
| tc["filters"].get("policy"), | |
| ) | |
| for rank, hit in enumerate(hits, 1): | |
| p = hit.payload | |
| print(f" #{rank} [{hit.score:.4f}] {p.get('policy_name')} / {p.get('section')} | {p.get('text')}") | |
| # ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 80) | |
| print(" SUMMARY") | |
| print("=" * 80) | |
| valid = [r for r in results if not r.error] | |
| errored = [r for r in results if r.error] | |
| if not valid: | |
| print(" All tests errored. Check Qdrant connection / API key.") | |
| sys.exit(1) | |
| hit_at_1 = sum(1 for r in valid if r.policy_hit_at_1) | |
| hit_at_k = sum(1 for r in valid if r.policy_hit_at_k) | |
| section_ok = sum(1 for r in valid if r.policy_hit_at_1 and r.section_match) | |
| mrr = sum((1.0 / r.first_hit_rank) for r in valid if r.first_hit_rank > 0) / len(valid) | |
| avg_latency = sum(r.latency_ms for r in valid) / len(valid) | |
| avg_score = sum(r.top1_score for r in valid) / len(valid) | |
| print(f"\n Total tests: {len(results)}") | |
| print(f" Successful: {len(valid)}") | |
| if errored: | |
| print(f" Errors: {len(errored)} ({', '.join(r.test_id for r in errored)})") | |
| print(f"\n Policy Hit@1: {hit_at_1}/{len(valid)} ({100*hit_at_1/len(valid):.0f}%)") | |
| print(f" Policy Hit@K: {hit_at_k}/{len(valid)} ({100*hit_at_k/len(valid):.0f}%)") | |
| print(f" Section Match@1: {section_ok}/{len(valid)} ({100*section_ok/len(valid):.0f}%)") | |
| print(f" MRR: {mrr:.4f}") | |
| print(f" Avg Cosine Score: {avg_score:.4f}") | |
| print(f" Avg Latency: {avg_latency:.0f}ms") | |
| for difficulty in ("easy", "medium", "hard"): | |
| subset = [r for r in valid if r.difficulty == difficulty] | |
| if not subset: | |
| continue | |
| h1 = sum(1 for r in subset if r.policy_hit_at_1) | |
| hk = sum(1 for r in subset if r.policy_hit_at_k) | |
| sub_mrr = sum((1.0 / r.first_hit_rank) for r in subset if r.first_hit_rank > 0) / len(subset) | |
| print(f"\n {difficulty.upper():6s} Hit@1: {h1}/{len(subset)} Hit@K: {hk}/{len(subset)} MRR: {sub_mrr:.4f}") | |
| print("\n" + "=" * 80) | |
| # ββ JSON dump for programmatic use ββββββββββββββββββββββββββββββββββββ | |
| report = { | |
| "model": EMBEDDING_MODEL_NAME, | |
| "top_k": args.top_k, | |
| "total": len(results), | |
| "policy_hit_at_1": hit_at_1, | |
| "policy_hit_at_k": hit_at_k, | |
| "section_match_at_1": section_ok, | |
| "mrr": round(mrr, 4), | |
| "avg_cosine_score": round(avg_score, 4), | |
| "avg_latency_ms": round(avg_latency, 1), | |
| "tests": [ | |
| { | |
| "id": r.test_id, | |
| "difficulty": r.difficulty, | |
| "query": r.query, | |
| "policy_hit_at_1": r.policy_hit_at_1, | |
| "policy_hit_at_k": r.policy_hit_at_k, | |
| "section_match": r.section_match, | |
| "first_hit_rank": r.first_hit_rank, | |
| "top1_score": round(r.top1_score, 4), | |
| "latency_ms": round(r.latency_ms, 1), | |
| "error": r.error or None, | |
| } | |
| for r in results | |
| ], | |
| } | |
| out_path = "test_results.json" | |
| with open(out_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f" Results saved to {out_path}\n") | |
| if __name__ == "__main__": | |
| main() | |