import re import random import threading import gradio as gr from pathlib import Path from llama_index.core import StorageContext, load_index_from_storage, Settings from llama_index.llms.llama_cpp import LlamaCPP from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.core.query_engine import RetrieverQueryEngine from src.retrieval.hybrid_retriever import GNNHybridRetriever from src.utils import UTF8LocalFileSystem, resolve_gguf_model_path # --------------------------------------------------------------------------- # Global state # --------------------------------------------------------------------------- query_engine: RetrieverQueryEngine | None = None init_status = "initializing" init_lock = threading.Lock() _last_raw = "" # stores partial response when the user clicks Stop _last_nodes: list = [] # stores the last retrieved nodes for source display CONTINUE_TRIGGERS = {"continue", "tiếp tục", "cont"} N_EXAMPLE_QUESTIONS = 5 # how many chips to show in the UI def _load_example_questions(n: int = N_EXAMPLE_QUESTIONS) -> list[str]: """ Pull real questions from the PubMedQA pqa_labeled split. Falls back to a hardcoded list if the dataset is unavailable. """ try: from datasets import load_dataset ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train") questions = [item["question"] for item in ds if item.get("question")] sample = random.sample(questions, min(n, len(questions))) return sample except Exception: return [ "Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?", "Can aerobic exercise reduce the risk of developing type 2 diabetes?", "Is hypertension a risk factor for chronic kidney disease progression?", "Does statin therapy reduce cardiovascular mortality in heart failure patients?", "Is there a link between gut microbiota and obesity?", ] EXAMPLE_QUESTIONS: list[str] = _load_example_questions() # --------------------------------------------------------------------------- # Initialisation (background thread) # --------------------------------------------------------------------------- def init_system(): global query_engine, init_status with init_lock: if query_engine is not None: return try: persist_dir = Path("./storage_graph") if not persist_dir.exists(): init_status = "error: storage_graph not found" return init_status = "loading index…" storage_context = StorageContext.from_defaults( persist_dir=str(persist_dir), fs=UTF8LocalFileSystem() ) init_status = "loading embedding model…" Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") def _set_status(msg): global init_status init_status = msg llm = LlamaCPP( model_path=resolve_gguf_model_path(status_callback=_set_status), temperature=0.0, max_new_tokens=2048, context_window=8192, model_kwargs={"n_threads": 4, "n_ctx": 8192, "n_batch": 512}, verbose=False, ) Settings.llm = llm init_status = "loading LLM…" index = load_index_from_storage(storage_context) init_status = "loading GNN retriever…" retriever = GNNHybridRetriever( index, "./storage_graph/pyg_data.pt", "./storage_graph/gnn_model.pth", alpha=0.5, top_k=3, ) query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=llm, streaming=True ) init_status = "ready" except Exception as e: init_status = f"error: {e}" threading.Thread(target=init_system, daemon=True).start() # --------------------------------------------------------------------------- # Response formatting helpers # --------------------------------------------------------------------------- def _format_sources(nodes: list) -> str: """Render retrieved graph nodes as a Markdown sources block.""" if not nodes: return "" lines = ["**📄 Retrieved Sources**\n"] for i, n in enumerate(nodes, 1): snippet = n.node.text[:150].replace("\n", " ").strip() score = f"{n.score:.3f}" if n.score is not None else "—" lines.append(f"{i}. `{snippet}` *(score: {score})*") return "\n".join(lines) def _format_thinking_stream(accumulated: str) -> str: """Render an in-progress block as a Markdown blockquote.""" thinking_so_far = accumulated.split("", 1)[1] lines = thinking_so_far.replace("\n", "\n> ") return f"💭 **Reasoning…**\n\n> {lines}" def _parse_final_response(raw: str) -> str: """ Split raw LLM output into an optional collapsible Reasoning block followed by the Answer. """ thinking = re.search(r"(.*?)", raw, re.DOTALL) answer = re.sub(r".*?", "", raw, flags=re.DOTALL).strip() if thinking: thought = thinking.group(1).strip() return ( f"
\n💭 Reasoning\n\n" f"
{thought}
\n" f"
\n\n**Answer:**\n\n{answer}" ) return answer # --------------------------------------------------------------------------- # Chat handler # --------------------------------------------------------------------------- def ask_question(message: str, history: list): global _last_raw, _last_nodes if query_engine is None: yield "⏳ System is still initializing — please wait a moment and try again." return # ── Continue mode: resume a previously interrupted stream ── if message.strip().lower() in CONTINUE_TRIGGERS: if not _last_raw: yield "Nothing to continue — ask a question first." return sources_md = _format_sources(_last_nodes) accumulated = _last_raw try: for chunk in Settings.llm.stream_complete(accumulated): accumulated += chunk.delta display = sources_md + "\n\n---\n\n" if "" in accumulated and "" not in accumulated: display += _format_thinking_stream(accumulated) else: display += accumulated yield display _last_raw = "" yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated) except GeneratorExit: _last_raw = accumulated raise except Exception as e: _last_raw = "" yield f"❌ {e}" return # ── Normal query ── accumulated = "" try: streaming_response = query_engine.query(message) # Retrieval is complete by the time .query() returns; grab the nodes. nodes = getattr(query_engine.retriever, "last_nodes", []) _last_nodes = nodes sources_md = _format_sources(nodes) for token in streaming_response.response_gen: accumulated += token display = sources_md + "\n\n---\n\n" if "" in accumulated and "" not in accumulated: display += _format_thinking_stream(accumulated) else: display += accumulated yield display _last_raw = "" yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated) except GeneratorExit: _last_raw = accumulated raise except Exception as e: _last_raw = "" yield f"❌ {e}" # --------------------------------------------------------------------------- # Status indicator # --------------------------------------------------------------------------- def get_status_html(): if init_status == "ready": return '
System Ready
' if init_status.startswith("error"): return ( '
⚠ Error
' ) label = init_status.capitalize() if init_status else "Initialising…" return f'
{label}
' # --------------------------------------------------------------------------- # Styles # --------------------------------------------------------------------------- def _build_example_chips_html(questions: list[str]) -> str: chips = "".join( f'" for q in questions ) return f'
{chips}
' CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap'); /* ── Base ── */ *, *::before, *::after { box-sizing: border-box; } body, .gradio-container { font-family: 'Inter', system-ui, sans-serif !important; background: #f0f2f5 !important; } /* ── Outer shell: single centred card ── */ .shell { max-width: 820px; margin: 32px auto; background: #ffffff; border-radius: 20px; box-shadow: 0 4px 24px rgba(0,0,0,0.08), 0 1px 4px rgba(0,0,0,0.04); overflow: hidden; } /* ── Header band ── */ .hdr { display: flex; align-items: center; justify-content: space-between; padding: 18px 24px; border-bottom: 1px solid #f1f3f6; background: #ffffff; } .status-bar { position: absolute; top: 18px; right: 24px; } .status-bar > .wrap { padding: 0 !important; background: transparent !important; border: none !important; } .brand { display: flex; align-items: center; gap: 12px; } .brand-icon { width: 42px; height: 42px; border-radius: 12px; background: linear-gradient(135deg, #eef2ff 0%, #e0e7ff 100%); display: flex; align-items: center; justify-content: center; font-size: 22px; flex-shrink: 0; } .brand-name { display: block; font-size: 15px; font-weight: 600; color: #111827; letter-spacing: -0.2px; } .brand-sub { display: block; font-size: 11.5px; color: #9ca3af; margin-top: 2px; font-weight: 400; } /* ── Status pill ── */ .status-pill { display: inline-flex; align-items: center; gap: 6px; padding: 5px 13px; border-radius: 999px; font-size: 12px; font-weight: 500; white-space: nowrap; transition: all 0.2s ease; } .status-pill.ready { background: #f0fdf4; color: #15803d; border: 1px solid #bbf7d0; } .status-pill.loading { background: #eff6ff; color: #1d4ed8; border: 1px solid #bfdbfe; } .status-pill.error { background: #fef2f2; color: #b91c1c; border: 1px solid #fecaca; cursor: help; } .dot { width: 7px; height: 7px; border-radius: 50%; background: currentColor; flex-shrink: 0; } .dot.pulse { animation: blink 1.4s ease-in-out infinite; } @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.15} } /* ── Chat area padding ── */ .chat-area { padding: 0 20px 8px; } /* ── Example question chips ── */ .eq-section { padding: 4px 20px 20px; border-top: 1px solid #f1f3f6; } .eq-label { font-size: 10.5px; font-weight: 600; letter-spacing: 0.8px; text-transform: uppercase; color: #c4c9d4; margin-bottom: 10px; } .eq-wrap { display: flex; flex-wrap: wrap; gap: 7px; } .eq-chip { background: #f8f9fb; border: 1px solid #e8eaed; border-radius: 20px; padding: 5px 13px; font-size: 12.5px; color: #4b5563; cursor: pointer; font-family: inherit; line-height: 1.4; transition: background 0.15s, border-color 0.15s, color 0.15s, transform 0.1s; } .eq-chip:hover { background: #eef2ff; border-color: #c7d2fe; color: #4338ca; transform: translateY(-1px); } .eq-chip:active { transform: translateY(0); } /* ── Gradio internals cleanup ── */ .gradio-container > .main { background: transparent !important; } .contain { background: transparent !important; } footer { display: none !important; } ::-webkit-scrollbar { width: 4px; } ::-webkit-scrollbar-thumb { background: #e5e7eb; border-radius: 99px; } """ _theme = gr.themes.Soft( primary_hue=gr.themes.colors.indigo, neutral_hue=gr.themes.colors.slate, font=gr.themes.GoogleFont("Inter"), ).set( body_background_fill="#f0f2f5", block_background_fill="transparent", block_border_width="0px", block_shadow="none", input_background_fill="#f8f9fb", input_border_color="#e8eaed", input_border_width="1px", button_primary_background_fill="#4f46e5", button_primary_background_fill_hover="#4338ca", button_primary_text_color="#ffffff", button_secondary_background_fill="#f8f9fb", button_secondary_background_fill_hover="#eef2ff", button_secondary_border_color="#e8eaed", button_secondary_text_color="#374151", ) # --------------------------------------------------------------------------- # Layout # --------------------------------------------------------------------------- with gr.Blocks(title="Healthcare GraphRAG") as demo: with gr.Column(elem_classes="shell"): # ── Header ── gr.HTML("""
🏥
Healthcare GraphRAG GNN-enhanced retrieval · Qwen3.5-4B · BGE-small-en-v1.5
""") # ── Chat interface ── with gr.Column(elem_classes="chat-area"): gr.ChatInterface( fn=ask_question, chatbot=gr.Chatbot( height=460, show_label=False, render_markdown=True, placeholder=( "
" "
🏥
" "
" "Ask a medical question
" "
" "Powered by graph-augmented reasoning
" "
" "Tip: type continue to resume an interrupted response
" "
" ), ), textbox=gr.Textbox( placeholder="Ask a medical question…", show_label=False, lines=1, max_lines=5, autofocus=True, container=False, submit_btn="Send", stop_btn="Stop", ), show_progress="hidden", ) # ── Example question chips (pure HTML — flex-wrap, no Gradio Row) ── gr.HTML( '
' '
Try asking
' + _build_example_chips_html(EXAMPLE_QUESTIONS) + "
" ) # ── Status indicator (live, refreshed every 3 s) ── status_html = gr.HTML(value=get_status_html(), elem_classes="status-bar") gr.Timer(value=3).tick(fn=get_status_html, outputs=[status_html]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=_theme, css=CSS)