Spaces:
Sleeping
Sleeping
| import re | |
| import random | |
| import threading | |
| import gradio as gr | |
| from pathlib import Path | |
| from llama_index.core import StorageContext, load_index_from_storage, Settings | |
| from llama_index.llms.llama_cpp import LlamaCPP | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from src.retrieval.hybrid_retriever import GNNHybridRetriever | |
| from src.utils import UTF8LocalFileSystem, resolve_gguf_model_path | |
| # --------------------------------------------------------------------------- | |
| # Global state | |
| # --------------------------------------------------------------------------- | |
| query_engine: RetrieverQueryEngine | None = None | |
| init_status = "initializing" | |
| init_lock = threading.Lock() | |
| _last_raw = "" # stores partial response when the user clicks Stop | |
| _last_nodes: list = [] # stores the last retrieved nodes for source display | |
| CONTINUE_TRIGGERS = {"continue", "tiếp tục", "cont"} | |
| N_EXAMPLE_QUESTIONS = 5 # how many chips to show in the UI | |
| def _load_example_questions(n: int = N_EXAMPLE_QUESTIONS) -> list[str]: | |
| """ | |
| Pull real questions from the PubMedQA pqa_labeled split. | |
| Falls back to a hardcoded list if the dataset is unavailable. | |
| """ | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train") | |
| questions = [item["question"] for item in ds if item.get("question")] | |
| sample = random.sample(questions, min(n, len(questions))) | |
| return sample | |
| except Exception: | |
| return [ | |
| "Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?", | |
| "Can aerobic exercise reduce the risk of developing type 2 diabetes?", | |
| "Is hypertension a risk factor for chronic kidney disease progression?", | |
| "Does statin therapy reduce cardiovascular mortality in heart failure patients?", | |
| "Is there a link between gut microbiota and obesity?", | |
| ] | |
| EXAMPLE_QUESTIONS: list[str] = _load_example_questions() | |
| # --------------------------------------------------------------------------- | |
| # Initialisation (background thread) | |
| # --------------------------------------------------------------------------- | |
| def init_system(): | |
| global query_engine, init_status | |
| with init_lock: | |
| if query_engine is not None: | |
| return | |
| try: | |
| persist_dir = Path("./storage_graph") | |
| if not persist_dir.exists(): | |
| init_status = "error: storage_graph not found" | |
| return | |
| init_status = "loading index…" | |
| storage_context = StorageContext.from_defaults( | |
| persist_dir=str(persist_dir), fs=UTF8LocalFileSystem() | |
| ) | |
| init_status = "loading embedding model…" | |
| Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
| def _set_status(msg): | |
| global init_status | |
| init_status = msg | |
| llm = LlamaCPP( | |
| model_path=resolve_gguf_model_path(status_callback=_set_status), | |
| temperature=0.0, | |
| max_new_tokens=2048, | |
| context_window=8192, | |
| model_kwargs={"n_threads": 4, "n_ctx": 8192, "n_batch": 512}, | |
| verbose=False, | |
| ) | |
| Settings.llm = llm | |
| init_status = "loading LLM…" | |
| index = load_index_from_storage(storage_context) | |
| init_status = "loading GNN retriever…" | |
| retriever = GNNHybridRetriever( | |
| index, | |
| "./storage_graph/pyg_data.pt", | |
| "./storage_graph/gnn_model.pth", | |
| alpha=0.5, | |
| top_k=3, | |
| ) | |
| query_engine = RetrieverQueryEngine.from_args( | |
| retriever=retriever, llm=llm, streaming=True | |
| ) | |
| init_status = "ready" | |
| except Exception as e: | |
| init_status = f"error: {e}" | |
| threading.Thread(target=init_system, daemon=True).start() | |
| # --------------------------------------------------------------------------- | |
| # Response formatting helpers | |
| # --------------------------------------------------------------------------- | |
| def _format_sources(nodes: list) -> str: | |
| """Render retrieved graph nodes as a Markdown sources block.""" | |
| if not nodes: | |
| return "" | |
| lines = ["**📄 Retrieved Sources**\n"] | |
| for i, n in enumerate(nodes, 1): | |
| snippet = n.node.text[:150].replace("\n", " ").strip() | |
| score = f"{n.score:.3f}" if n.score is not None else "—" | |
| lines.append(f"{i}. `{snippet}` *(score: {score})*") | |
| return "\n".join(lines) | |
| def _format_thinking_stream(accumulated: str) -> str: | |
| """Render an in-progress <think> block as a Markdown blockquote.""" | |
| thinking_so_far = accumulated.split("<think>", 1)[1] | |
| lines = thinking_so_far.replace("\n", "\n> ") | |
| return f"💭 **Reasoning…**\n\n> {lines}" | |
| def _parse_final_response(raw: str) -> str: | |
| """ | |
| Split raw LLM output into an optional collapsible Reasoning block | |
| followed by the Answer. | |
| """ | |
| thinking = re.search(r"<think>(.*?)</think>", raw, re.DOTALL) | |
| answer = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip() | |
| if thinking: | |
| thought = thinking.group(1).strip() | |
| return ( | |
| f"<details>\n<summary>💭 Reasoning</summary>\n\n" | |
| f"<blockquote><em>{thought}</em></blockquote>\n" | |
| f"</details>\n\n**Answer:**\n\n{answer}" | |
| ) | |
| return answer | |
| # --------------------------------------------------------------------------- | |
| # Chat handler | |
| # --------------------------------------------------------------------------- | |
| def ask_question(message: str, history: list): | |
| global _last_raw, _last_nodes | |
| if query_engine is None: | |
| yield "⏳ System is still initializing — please wait a moment and try again." | |
| return | |
| # ── Continue mode: resume a previously interrupted stream ── | |
| if message.strip().lower() in CONTINUE_TRIGGERS: | |
| if not _last_raw: | |
| yield "Nothing to continue — ask a question first." | |
| return | |
| sources_md = _format_sources(_last_nodes) | |
| accumulated = _last_raw | |
| try: | |
| for chunk in Settings.llm.stream_complete(accumulated): | |
| accumulated += chunk.delta | |
| display = sources_md + "\n\n---\n\n" | |
| if "<think>" in accumulated and "</think>" not in accumulated: | |
| display += _format_thinking_stream(accumulated) | |
| else: | |
| display += accumulated | |
| yield display | |
| _last_raw = "" | |
| yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated) | |
| except GeneratorExit: | |
| _last_raw = accumulated | |
| raise | |
| except Exception as e: | |
| _last_raw = "" | |
| yield f"❌ {e}" | |
| return | |
| # ── Normal query ── | |
| accumulated = "" | |
| try: | |
| streaming_response = query_engine.query(message) | |
| # Retrieval is complete by the time .query() returns; grab the nodes. | |
| nodes = getattr(query_engine.retriever, "last_nodes", []) | |
| _last_nodes = nodes | |
| sources_md = _format_sources(nodes) | |
| for token in streaming_response.response_gen: | |
| accumulated += token | |
| display = sources_md + "\n\n---\n\n" | |
| if "<think>" in accumulated and "</think>" not in accumulated: | |
| display += _format_thinking_stream(accumulated) | |
| else: | |
| display += accumulated | |
| yield display | |
| _last_raw = "" | |
| yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated) | |
| except GeneratorExit: | |
| _last_raw = accumulated | |
| raise | |
| except Exception as e: | |
| _last_raw = "" | |
| yield f"❌ {e}" | |
| # --------------------------------------------------------------------------- | |
| # Status indicator | |
| # --------------------------------------------------------------------------- | |
| def get_status_html(): | |
| if init_status == "ready": | |
| return '<div class="status-pill ready"><span class="dot"></span>System Ready</div>' | |
| if init_status.startswith("error"): | |
| return ( | |
| '<div class="status-pill error" title="' | |
| + init_status.replace('"', "") | |
| + '">⚠ Error</div>' | |
| ) | |
| label = init_status.capitalize() if init_status else "Initialising…" | |
| return f'<div class="status-pill loading"><span class="dot pulse"></span>{label}</div>' | |
| # --------------------------------------------------------------------------- | |
| # Styles | |
| # --------------------------------------------------------------------------- | |
| def _build_example_chips_html(questions: list[str]) -> str: | |
| chips = "".join( | |
| f'<button class="eq-chip" onclick="' | |
| f'(function(){{var tb=document.querySelector(\'textarea\');' | |
| f'if(tb){{var nv=Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype,\'value\');' | |
| f'nv.set.call(tb,{repr(q)});tb.dispatchEvent(new Event(\'input\',{{bubbles:true}}));}}}})()">' | |
| f"{q}</button>" | |
| for q in questions | |
| ) | |
| return f'<div class="eq-wrap">{chips}</div>' | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap'); | |
| /* ── Base ── */ | |
| *, *::before, *::after { box-sizing: border-box; } | |
| body, .gradio-container { | |
| font-family: 'Inter', system-ui, sans-serif !important; | |
| background: #f0f2f5 !important; | |
| } | |
| /* ── Outer shell: single centred card ── */ | |
| .shell { | |
| max-width: 820px; | |
| margin: 32px auto; | |
| background: #ffffff; | |
| border-radius: 20px; | |
| box-shadow: 0 4px 24px rgba(0,0,0,0.08), 0 1px 4px rgba(0,0,0,0.04); | |
| overflow: hidden; | |
| } | |
| /* ── Header band ── */ | |
| .hdr { | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| padding: 18px 24px; | |
| border-bottom: 1px solid #f1f3f6; | |
| background: #ffffff; | |
| } | |
| .status-bar { | |
| position: absolute; | |
| top: 18px; | |
| right: 24px; | |
| } | |
| .status-bar > .wrap { padding: 0 !important; background: transparent !important; border: none !important; } | |
| .brand { | |
| display: flex; | |
| align-items: center; | |
| gap: 12px; | |
| } | |
| .brand-icon { | |
| width: 42px; | |
| height: 42px; | |
| border-radius: 12px; | |
| background: linear-gradient(135deg, #eef2ff 0%, #e0e7ff 100%); | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-size: 22px; | |
| flex-shrink: 0; | |
| } | |
| .brand-name { | |
| display: block; | |
| font-size: 15px; | |
| font-weight: 600; | |
| color: #111827; | |
| letter-spacing: -0.2px; | |
| } | |
| .brand-sub { | |
| display: block; | |
| font-size: 11.5px; | |
| color: #9ca3af; | |
| margin-top: 2px; | |
| font-weight: 400; | |
| } | |
| /* ── Status pill ── */ | |
| .status-pill { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| padding: 5px 13px; | |
| border-radius: 999px; | |
| font-size: 12px; | |
| font-weight: 500; | |
| white-space: nowrap; | |
| transition: all 0.2s ease; | |
| } | |
| .status-pill.ready { background: #f0fdf4; color: #15803d; border: 1px solid #bbf7d0; } | |
| .status-pill.loading { background: #eff6ff; color: #1d4ed8; border: 1px solid #bfdbfe; } | |
| .status-pill.error { background: #fef2f2; color: #b91c1c; border: 1px solid #fecaca; cursor: help; } | |
| .dot { width: 7px; height: 7px; border-radius: 50%; background: currentColor; flex-shrink: 0; } | |
| .dot.pulse { animation: blink 1.4s ease-in-out infinite; } | |
| @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.15} } | |
| /* ── Chat area padding ── */ | |
| .chat-area { padding: 0 20px 8px; } | |
| /* ── Example question chips ── */ | |
| .eq-section { | |
| padding: 4px 20px 20px; | |
| border-top: 1px solid #f1f3f6; | |
| } | |
| .eq-label { | |
| font-size: 10.5px; | |
| font-weight: 600; | |
| letter-spacing: 0.8px; | |
| text-transform: uppercase; | |
| color: #c4c9d4; | |
| margin-bottom: 10px; | |
| } | |
| .eq-wrap { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 7px; | |
| } | |
| .eq-chip { | |
| background: #f8f9fb; | |
| border: 1px solid #e8eaed; | |
| border-radius: 20px; | |
| padding: 5px 13px; | |
| font-size: 12.5px; | |
| color: #4b5563; | |
| cursor: pointer; | |
| font-family: inherit; | |
| line-height: 1.4; | |
| transition: background 0.15s, border-color 0.15s, color 0.15s, transform 0.1s; | |
| } | |
| .eq-chip:hover { | |
| background: #eef2ff; | |
| border-color: #c7d2fe; | |
| color: #4338ca; | |
| transform: translateY(-1px); | |
| } | |
| .eq-chip:active { transform: translateY(0); } | |
| /* ── Gradio internals cleanup ── */ | |
| .gradio-container > .main { background: transparent !important; } | |
| .contain { background: transparent !important; } | |
| footer { display: none !important; } | |
| ::-webkit-scrollbar { width: 4px; } | |
| ::-webkit-scrollbar-thumb { background: #e5e7eb; border-radius: 99px; } | |
| """ | |
| _theme = gr.themes.Soft( | |
| primary_hue=gr.themes.colors.indigo, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=gr.themes.GoogleFont("Inter"), | |
| ).set( | |
| body_background_fill="#f0f2f5", | |
| block_background_fill="transparent", | |
| block_border_width="0px", | |
| block_shadow="none", | |
| input_background_fill="#f8f9fb", | |
| input_border_color="#e8eaed", | |
| input_border_width="1px", | |
| button_primary_background_fill="#4f46e5", | |
| button_primary_background_fill_hover="#4338ca", | |
| button_primary_text_color="#ffffff", | |
| button_secondary_background_fill="#f8f9fb", | |
| button_secondary_background_fill_hover="#eef2ff", | |
| button_secondary_border_color="#e8eaed", | |
| button_secondary_text_color="#374151", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Layout | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Healthcare GraphRAG") as demo: | |
| with gr.Column(elem_classes="shell"): | |
| # ── Header ── | |
| gr.HTML(""" | |
| <div class="hdr"> | |
| <div class="brand"> | |
| <div class="brand-icon">🏥</div> | |
| <div> | |
| <span class="brand-name">Healthcare GraphRAG</span> | |
| <span class="brand-sub">GNN-enhanced retrieval · Qwen3.5-4B · BGE-small-en-v1.5</span> | |
| </div> | |
| </div> | |
| </div> | |
| """) | |
| # ── Chat interface ── | |
| with gr.Column(elem_classes="chat-area"): | |
| gr.ChatInterface( | |
| fn=ask_question, | |
| chatbot=gr.Chatbot( | |
| height=460, | |
| show_label=False, | |
| render_markdown=True, | |
| placeholder=( | |
| "<div style='text-align:center;padding:56px 20px'>" | |
| "<div style='font-size:2rem;margin-bottom:10px'>🏥</div>" | |
| "<div style='font-size:15px;font-weight:600;color:#374151'>" | |
| "Ask a medical question</div>" | |
| "<div style='font-size:13px;margin-top:6px;color:#9ca3af'>" | |
| "Powered by graph-augmented reasoning</div>" | |
| "<div style='font-size:11px;margin-top:10px;color:#cbd5e1'>" | |
| "Tip: type <code>continue</code> to resume an interrupted response</div>" | |
| "</div>" | |
| ), | |
| ), | |
| textbox=gr.Textbox( | |
| placeholder="Ask a medical question…", | |
| show_label=False, | |
| lines=1, | |
| max_lines=5, | |
| autofocus=True, | |
| container=False, | |
| submit_btn="Send", | |
| stop_btn="Stop", | |
| ), | |
| show_progress="hidden", | |
| ) | |
| # ── Example question chips (pure HTML — flex-wrap, no Gradio Row) ── | |
| gr.HTML( | |
| '<div class="eq-section">' | |
| '<div class="eq-label">Try asking</div>' | |
| + _build_example_chips_html(EXAMPLE_QUESTIONS) | |
| + "</div>" | |
| ) | |
| # ── Status indicator (live, refreshed every 3 s) ── | |
| status_html = gr.HTML(value=get_status_html(), elem_classes="status-bar") | |
| gr.Timer(value=3).tick(fn=get_status_html, outputs=[status_html]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=_theme, css=CSS) | |