minhthien's picture
fix: make status pill live-update by using gr.HTML with Timer instead of static header HTML
3331528
Raw
History Blame Contribute Delete
16.3 kB
import re
import random
import threading
import gradio as gr
from pathlib import Path
from llama_index.core import StorageContext, load_index_from_storage, Settings
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from src.retrieval.hybrid_retriever import GNNHybridRetriever
from src.utils import UTF8LocalFileSystem, resolve_gguf_model_path
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
query_engine: RetrieverQueryEngine | None = None
init_status = "initializing"
init_lock = threading.Lock()
_last_raw = "" # stores partial response when the user clicks Stop
_last_nodes: list = [] # stores the last retrieved nodes for source display
CONTINUE_TRIGGERS = {"continue", "tiếp tục", "cont"}
N_EXAMPLE_QUESTIONS = 5 # how many chips to show in the UI
def _load_example_questions(n: int = N_EXAMPLE_QUESTIONS) -> list[str]:
"""
Pull real questions from the PubMedQA pqa_labeled split.
Falls back to a hardcoded list if the dataset is unavailable.
"""
try:
from datasets import load_dataset
ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")
questions = [item["question"] for item in ds if item.get("question")]
sample = random.sample(questions, min(n, len(questions)))
return sample
except Exception:
return [
"Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?",
"Can aerobic exercise reduce the risk of developing type 2 diabetes?",
"Is hypertension a risk factor for chronic kidney disease progression?",
"Does statin therapy reduce cardiovascular mortality in heart failure patients?",
"Is there a link between gut microbiota and obesity?",
]
EXAMPLE_QUESTIONS: list[str] = _load_example_questions()
# ---------------------------------------------------------------------------
# Initialisation (background thread)
# ---------------------------------------------------------------------------
def init_system():
global query_engine, init_status
with init_lock:
if query_engine is not None:
return
try:
persist_dir = Path("./storage_graph")
if not persist_dir.exists():
init_status = "error: storage_graph not found"
return
init_status = "loading index…"
storage_context = StorageContext.from_defaults(
persist_dir=str(persist_dir), fs=UTF8LocalFileSystem()
)
init_status = "loading embedding model…"
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
def _set_status(msg):
global init_status
init_status = msg
llm = LlamaCPP(
model_path=resolve_gguf_model_path(status_callback=_set_status),
temperature=0.0,
max_new_tokens=2048,
context_window=8192,
model_kwargs={"n_threads": 4, "n_ctx": 8192, "n_batch": 512},
verbose=False,
)
Settings.llm = llm
init_status = "loading LLM…"
index = load_index_from_storage(storage_context)
init_status = "loading GNN retriever…"
retriever = GNNHybridRetriever(
index,
"./storage_graph/pyg_data.pt",
"./storage_graph/gnn_model.pth",
alpha=0.5,
top_k=3,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever, llm=llm, streaming=True
)
init_status = "ready"
except Exception as e:
init_status = f"error: {e}"
threading.Thread(target=init_system, daemon=True).start()
# ---------------------------------------------------------------------------
# Response formatting helpers
# ---------------------------------------------------------------------------
def _format_sources(nodes: list) -> str:
"""Render retrieved graph nodes as a Markdown sources block."""
if not nodes:
return ""
lines = ["**📄 Retrieved Sources**\n"]
for i, n in enumerate(nodes, 1):
snippet = n.node.text[:150].replace("\n", " ").strip()
score = f"{n.score:.3f}" if n.score is not None else "—"
lines.append(f"{i}. `{snippet}` *(score: {score})*")
return "\n".join(lines)
def _format_thinking_stream(accumulated: str) -> str:
"""Render an in-progress <think> block as a Markdown blockquote."""
thinking_so_far = accumulated.split("<think>", 1)[1]
lines = thinking_so_far.replace("\n", "\n> ")
return f"💭 **Reasoning…**\n\n> {lines}"
def _parse_final_response(raw: str) -> str:
"""
Split raw LLM output into an optional collapsible Reasoning block
followed by the Answer.
"""
thinking = re.search(r"<think>(.*?)</think>", raw, re.DOTALL)
answer = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL).strip()
if thinking:
thought = thinking.group(1).strip()
return (
f"<details>\n<summary>💭 Reasoning</summary>\n\n"
f"<blockquote><em>{thought}</em></blockquote>\n"
f"</details>\n\n**Answer:**\n\n{answer}"
)
return answer
# ---------------------------------------------------------------------------
# Chat handler
# ---------------------------------------------------------------------------
def ask_question(message: str, history: list):
global _last_raw, _last_nodes
if query_engine is None:
yield "⏳ System is still initializing — please wait a moment and try again."
return
# ── Continue mode: resume a previously interrupted stream ──
if message.strip().lower() in CONTINUE_TRIGGERS:
if not _last_raw:
yield "Nothing to continue — ask a question first."
return
sources_md = _format_sources(_last_nodes)
accumulated = _last_raw
try:
for chunk in Settings.llm.stream_complete(accumulated):
accumulated += chunk.delta
display = sources_md + "\n\n---\n\n"
if "<think>" in accumulated and "</think>" not in accumulated:
display += _format_thinking_stream(accumulated)
else:
display += accumulated
yield display
_last_raw = ""
yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated)
except GeneratorExit:
_last_raw = accumulated
raise
except Exception as e:
_last_raw = ""
yield f"❌ {e}"
return
# ── Normal query ──
accumulated = ""
try:
streaming_response = query_engine.query(message)
# Retrieval is complete by the time .query() returns; grab the nodes.
nodes = getattr(query_engine.retriever, "last_nodes", [])
_last_nodes = nodes
sources_md = _format_sources(nodes)
for token in streaming_response.response_gen:
accumulated += token
display = sources_md + "\n\n---\n\n"
if "<think>" in accumulated and "</think>" not in accumulated:
display += _format_thinking_stream(accumulated)
else:
display += accumulated
yield display
_last_raw = ""
yield sources_md + "\n\n---\n\n" + _parse_final_response(accumulated)
except GeneratorExit:
_last_raw = accumulated
raise
except Exception as e:
_last_raw = ""
yield f"❌ {e}"
# ---------------------------------------------------------------------------
# Status indicator
# ---------------------------------------------------------------------------
def get_status_html():
if init_status == "ready":
return '<div class="status-pill ready"><span class="dot"></span>System Ready</div>'
if init_status.startswith("error"):
return (
'<div class="status-pill error" title="'
+ init_status.replace('"', "")
+ '">⚠ Error</div>'
)
label = init_status.capitalize() if init_status else "Initialising…"
return f'<div class="status-pill loading"><span class="dot pulse"></span>{label}</div>'
# ---------------------------------------------------------------------------
# Styles
# ---------------------------------------------------------------------------
def _build_example_chips_html(questions: list[str]) -> str:
chips = "".join(
f'<button class="eq-chip" onclick="'
f'(function(){{var tb=document.querySelector(\'textarea\');'
f'if(tb){{var nv=Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype,\'value\');'
f'nv.set.call(tb,{repr(q)});tb.dispatchEvent(new Event(\'input\',{{bubbles:true}}));}}}})()">'
f"{q}</button>"
for q in questions
)
return f'<div class="eq-wrap">{chips}</div>'
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
/* ── Base ── */
*, *::before, *::after { box-sizing: border-box; }
body, .gradio-container {
font-family: 'Inter', system-ui, sans-serif !important;
background: #f0f2f5 !important;
}
/* ── Outer shell: single centred card ── */
.shell {
max-width: 820px;
margin: 32px auto;
background: #ffffff;
border-radius: 20px;
box-shadow: 0 4px 24px rgba(0,0,0,0.08), 0 1px 4px rgba(0,0,0,0.04);
overflow: hidden;
}
/* ── Header band ── */
.hdr {
display: flex;
align-items: center;
justify-content: space-between;
padding: 18px 24px;
border-bottom: 1px solid #f1f3f6;
background: #ffffff;
}
.status-bar {
position: absolute;
top: 18px;
right: 24px;
}
.status-bar > .wrap { padding: 0 !important; background: transparent !important; border: none !important; }
.brand {
display: flex;
align-items: center;
gap: 12px;
}
.brand-icon {
width: 42px;
height: 42px;
border-radius: 12px;
background: linear-gradient(135deg, #eef2ff 0%, #e0e7ff 100%);
display: flex;
align-items: center;
justify-content: center;
font-size: 22px;
flex-shrink: 0;
}
.brand-name {
display: block;
font-size: 15px;
font-weight: 600;
color: #111827;
letter-spacing: -0.2px;
}
.brand-sub {
display: block;
font-size: 11.5px;
color: #9ca3af;
margin-top: 2px;
font-weight: 400;
}
/* ── Status pill ── */
.status-pill {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 5px 13px;
border-radius: 999px;
font-size: 12px;
font-weight: 500;
white-space: nowrap;
transition: all 0.2s ease;
}
.status-pill.ready { background: #f0fdf4; color: #15803d; border: 1px solid #bbf7d0; }
.status-pill.loading { background: #eff6ff; color: #1d4ed8; border: 1px solid #bfdbfe; }
.status-pill.error { background: #fef2f2; color: #b91c1c; border: 1px solid #fecaca; cursor: help; }
.dot { width: 7px; height: 7px; border-radius: 50%; background: currentColor; flex-shrink: 0; }
.dot.pulse { animation: blink 1.4s ease-in-out infinite; }
@keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.15} }
/* ── Chat area padding ── */
.chat-area { padding: 0 20px 8px; }
/* ── Example question chips ── */
.eq-section {
padding: 4px 20px 20px;
border-top: 1px solid #f1f3f6;
}
.eq-label {
font-size: 10.5px;
font-weight: 600;
letter-spacing: 0.8px;
text-transform: uppercase;
color: #c4c9d4;
margin-bottom: 10px;
}
.eq-wrap {
display: flex;
flex-wrap: wrap;
gap: 7px;
}
.eq-chip {
background: #f8f9fb;
border: 1px solid #e8eaed;
border-radius: 20px;
padding: 5px 13px;
font-size: 12.5px;
color: #4b5563;
cursor: pointer;
font-family: inherit;
line-height: 1.4;
transition: background 0.15s, border-color 0.15s, color 0.15s, transform 0.1s;
}
.eq-chip:hover {
background: #eef2ff;
border-color: #c7d2fe;
color: #4338ca;
transform: translateY(-1px);
}
.eq-chip:active { transform: translateY(0); }
/* ── Gradio internals cleanup ── */
.gradio-container > .main { background: transparent !important; }
.contain { background: transparent !important; }
footer { display: none !important; }
::-webkit-scrollbar { width: 4px; }
::-webkit-scrollbar-thumb { background: #e5e7eb; border-radius: 99px; }
"""
_theme = gr.themes.Soft(
primary_hue=gr.themes.colors.indigo,
neutral_hue=gr.themes.colors.slate,
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#f0f2f5",
block_background_fill="transparent",
block_border_width="0px",
block_shadow="none",
input_background_fill="#f8f9fb",
input_border_color="#e8eaed",
input_border_width="1px",
button_primary_background_fill="#4f46e5",
button_primary_background_fill_hover="#4338ca",
button_primary_text_color="#ffffff",
button_secondary_background_fill="#f8f9fb",
button_secondary_background_fill_hover="#eef2ff",
button_secondary_border_color="#e8eaed",
button_secondary_text_color="#374151",
)
# ---------------------------------------------------------------------------
# Layout
# ---------------------------------------------------------------------------
with gr.Blocks(title="Healthcare GraphRAG") as demo:
with gr.Column(elem_classes="shell"):
# ── Header ──
gr.HTML("""
<div class="hdr">
<div class="brand">
<div class="brand-icon">🏥</div>
<div>
<span class="brand-name">Healthcare GraphRAG</span>
<span class="brand-sub">GNN-enhanced retrieval · Qwen3.5-4B · BGE-small-en-v1.5</span>
</div>
</div>
</div>
""")
# ── Chat interface ──
with gr.Column(elem_classes="chat-area"):
gr.ChatInterface(
fn=ask_question,
chatbot=gr.Chatbot(
height=460,
show_label=False,
render_markdown=True,
placeholder=(
"<div style='text-align:center;padding:56px 20px'>"
"<div style='font-size:2rem;margin-bottom:10px'>🏥</div>"
"<div style='font-size:15px;font-weight:600;color:#374151'>"
"Ask a medical question</div>"
"<div style='font-size:13px;margin-top:6px;color:#9ca3af'>"
"Powered by graph-augmented reasoning</div>"
"<div style='font-size:11px;margin-top:10px;color:#cbd5e1'>"
"Tip: type <code>continue</code> to resume an interrupted response</div>"
"</div>"
),
),
textbox=gr.Textbox(
placeholder="Ask a medical question…",
show_label=False,
lines=1,
max_lines=5,
autofocus=True,
container=False,
submit_btn="Send",
stop_btn="Stop",
),
show_progress="hidden",
)
# ── Example question chips (pure HTML — flex-wrap, no Gradio Row) ──
gr.HTML(
'<div class="eq-section">'
'<div class="eq-label">Try asking</div>'
+ _build_example_chips_html(EXAMPLE_QUESTIONS)
+ "</div>"
)
# ── Status indicator (live, refreshed every 3 s) ──
status_html = gr.HTML(value=get_status_html(), elem_classes="status-bar")
gr.Timer(value=3).tick(fn=get_status_html, outputs=[status_html])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=_theme, css=CSS)