uhc-policy-chatbot / embedding /scripts /test_retrieval.py
Mayank Patel
Initial deployment: UHC Medical Policy Chatbot
5c32ed1
"""
Batch retrieval test suite β€” loads the model once and runs 10 test queries
against Qdrant, graded by difficulty (4 easy, 3 medium, 3 hard).
Each test specifies:
- query: natural-language question a doctor/staff might ask
- expected_policy: slug that MUST appear in the top-K results
- expected_section: section that SHOULD appear for the best hit
- filters: optional section/policy filter to exercise filtered search
- difficulty: easy | medium | hard
Scoring:
- policy_hit@K : expected policy appears anywhere in top-K
- policy_hit@1 : expected policy is the rank-1 result
- section_match : rank-1 result matches expected section
- mrr : 1/rank of first correct-policy hit (mean reciprocal rank)
Usage:
python test_retrieval.py
python test_retrieval.py --top-k 5
python test_retrieval.py --verbose
"""
import argparse
import json
import sys
import time
from dataclasses import dataclass, field
import torch
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from config import (
EMBEDDING_MODEL_NAME,
MAX_SEQ_LENGTH,
QDRANT_HOST,
QDRANT_PORT,
QDRANT_COLLECTION,
QDRANT_URL,
QDRANT_API_KEY,
TOP_K,
)
# ── Test cases ────────────────────────────────────────────────────────────────
TEST_CASES = [
# ── EASY (4): direct keyword overlap, single policy, obvious answer ──────
{
"id": "E1",
"difficulty": "easy",
"query": "Is bariatric surgery covered for patients with BMI over 40?",
"expected_policy": "bariatric-surgery",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Direct policy name in query; BMI 40 threshold explicitly stated in Coverage Rationale.",
},
{
"id": "E2",
"difficulty": "easy",
"query": "What conditions are treated with hyperbaric oxygen therapy?",
"expected_policy": "hyperbaric-topical-oxygen-therapy",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Policy lists conditions (crush injury, osteomyelitis, etc.) directly in Coverage Rationale.",
},
{
"id": "E3",
"difficulty": "easy",
"query": "What is the coverage policy for cochlear implants in adults?",
"expected_policy": "cochlear-implants",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Exact policy name; Coverage Rationale states criteria for adults 18+.",
},
{
"id": "E4",
"difficulty": "easy",
"query": "Is TENS covered for pain management?",
"expected_policy": "electrical-stimulation-treatment-pain-muscle-rehabilitation",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "TENS is the primary device discussed in this policy's Coverage Rationale.",
},
# ── MEDIUM (3): requires semantic understanding, cross-section, or filter ─
{
"id": "M1",
"difficulty": "medium",
"query": "What are the eligibility criteria for gene therapy in hemophilia B patients?",
"expected_policy": "gene-therapies-hemophilia",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Must match 'hemophilia B' to Beqvez criteria; query uses 'eligibility' not 'coverage'.",
},
{
"id": "M2",
"difficulty": "medium",
"query": "When is proton beam radiation approved instead of standard radiation for cancer?",
"expected_policy": "proton-beam-radiation-therapy",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Requires understanding that PBRT is an alternative; policy specifies indications by age and tumor type.",
},
{
"id": "M3",
"difficulty": "medium",
"query": "Does UHC cover continuous glucose monitors for diabetic patients on insulin pumps?",
"expected_policy": "continuous-glucose-monitoring-insulin-delivery-managing-diabetes",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Long policy slug; query combines two sub-topics (CGM + insulin delivery) from the same policy.",
},
# ── HARD (3): paraphrased, multi-hop, or requires domain reasoning ────────
{
"id": "H1",
"difficulty": "hard",
"query": "A 16-year-old patient needs genetic testing for an undiagnosed developmental disorder β€” is whole genome sequencing covered?",
"expected_policy": "whole-exome-and-whole-genome-sequencing",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Heavily paraphrased; must link 'undiagnosed developmental disorder' + 'genetic testing' to WES/WGS policy criteria about suspected genetic cause.",
},
{
"id": "H2",
"difficulty": "hard",
"query": "What documentation is needed before a patient can get gender-affirming mastectomy?",
"expected_policy": "gender-dysphoria-treatment",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Uses 'gender-affirming mastectomy' instead of 'Gender Dysphoria'; must connect to breast surgery documentation requirements in Coverage Rationale.",
},
{
"id": "H3",
"difficulty": "hard",
"query": "Patient has failed oral appliance therapy for sleep apnea β€” what surgical options does UHC cover?",
"expected_policy": "obstructive-sleep-apnea-treatment",
"expected_section": "Coverage Rationale",
"filters": {},
"rationale": "Multi-hop reasoning: failed OAT β†’ surgical alternatives; query never uses 'obstructive' or policy name. Must infer from clinical scenario.",
},
]
# ── Helpers ───────────────────────────────────────────────────────────────────
MAX_RETRIES = 3
RETRY_BACKOFF = 2
PASS = "\033[92mβœ“ PASS\033[0m"
FAIL = "\033[91mβœ— FAIL\033[0m"
WARN = "\033[93m~ PARTIAL\033[0m"
def get_client() -> QdrantClient:
if QDRANT_URL:
return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=30, prefer_grpc=False)
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=30)
def run_query(client, model, device, query, top_k, section_filter=None, policy_filter=None):
vec = model.encode(query, convert_to_numpy=True, normalize_embeddings=True, device=device).tolist()
conditions = []
if section_filter:
conditions.append(FieldCondition(key="section", match=MatchValue(value=section_filter)))
if policy_filter:
conditions.append(FieldCondition(key="policy_name", match=MatchValue(value=policy_filter)))
qf = Filter(must=conditions) if conditions else None
for attempt in range(1, MAX_RETRIES + 1):
try:
return client.query_points(
collection_name=QDRANT_COLLECTION,
query=vec,
query_filter=qf,
limit=top_k,
with_payload=True,
).points
except Exception as e:
if attempt < MAX_RETRIES:
time.sleep(RETRY_BACKOFF ** attempt)
else:
raise RuntimeError(f"Qdrant query failed after {MAX_RETRIES} retries: {e}") from e
@dataclass
class TestResult:
test_id: str
difficulty: str
query: str
expected_policy: str
expected_section: str
policy_hit_at_k: bool = False
policy_hit_at_1: bool = False
section_match: bool = False
first_hit_rank: int = 0
top1_policy: str = ""
top1_section: str = ""
top1_score: float = 0.0
latency_ms: float = 0.0
error: str = ""
def evaluate(tc: dict, client, model, device, top_k: int) -> TestResult:
res = TestResult(
test_id=tc["id"],
difficulty=tc["difficulty"],
query=tc["query"],
expected_policy=tc["expected_policy"],
expected_section=tc["expected_section"],
)
t0 = time.perf_counter()
try:
hits = run_query(
client, model, device,
tc["query"], top_k,
tc["filters"].get("section"),
tc["filters"].get("policy"),
)
except RuntimeError as e:
res.error = str(e)
res.latency_ms = (time.perf_counter() - t0) * 1000
return res
res.latency_ms = (time.perf_counter() - t0) * 1000
if not hits:
return res
res.top1_policy = hits[0].payload.get("policy_name", "")
res.top1_section = hits[0].payload.get("section", "")
res.top1_score = hits[0].score
res.policy_hit_at_1 = res.top1_policy == tc["expected_policy"]
res.section_match = res.top1_section == tc["expected_section"]
for rank, hit in enumerate(hits, 1):
if hit.payload.get("policy_name") == tc["expected_policy"]:
res.policy_hit_at_k = True
res.first_hit_rank = rank
break
return res
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Batch retrieval test suite")
parser.add_argument("--top-k", type=int, default=TOP_K)
parser.add_argument("--verbose", "-v", action="store_true", help="Print top-3 results per test")
args = parser.parse_args()
print("=" * 80)
print(" UHC Policy RAG β€” Retrieval Test Suite")
print(f" Model: {EMBEDDING_MODEL_NAME} | Top-K: {args.top_k}")
print("=" * 80)
print("\nLoading model (one-time)...")
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=False)
model.max_seq_length = MAX_SEQ_LENGTH
print(f" Model loaded on {device}")
print("Connecting to Qdrant...\n")
client = get_client()
results: list[TestResult] = []
for tc in TEST_CASES:
r = evaluate(tc, client, model, device, args.top_k)
results.append(r)
if r.error:
status = f"\033[91mERROR\033[0m"
elif r.policy_hit_at_1 and r.section_match:
status = PASS
elif r.policy_hit_at_k:
status = WARN
else:
status = FAIL
print(f" [{r.test_id}] {status} ({r.difficulty.upper():6s}) {r.latency_ms:6.0f}ms "
f"score={r.top1_score:.4f} {r.query[:60]}...")
if r.error:
print(f" ERROR: {r.error[:120]}")
elif not r.policy_hit_at_1:
print(f" Expected: {r.expected_policy} / {r.expected_section}")
print(f" Got top1: {r.top1_policy} / {r.top1_section}")
if r.policy_hit_at_k:
print(f" Correct policy first found at rank #{r.first_hit_rank}")
if args.verbose and not r.error:
hits = run_query(
client, model, device,
tc["query"], min(3, args.top_k),
tc["filters"].get("section"),
tc["filters"].get("policy"),
)
for rank, hit in enumerate(hits, 1):
p = hit.payload
print(f" #{rank} [{hit.score:.4f}] {p.get('policy_name')} / {p.get('section')} | {p.get('text')}")
# ── Summary ───────────────────────────────────────────────────────────
print("\n" + "=" * 80)
print(" SUMMARY")
print("=" * 80)
valid = [r for r in results if not r.error]
errored = [r for r in results if r.error]
if not valid:
print(" All tests errored. Check Qdrant connection / API key.")
sys.exit(1)
hit_at_1 = sum(1 for r in valid if r.policy_hit_at_1)
hit_at_k = sum(1 for r in valid if r.policy_hit_at_k)
section_ok = sum(1 for r in valid if r.policy_hit_at_1 and r.section_match)
mrr = sum((1.0 / r.first_hit_rank) for r in valid if r.first_hit_rank > 0) / len(valid)
avg_latency = sum(r.latency_ms for r in valid) / len(valid)
avg_score = sum(r.top1_score for r in valid) / len(valid)
print(f"\n Total tests: {len(results)}")
print(f" Successful: {len(valid)}")
if errored:
print(f" Errors: {len(errored)} ({', '.join(r.test_id for r in errored)})")
print(f"\n Policy Hit@1: {hit_at_1}/{len(valid)} ({100*hit_at_1/len(valid):.0f}%)")
print(f" Policy Hit@K: {hit_at_k}/{len(valid)} ({100*hit_at_k/len(valid):.0f}%)")
print(f" Section Match@1: {section_ok}/{len(valid)} ({100*section_ok/len(valid):.0f}%)")
print(f" MRR: {mrr:.4f}")
print(f" Avg Cosine Score: {avg_score:.4f}")
print(f" Avg Latency: {avg_latency:.0f}ms")
for difficulty in ("easy", "medium", "hard"):
subset = [r for r in valid if r.difficulty == difficulty]
if not subset:
continue
h1 = sum(1 for r in subset if r.policy_hit_at_1)
hk = sum(1 for r in subset if r.policy_hit_at_k)
sub_mrr = sum((1.0 / r.first_hit_rank) for r in subset if r.first_hit_rank > 0) / len(subset)
print(f"\n {difficulty.upper():6s} Hit@1: {h1}/{len(subset)} Hit@K: {hk}/{len(subset)} MRR: {sub_mrr:.4f}")
print("\n" + "=" * 80)
# ── JSON dump for programmatic use ────────────────────────────────────
report = {
"model": EMBEDDING_MODEL_NAME,
"top_k": args.top_k,
"total": len(results),
"policy_hit_at_1": hit_at_1,
"policy_hit_at_k": hit_at_k,
"section_match_at_1": section_ok,
"mrr": round(mrr, 4),
"avg_cosine_score": round(avg_score, 4),
"avg_latency_ms": round(avg_latency, 1),
"tests": [
{
"id": r.test_id,
"difficulty": r.difficulty,
"query": r.query,
"policy_hit_at_1": r.policy_hit_at_1,
"policy_hit_at_k": r.policy_hit_at_k,
"section_match": r.section_match,
"first_hit_rank": r.first_hit_rank,
"top1_score": round(r.top1_score, 4),
"latency_ms": round(r.latency_ms, 1),
"error": r.error or None,
}
for r in results
],
}
out_path = "test_results.json"
with open(out_path, "w") as f:
json.dump(report, f, indent=2)
print(f" Results saved to {out_path}\n")
if __name__ == "__main__":
main()