File size: 2,472 Bytes
b7f3196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import argparse
import json
from dataclasses import asdict
from pipeline import HealthQueryPipeline
EXIT_COMMANDS = ["exit", "quit"]
PROMPT = "\nQuery> "
def main(pipeline: HealthQueryPipeline, k: int) -> None:
print(f"(Ctrl-D or 'quit' to exit)")
while True:
try:
query = input(PROMPT).strip()
if not query or query.lower() in EXIT_COMMANDS:
break
# Show index status
curr, total = pipeline.get_index_progress()
if total > 0:
pct = int((curr / total) * 100)
if pct < 100:
print(f"[Index: {pct}% loaded]")
# Use the pipeline to get results
result = pipeline.predict(query, k=k)
classification = result["classification"]
prediction = classification["prediction"]
print(f"\nTriaging query as {prediction}")
print(f"\nConfidence:")
for cat, prob in classification["probabilities"].items():
percent = prob * 100
print(f" {cat}: {percent:3.2f}%")
print()
if "medical" == prediction:
hits = result["retrieval"]
print(f"Found {len(hits)} matching medical documents\n")
if not hits:
print("No medical documents found.\n")
continue
for i, hit in enumerate(hits, 1):
# hit is already a dict from the pipeline
print(json.dumps(hit, indent=2, ensure_ascii=False))
else:
print(f"TODO: handle queries of type {prediction}")
continue
except EOFError:
print("\nBye!")
break
except KeyboardInterrupt:
print("\nBye!")
break
if __name__ == "__main__":
ap = argparse.ArgumentParser(
description="Hybrid retrieval (BM25 + Dense + RRF, optional re-rank)"
)
ap.add_argument("--k", type=int, default=10, help="Number of results to return")
ap.add_argument(
"--rerank", action="store_true",
help="Use cross-encoder reranker (slower, usually better)"
)
args = ap.parse_args()
# Initialize pipeline
pipeline = HealthQueryPipeline(use_reranker=args.rerank)
pipeline.initialize()
main(pipeline, k=args.k)
|