Spaces:
Sleeping
Sleeping
File size: 5,060 Bytes
5c32ed1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Test search interface against the Qdrant vector store.
Encodes a query with the same MedEmbed model and retrieves the top-K
most relevant policy chunks, demonstrating the full retrieval pipeline.
Usage:
python search.py "Is bariatric surgery covered for BMI over 40?"
python search.py "What CPT codes are used for cochlear implants?"
python search.py "criteria for sleep apnea treatment" --top-k 5
python search.py "coverage for gene therapy hemophilia" --section "Coverage Rationale"
"""
import argparse
import sys
import time
import torch
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from config import (
EMBEDDING_MODEL_NAME,
MAX_SEQ_LENGTH,
QDRANT_HOST,
QDRANT_PORT,
QDRANT_COLLECTION,
QDRANT_URL,
QDRANT_API_KEY,
TOP_K,
)
MAX_RETRIES = 3
RETRY_BACKOFF = 2
def get_client() -> QdrantClient:
if QDRANT_URL:
if not QDRANT_API_KEY or QDRANT_API_KEY == "YOUR_API_KEY_HERE":
print(
"WARNING: QDRANT_API_KEY is not set or still a placeholder.\n"
" Filtered queries WILL fail. Set a real key in .env",
file=sys.stderr,
)
return QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=30,
prefer_grpc=False,
)
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=30)
def load_model():
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=False)
model.max_seq_length = MAX_SEQ_LENGTH
return model, device
def search(client, model, device, query, top_k=TOP_K, section_filter=None, policy_filter=None):
query_vector = model.encode(
query,
convert_to_numpy=True,
normalize_embeddings=True,
device=device,
).tolist()
conditions = []
if section_filter:
conditions.append(FieldCondition(key="section", match=MatchValue(value=section_filter)))
if policy_filter:
conditions.append(FieldCondition(key="policy_name", match=MatchValue(value=policy_filter)))
search_filter = Filter(must=conditions) if conditions else None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = client.query_points(
collection_name=QDRANT_COLLECTION,
query=query_vector,
query_filter=search_filter,
limit=top_k,
with_payload=True,
)
return results.points
except Exception as e:
if attempt < MAX_RETRIES:
wait = RETRY_BACKOFF ** attempt
print(f" Connection error (attempt {attempt}/{MAX_RETRIES}), retrying in {wait}s...")
time.sleep(wait)
else:
raise RuntimeError(
f"Failed after {MAX_RETRIES} attempts. Last error: {e}\n"
"Check that QDRANT_URL and QDRANT_API_KEY in .env are correct."
) from e
def format_result(hit, rank):
p = hit.payload
lines = [
f"\n{'='*80}",
f" Rank #{rank} | Score: {hit.score:.4f}",
f" Policy: {p.get('policy_name', 'N/A')}",
f" Section: {p.get('section', 'N/A')}",
f" Effective: {p.get('effective_date', 'N/A')}",
f" Plan: {p.get('plan_type', 'N/A')}",
f" Pages: {p.get('page_start', '?')}-{p.get('page_end', '?')}",
f"{'─'*80}",
]
text = p.get("text", "")
preview = text[:500] + ("..." if len(text) > 500 else "")
lines.append(f" {preview}")
lines.append(f"{'='*80}")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="Search UHC policy chunks")
parser.add_argument("query", type=str, help="Search query")
parser.add_argument("--top-k", type=int, default=TOP_K, help="Number of results")
parser.add_argument("--section", type=str, default=None, help="Filter by section name")
parser.add_argument("--policy", type=str, default=None, help="Filter by policy slug")
args = parser.parse_args()
print(f"Query: \"{args.query}\"")
if args.section:
print(f"Section filter: {args.section}")
if args.policy:
print(f"Policy filter: {args.policy}")
print("\nLoading model...")
model, device = load_model()
print("Connecting to Qdrant...")
client = get_client()
print(f"Searching (top-{args.top_k})...\n")
try:
results = search(client, model, device, args.query, args.top_k, args.section, args.policy)
except RuntimeError as e:
print(f"\nERROR: {e}", file=sys.stderr)
sys.exit(1)
if not results:
print("No results found.")
return
for i, hit in enumerate(results, 1):
print(format_result(hit, i))
if __name__ == "__main__":
main()
|