Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Interactive CLI chatbot for querying UHC medical policies. | |
| Loads the MedEmbed model once, retrieves relevant policy chunks from Qdrant, | |
| and generates answers via Phi-3.5 Mini served by Ollama. | |
| Usage: | |
| python -m chatbot.cli | |
| python -m chatbot.cli --top-k 5 --model phi3.5 | |
| """ | |
| import argparse | |
| import sys | |
| import time | |
| from chatbot.config import ( | |
| OLLAMA_MODEL, | |
| RETRIEVAL_TOP_K, | |
| MAX_HISTORY_TURNS, | |
| ) | |
| from chatbot.retriever import PolicyRetriever | |
| from chatbot.llm import OllamaClient, OllamaError | |
| from chatbot.prompts import format_context, build_messages | |
| # -- ANSI colors -------------------------------------------------------------- | |
| DIM = "\033[2m" | |
| GREEN = "\033[92m" | |
| CYAN = "\033[96m" | |
| YELLOW = "\033[93m" | |
| RED = "\033[91m" | |
| BOLD = "\033[1m" | |
| RESET = "\033[0m" | |
| def print_banner(): | |
| print(f""" | |
| {BOLD}{'=' * 64} | |
| UHC Medical Policy Chatbot | |
| Model: Phi-3.5 Mini via Ollama | Retrieval: MedEmbed + Qdrant | |
| {'=' * 64}{RESET} | |
| {DIM}Commands: | |
| /clear β reset conversation history | |
| /debug β show retrieved chunks for the last query | |
| /quit β exit{RESET} | |
| """) | |
| def print_sources(chunks): | |
| """Print a compact list of sources used.""" | |
| if not chunks: | |
| return | |
| seen = set() | |
| print(f"\n{DIM}Sources:", end="") | |
| for c in chunks: | |
| key = f"{c.policy_name}/{c.section}" | |
| if key not in seen: | |
| seen.add(key) | |
| print(f"\n [{c.score:.2f}] {c.policy_name} β {c.section}", end="") | |
| print(RESET) | |
| def print_debug(chunks): | |
| """Print full debug info for retrieved chunks.""" | |
| print(f"\n{YELLOW}{'β' * 64}") | |
| print(f" DEBUG: {len(chunks)} chunks retrieved") | |
| print(f"{'β' * 64}{RESET}") | |
| for i, c in enumerate(chunks, 1): | |
| print(f"\n{YELLOW} [{i}] score={c.score:.4f} {c.policy_name} / {c.section}{RESET}") | |
| print(f"{DIM} Plan: {c.plan_type} Pages: {c.page_start}-{c.page_end}") | |
| preview = c.text[:300].replace("\n", " ") | |
| print(f" {preview}...{RESET}") | |
| print() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="UHC Policy Chatbot CLI") | |
| parser.add_argument("--top-k", type=int, default=RETRIEVAL_TOP_K) | |
| parser.add_argument("--model", type=str, default=OLLAMA_MODEL) | |
| args = parser.parse_args() | |
| print_banner() | |
| # -- Check Ollama --------------------------------------------------------- | |
| llm = OllamaClient(model=args.model) | |
| err = llm.check_ready() | |
| if err: | |
| print(f"{RED}ERROR: {err}{RESET}") | |
| sys.exit(1) | |
| print(f"{DIM}Ollama ready ({args.model}){RESET}") | |
| # -- Init retriever ------------------------------------------------------- | |
| retriever = PolicyRetriever() | |
| retriever.init(status_callback=lambda msg: print(f"{DIM}{msg}{RESET}")) | |
| print() | |
| # -- REPL ----------------------------------------------------------------- | |
| history: list[dict] = [] | |
| last_chunks = [] | |
| debug_mode = False | |
| while True: | |
| try: | |
| query = input(f"{CYAN}{BOLD}> {RESET}").strip() | |
| except (KeyboardInterrupt, EOFError): | |
| print(f"\n{DIM}Goodbye.{RESET}") | |
| break | |
| if not query: | |
| continue | |
| # -- Commands --------------------------------------------------------- | |
| if query.lower() == "/quit": | |
| print(f"{DIM}Goodbye.{RESET}") | |
| break | |
| if query.lower() == "/clear": | |
| history.clear() | |
| last_chunks.clear() | |
| print(f"{DIM}History cleared.{RESET}\n") | |
| continue | |
| if query.lower() == "/debug": | |
| if last_chunks: | |
| print_debug(last_chunks) | |
| else: | |
| print(f"{DIM}No chunks retrieved yet.{RESET}\n") | |
| continue | |
| # -- Retrieve --------------------------------------------------------- | |
| t_start = time.perf_counter() | |
| try: | |
| chunks = retriever.retrieve(query, top_k=args.top_k) | |
| except RuntimeError as e: | |
| print(f"{RED}Retrieval error: {e}{RESET}\n") | |
| continue | |
| t_retrieval = time.perf_counter() | |
| last_chunks = chunks | |
| context = format_context(chunks) | |
| # -- Build messages and stream ---------------------------------------- | |
| messages = build_messages(query, context, history=history) | |
| print(f"\n{GREEN}", end="", flush=True) | |
| full_response = [] | |
| token_count = 0 | |
| t_first_token = None | |
| try: | |
| for token in llm.chat_stream(messages): | |
| if t_first_token is None: | |
| t_first_token = time.perf_counter() | |
| print(token, end="", flush=True) | |
| full_response.append(token) | |
| token_count += 1 | |
| except OllamaError as e: | |
| print(f"{RESET}\n{RED}LLM error: {e}{RESET}\n") | |
| continue | |
| t_done = time.perf_counter() | |
| print(RESET) | |
| # -- Sources ---------------------------------------------------------- | |
| print_sources(chunks) | |
| # -- Latency ---------------------------------------------------------- | |
| retrieval_ms = (t_retrieval - t_start) * 1000 | |
| first_tok_ms = ((t_first_token or t_done) - t_retrieval) * 1000 | |
| gen_ms = (t_done - (t_first_token or t_retrieval)) * 1000 | |
| total_ms = (t_done - t_start) * 1000 | |
| tok_per_s = token_count / (gen_ms / 1000) if gen_ms > 0 else 0 | |
| print(f"\n{DIM}{'β' * 48}") | |
| print(f" Retrieval: {retrieval_ms:7.0f} ms") | |
| print(f" First token: {first_tok_ms:7.0f} ms") | |
| print(f" Generation: {gen_ms:7.0f} ms ({token_count} tok, {tok_per_s:.1f} tok/s)") | |
| print(f" Total: {total_ms:7.0f} ms") | |
| print(f"{'β' * 48}{RESET}") | |
| print() | |
| # -- Update history --------------------------------------------------- | |
| history.append({"role": "user", "content": query}) | |
| history.append({"role": "assistant", "content": "".join(full_response)}) | |
| if len(history) > MAX_HISTORY_TURNS * 2: | |
| history = history[-(MAX_HISTORY_TURNS * 2):] | |
| if __name__ == "__main__": | |
| main() | |