Spaces:
Sleeping
Sleeping
| """ | |
| Test search interface against the Qdrant vector store. | |
| Encodes a query with the same MedEmbed model and retrieves the top-K | |
| most relevant policy chunks, demonstrating the full retrieval pipeline. | |
| Usage: | |
| python search.py "Is bariatric surgery covered for BMI over 40?" | |
| python search.py "What CPT codes are used for cochlear implants?" | |
| python search.py "criteria for sleep apnea treatment" --top-k 5 | |
| python search.py "coverage for gene therapy hemophilia" --section "Coverage Rationale" | |
| """ | |
| import argparse | |
| import sys | |
| import time | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Filter, FieldCondition, MatchValue | |
| from config import ( | |
| EMBEDDING_MODEL_NAME, | |
| MAX_SEQ_LENGTH, | |
| QDRANT_HOST, | |
| QDRANT_PORT, | |
| QDRANT_COLLECTION, | |
| QDRANT_URL, | |
| QDRANT_API_KEY, | |
| TOP_K, | |
| ) | |
| MAX_RETRIES = 3 | |
| RETRY_BACKOFF = 2 | |
| def get_client() -> QdrantClient: | |
| if QDRANT_URL: | |
| if not QDRANT_API_KEY or QDRANT_API_KEY == "YOUR_API_KEY_HERE": | |
| print( | |
| "WARNING: QDRANT_API_KEY is not set or still a placeholder.\n" | |
| " Filtered queries WILL fail. Set a real key in .env", | |
| file=sys.stderr, | |
| ) | |
| return QdrantClient( | |
| url=QDRANT_URL, | |
| api_key=QDRANT_API_KEY, | |
| timeout=30, | |
| prefer_grpc=False, | |
| ) | |
| return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, timeout=30) | |
| def load_model(): | |
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| model = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=False) | |
| model.max_seq_length = MAX_SEQ_LENGTH | |
| return model, device | |
| def search(client, model, device, query, top_k=TOP_K, section_filter=None, policy_filter=None): | |
| query_vector = model.encode( | |
| query, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| device=device, | |
| ).tolist() | |
| conditions = [] | |
| if section_filter: | |
| conditions.append(FieldCondition(key="section", match=MatchValue(value=section_filter))) | |
| if policy_filter: | |
| conditions.append(FieldCondition(key="policy_name", match=MatchValue(value=policy_filter))) | |
| search_filter = Filter(must=conditions) if conditions else None | |
| for attempt in range(1, MAX_RETRIES + 1): | |
| try: | |
| results = client.query_points( | |
| collection_name=QDRANT_COLLECTION, | |
| query=query_vector, | |
| query_filter=search_filter, | |
| limit=top_k, | |
| with_payload=True, | |
| ) | |
| return results.points | |
| except Exception as e: | |
| if attempt < MAX_RETRIES: | |
| wait = RETRY_BACKOFF ** attempt | |
| print(f" Connection error (attempt {attempt}/{MAX_RETRIES}), retrying in {wait}s...") | |
| time.sleep(wait) | |
| else: | |
| raise RuntimeError( | |
| f"Failed after {MAX_RETRIES} attempts. Last error: {e}\n" | |
| "Check that QDRANT_URL and QDRANT_API_KEY in .env are correct." | |
| ) from e | |
| def format_result(hit, rank): | |
| p = hit.payload | |
| lines = [ | |
| f"\n{'='*80}", | |
| f" Rank #{rank} | Score: {hit.score:.4f}", | |
| f" Policy: {p.get('policy_name', 'N/A')}", | |
| f" Section: {p.get('section', 'N/A')}", | |
| f" Effective: {p.get('effective_date', 'N/A')}", | |
| f" Plan: {p.get('plan_type', 'N/A')}", | |
| f" Pages: {p.get('page_start', '?')}-{p.get('page_end', '?')}", | |
| f"{'─'*80}", | |
| ] | |
| text = p.get("text", "") | |
| preview = text[:500] + ("..." if len(text) > 500 else "") | |
| lines.append(f" {preview}") | |
| lines.append(f"{'='*80}") | |
| return "\n".join(lines) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Search UHC policy chunks") | |
| parser.add_argument("query", type=str, help="Search query") | |
| parser.add_argument("--top-k", type=int, default=TOP_K, help="Number of results") | |
| parser.add_argument("--section", type=str, default=None, help="Filter by section name") | |
| parser.add_argument("--policy", type=str, default=None, help="Filter by policy slug") | |
| args = parser.parse_args() | |
| print(f"Query: \"{args.query}\"") | |
| if args.section: | |
| print(f"Section filter: {args.section}") | |
| if args.policy: | |
| print(f"Policy filter: {args.policy}") | |
| print("\nLoading model...") | |
| model, device = load_model() | |
| print("Connecting to Qdrant...") | |
| client = get_client() | |
| print(f"Searching (top-{args.top_k})...\n") | |
| try: | |
| results = search(client, model, device, args.query, args.top_k, args.section, args.policy) | |
| except RuntimeError as e: | |
| print(f"\nERROR: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| if not results: | |
| print("No results found.") | |
| return | |
| for i, hit in enumerate(results, 1): | |
| print(format_result(hit, i)) | |
| if __name__ == "__main__": | |
| main() | |