Spaces:
Sleeping
Sleeping
File size: 6,126 Bytes
5c32ed1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | #!/usr/bin/env python3
"""
Interactive CLI chatbot for querying UHC medical policies.
Loads the MedEmbed model once, retrieves relevant policy chunks from Qdrant,
and generates answers via Phi-3.5 Mini served by Ollama.
Usage:
python -m chatbot.cli
python -m chatbot.cli --top-k 5 --model phi3.5
"""
import argparse
import sys
import time
from chatbot.config import (
OLLAMA_MODEL,
RETRIEVAL_TOP_K,
MAX_HISTORY_TURNS,
)
from chatbot.retriever import PolicyRetriever
from chatbot.llm import OllamaClient, OllamaError
from chatbot.prompts import format_context, build_messages
# -- ANSI colors --------------------------------------------------------------
DIM = "\033[2m"
GREEN = "\033[92m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
RED = "\033[91m"
BOLD = "\033[1m"
RESET = "\033[0m"
def print_banner():
print(f"""
{BOLD}{'=' * 64}
UHC Medical Policy Chatbot
Model: Phi-3.5 Mini via Ollama | Retrieval: MedEmbed + Qdrant
{'=' * 64}{RESET}
{DIM}Commands:
/clear β reset conversation history
/debug β show retrieved chunks for the last query
/quit β exit{RESET}
""")
def print_sources(chunks):
"""Print a compact list of sources used."""
if not chunks:
return
seen = set()
print(f"\n{DIM}Sources:", end="")
for c in chunks:
key = f"{c.policy_name}/{c.section}"
if key not in seen:
seen.add(key)
print(f"\n [{c.score:.2f}] {c.policy_name} β {c.section}", end="")
print(RESET)
def print_debug(chunks):
"""Print full debug info for retrieved chunks."""
print(f"\n{YELLOW}{'β' * 64}")
print(f" DEBUG: {len(chunks)} chunks retrieved")
print(f"{'β' * 64}{RESET}")
for i, c in enumerate(chunks, 1):
print(f"\n{YELLOW} [{i}] score={c.score:.4f} {c.policy_name} / {c.section}{RESET}")
print(f"{DIM} Plan: {c.plan_type} Pages: {c.page_start}-{c.page_end}")
preview = c.text[:300].replace("\n", " ")
print(f" {preview}...{RESET}")
print()
def main():
parser = argparse.ArgumentParser(description="UHC Policy Chatbot CLI")
parser.add_argument("--top-k", type=int, default=RETRIEVAL_TOP_K)
parser.add_argument("--model", type=str, default=OLLAMA_MODEL)
args = parser.parse_args()
print_banner()
# -- Check Ollama ---------------------------------------------------------
llm = OllamaClient(model=args.model)
err = llm.check_ready()
if err:
print(f"{RED}ERROR: {err}{RESET}")
sys.exit(1)
print(f"{DIM}Ollama ready ({args.model}){RESET}")
# -- Init retriever -------------------------------------------------------
retriever = PolicyRetriever()
retriever.init(status_callback=lambda msg: print(f"{DIM}{msg}{RESET}"))
print()
# -- REPL -----------------------------------------------------------------
history: list[dict] = []
last_chunks = []
debug_mode = False
while True:
try:
query = input(f"{CYAN}{BOLD}> {RESET}").strip()
except (KeyboardInterrupt, EOFError):
print(f"\n{DIM}Goodbye.{RESET}")
break
if not query:
continue
# -- Commands ---------------------------------------------------------
if query.lower() == "/quit":
print(f"{DIM}Goodbye.{RESET}")
break
if query.lower() == "/clear":
history.clear()
last_chunks.clear()
print(f"{DIM}History cleared.{RESET}\n")
continue
if query.lower() == "/debug":
if last_chunks:
print_debug(last_chunks)
else:
print(f"{DIM}No chunks retrieved yet.{RESET}\n")
continue
# -- Retrieve ---------------------------------------------------------
t_start = time.perf_counter()
try:
chunks = retriever.retrieve(query, top_k=args.top_k)
except RuntimeError as e:
print(f"{RED}Retrieval error: {e}{RESET}\n")
continue
t_retrieval = time.perf_counter()
last_chunks = chunks
context = format_context(chunks)
# -- Build messages and stream ----------------------------------------
messages = build_messages(query, context, history=history)
print(f"\n{GREEN}", end="", flush=True)
full_response = []
token_count = 0
t_first_token = None
try:
for token in llm.chat_stream(messages):
if t_first_token is None:
t_first_token = time.perf_counter()
print(token, end="", flush=True)
full_response.append(token)
token_count += 1
except OllamaError as e:
print(f"{RESET}\n{RED}LLM error: {e}{RESET}\n")
continue
t_done = time.perf_counter()
print(RESET)
# -- Sources ----------------------------------------------------------
print_sources(chunks)
# -- Latency ----------------------------------------------------------
retrieval_ms = (t_retrieval - t_start) * 1000
first_tok_ms = ((t_first_token or t_done) - t_retrieval) * 1000
gen_ms = (t_done - (t_first_token or t_retrieval)) * 1000
total_ms = (t_done - t_start) * 1000
tok_per_s = token_count / (gen_ms / 1000) if gen_ms > 0 else 0
print(f"\n{DIM}{'β' * 48}")
print(f" Retrieval: {retrieval_ms:7.0f} ms")
print(f" First token: {first_tok_ms:7.0f} ms")
print(f" Generation: {gen_ms:7.0f} ms ({token_count} tok, {tok_per_s:.1f} tok/s)")
print(f" Total: {total_ms:7.0f} ms")
print(f"{'β' * 48}{RESET}")
print()
# -- Update history ---------------------------------------------------
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": "".join(full_response)})
if len(history) > MAX_HISTORY_TURNS * 2:
history = history[-(MAX_HISTORY_TURNS * 2):]
if __name__ == "__main__":
main()
|