Mayank Patel
Initial deployment: UHC Medical Policy Chatbot
5c32ed1
#!/usr/bin/env python3
"""
Interactive CLI chatbot for querying UHC medical policies.
Loads the MedEmbed model once, retrieves relevant policy chunks from Qdrant,
and generates answers via Phi-3.5 Mini served by Ollama.
Usage:
python -m chatbot.cli
python -m chatbot.cli --top-k 5 --model phi3.5
"""
import argparse
import sys
import time
from chatbot.config import (
OLLAMA_MODEL,
RETRIEVAL_TOP_K,
MAX_HISTORY_TURNS,
)
from chatbot.retriever import PolicyRetriever
from chatbot.llm import OllamaClient, OllamaError
from chatbot.prompts import format_context, build_messages
# -- ANSI colors --------------------------------------------------------------
DIM = "\033[2m"
GREEN = "\033[92m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
RED = "\033[91m"
BOLD = "\033[1m"
RESET = "\033[0m"
def print_banner():
print(f"""
{BOLD}{'=' * 64}
UHC Medical Policy Chatbot
Model: Phi-3.5 Mini via Ollama | Retrieval: MedEmbed + Qdrant
{'=' * 64}{RESET}
{DIM}Commands:
/clear β€” reset conversation history
/debug β€” show retrieved chunks for the last query
/quit β€” exit{RESET}
""")
def print_sources(chunks):
"""Print a compact list of sources used."""
if not chunks:
return
seen = set()
print(f"\n{DIM}Sources:", end="")
for c in chunks:
key = f"{c.policy_name}/{c.section}"
if key not in seen:
seen.add(key)
print(f"\n [{c.score:.2f}] {c.policy_name} β€” {c.section}", end="")
print(RESET)
def print_debug(chunks):
"""Print full debug info for retrieved chunks."""
print(f"\n{YELLOW}{'─' * 64}")
print(f" DEBUG: {len(chunks)} chunks retrieved")
print(f"{'─' * 64}{RESET}")
for i, c in enumerate(chunks, 1):
print(f"\n{YELLOW} [{i}] score={c.score:.4f} {c.policy_name} / {c.section}{RESET}")
print(f"{DIM} Plan: {c.plan_type} Pages: {c.page_start}-{c.page_end}")
preview = c.text[:300].replace("\n", " ")
print(f" {preview}...{RESET}")
print()
def main():
parser = argparse.ArgumentParser(description="UHC Policy Chatbot CLI")
parser.add_argument("--top-k", type=int, default=RETRIEVAL_TOP_K)
parser.add_argument("--model", type=str, default=OLLAMA_MODEL)
args = parser.parse_args()
print_banner()
# -- Check Ollama ---------------------------------------------------------
llm = OllamaClient(model=args.model)
err = llm.check_ready()
if err:
print(f"{RED}ERROR: {err}{RESET}")
sys.exit(1)
print(f"{DIM}Ollama ready ({args.model}){RESET}")
# -- Init retriever -------------------------------------------------------
retriever = PolicyRetriever()
retriever.init(status_callback=lambda msg: print(f"{DIM}{msg}{RESET}"))
print()
# -- REPL -----------------------------------------------------------------
history: list[dict] = []
last_chunks = []
debug_mode = False
while True:
try:
query = input(f"{CYAN}{BOLD}> {RESET}").strip()
except (KeyboardInterrupt, EOFError):
print(f"\n{DIM}Goodbye.{RESET}")
break
if not query:
continue
# -- Commands ---------------------------------------------------------
if query.lower() == "/quit":
print(f"{DIM}Goodbye.{RESET}")
break
if query.lower() == "/clear":
history.clear()
last_chunks.clear()
print(f"{DIM}History cleared.{RESET}\n")
continue
if query.lower() == "/debug":
if last_chunks:
print_debug(last_chunks)
else:
print(f"{DIM}No chunks retrieved yet.{RESET}\n")
continue
# -- Retrieve ---------------------------------------------------------
t_start = time.perf_counter()
try:
chunks = retriever.retrieve(query, top_k=args.top_k)
except RuntimeError as e:
print(f"{RED}Retrieval error: {e}{RESET}\n")
continue
t_retrieval = time.perf_counter()
last_chunks = chunks
context = format_context(chunks)
# -- Build messages and stream ----------------------------------------
messages = build_messages(query, context, history=history)
print(f"\n{GREEN}", end="", flush=True)
full_response = []
token_count = 0
t_first_token = None
try:
for token in llm.chat_stream(messages):
if t_first_token is None:
t_first_token = time.perf_counter()
print(token, end="", flush=True)
full_response.append(token)
token_count += 1
except OllamaError as e:
print(f"{RESET}\n{RED}LLM error: {e}{RESET}\n")
continue
t_done = time.perf_counter()
print(RESET)
# -- Sources ----------------------------------------------------------
print_sources(chunks)
# -- Latency ----------------------------------------------------------
retrieval_ms = (t_retrieval - t_start) * 1000
first_tok_ms = ((t_first_token or t_done) - t_retrieval) * 1000
gen_ms = (t_done - (t_first_token or t_retrieval)) * 1000
total_ms = (t_done - t_start) * 1000
tok_per_s = token_count / (gen_ms / 1000) if gen_ms > 0 else 0
print(f"\n{DIM}{'─' * 48}")
print(f" Retrieval: {retrieval_ms:7.0f} ms")
print(f" First token: {first_tok_ms:7.0f} ms")
print(f" Generation: {gen_ms:7.0f} ms ({token_count} tok, {tok_per_s:.1f} tok/s)")
print(f" Total: {total_ms:7.0f} ms")
print(f"{'─' * 48}{RESET}")
print()
# -- Update history ---------------------------------------------------
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": "".join(full_response)})
if len(history) > MAX_HISTORY_TURNS * 2:
history = history[-(MAX_HISTORY_TURNS * 2):]
if __name__ == "__main__":
main()