File size: 2,472 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import json
from dataclasses import asdict

from pipeline import HealthQueryPipeline

EXIT_COMMANDS = ["exit", "quit"]
PROMPT = "\nQuery> "

def main(pipeline: HealthQueryPipeline, k: int) -> None:
    print(f"(Ctrl-D or 'quit' to exit)")

    while True:
        try:
            query = input(PROMPT).strip()
            if not query or query.lower() in EXIT_COMMANDS:
                break

            # Show index status
            curr, total = pipeline.get_index_progress()
            if total > 0:
                pct = int((curr / total) * 100)
                if pct < 100:
                    print(f"[Index: {pct}% loaded]")

            # Use the pipeline to get results
            result = pipeline.predict(query, k=k)

            classification = result["classification"]
            prediction = classification["prediction"]

            print(f"\nTriaging query as {prediction}")
            print(f"\nConfidence:")
            for cat, prob in classification["probabilities"].items():
                percent = prob * 100
                print(f"  {cat}: {percent:3.2f}%")
            print()

            if "medical" == prediction:
                hits = result["retrieval"]
                print(f"Found {len(hits)} matching medical documents\n")

                if not hits:
                    print("No medical documents found.\n")
                    continue

                for i, hit in enumerate(hits, 1):
                    # hit is already a dict from the pipeline
                    print(json.dumps(hit, indent=2, ensure_ascii=False))
            else:
                print(f"TODO: handle queries of type {prediction}")
                continue

        except EOFError:
            print("\nBye!")
            break

        except KeyboardInterrupt:
            print("\nBye!")
            break


if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        description="Hybrid retrieval (BM25 + Dense + RRF, optional re-rank)"
    )
    ap.add_argument("--k", type=int, default=10, help="Number of results to return")
    ap.add_argument(
        "--rerank", action="store_true",
        help="Use cross-encoder reranker (slower, usually better)"
    )
    args = ap.parse_args()

    # Initialize pipeline
    pipeline = HealthQueryPipeline(use_reranker=args.rerank)
    pipeline.initialize()

    main(pipeline, k=args.k)