GraphRAG-Live / ui.py
aayush226's picture
Update ui.py
48a34ca verified
# ui.py
import gradio as gr
import html
from datetime import date as _date
from app import (
add_doc_endpoint,
ask_endpoint,
metrics_endpoint,
DocInput,
QuestionInput,
)
# Helpers
def _parse_answer_sections(answer_text: str):
lines = [l.strip() for l in (answer_text or "").splitlines() if l.strip()]
out = {
"main": "",
"citations": "",
"graph_reasoning": "",
"confidence": "",
"knobs": "",
"knobs_explain": "",
}
main_parts = []
for ln in lines:
ll = ln.lower()
if ll.startswith("citations:"):
out["citations"] = ln.split(":", 1)[1].strip()
elif ll.startswith("graph reasoning:") or ll.startswith("graphreasoning:"):
out["graph_reasoning"] = ln.split(":", 1)[1].strip()
elif ll.startswith("confidence:"):
out["confidence"] = ln.split(":", 1)[1].strip()
elif ll.startswith("knobs explain:"):
out["knobs_explain"] = ln.split(":", 1)[1].strip()
elif ll.startswith("knobs:"):
out["knobs"] = ln.split(":", 1)[1].strip()
else:
main_parts.append(ln)
out["main"] = " ".join(main_parts).strip() or (answer_text or "").strip()
return out
def _confidence_class(conf: str) -> str:
c = (conf or "").strip().lower()
if c.startswith("high"):
return "badge-high"
if c.startswith("medium"):
return "badge-medium"
if c.startswith("low"):
return "badge-low"
return "badge-none"
def _render_answer_card(answer_text: str) -> str:
sec = _parse_answer_sections(answer_text)
conf_cls = _confidence_class(sec["confidence"])
main = html.escape(sec["main"])
citations = html.escape(sec["citations"] or "None")
greason = html.escape(sec["graph_reasoning"] or "—")
conf = html.escape(sec["confidence"] or "—")
knobs = html.escape(sec["knobs"] or "—")
knobs_explain = html.escape(sec["knobs_explain"] or "—")
return f"""
<div class="card">
<div class="card-title">Answer</div>
<div class="answer">{main}</div>
<div class="meta">
<span class="badge {conf_cls}">{conf}</span>
</div>
<div class="sub"><b>Citations:</b> {citations}</div>
<div class="sub"><b>Graph reasoning:</b> {greason}</div>
<div class="sub"><b>Knobs effect:</b> {knobs}</div>
<div class="sub"><b>Knobs explain:</b> {knobs_explain}</div>
</div>
"""
def _render_evidence_markdown(evidence_list):
if not evidence_list:
return "_No evidence returned._"
lines = []
for i, chunk in enumerate(evidence_list, 1):
chunk = chunk.strip()
lines.append(f"**E{i}.** {chunk}")
return "\n\n".join(lines)
def _wrap_svg(svg: str) -> str:
if not svg or "<svg" not in svg:
return "<div class='graph-empty'>No graph</div>"
return f"""<div class="graph-wrap">{svg}</div>"""
# direct function calls, no HTTP
def metrics_ui():
try:
j = metrics_endpoint()
if j.get("status") != "ok":
return f"Error: {j}"
r = j["results"]
return f"""
### 📊 Evaluation Results
**Baseline (cosine-only)**
- hit@10: {r['baseline']['hit@10']:.2f}
- nDCG@10: {r['baseline']['nDCG@10']:.2f}
**Hybrid (GraphRAG)**
- hit@10: {r['hybrid']['hit@10']:.2f}
- nDCG@10: {r['hybrid']['nDCG@10']:.2f}
**Other**
- Citation correctness: {r['citation_correctness']:.2f}
- Avg latency (s): {r['avg_latency_sec']:.2f}
"""
except Exception as e:
return f"Error: {e}"
def add_doc_ui(text, source="user", date_val=None, time_val=None):
# Build ISO timestamp if a date was picked
ts_iso = ""
if date_val:
# gr.Date may return a datetime.date or a 'YYYY-MM-DD' string
if isinstance(date_val, _date):
dstr = date_val.isoformat()
else:
dstr = str(date_val)
# time_val can be None, "HH:MM" (gr.Time) or "HH:MM:SS"
tstr = (time_val or "00:00").strip()
if len(tstr) == 5: # HH:MM -> add seconds
tstr = f"{tstr}:00"
ts_iso = f"{dstr}T{tstr}Z"
try:
doc = DocInput(text=text, source=source, timestamp=ts_iso or None)
j = add_doc_endpoint(doc)
return "\n".join(j.get("logs", [])) or "No logs."
except Exception as e:
return f"Error: {e}"
def ask_ui(question, w_cos, w_path, w_fresh, w_deg):
try:
q = QuestionInput(
question=question,
w_cos=w_cos,
w_path=w_path,
w_fresh=w_fresh,
w_deg=w_deg,
)
j = ask_endpoint(q)
except Exception as e:
err = f"Error: {e}"
return (
_render_answer_card("I don’t know based on the given evidence.\nConfidence: Low"),
"_No evidence returned._",
err,
"<div id='graph' style='height:600px'></div>",
{},
)
answer_html = _render_answer_card(j.get("answer", ""))
evidence_md = _render_evidence_markdown(j.get("evidence", []))
logs_txt = "\n".join(j.get("logs", [])) or "No logs."
# Prefer D3 container fall back to server SVG
graph_json = j.get("subgraph_json", {})
if graph_json and graph_json.get("nodes"):
graph_html_value = "<div id='graph' style='height:600px'></div>"
else:
graph_html_value = _wrap_svg(j.get("subgraph_svg", ""))
return (answer_html, evidence_md, logs_txt, graph_html_value, graph_json)
# UI
with gr.Blocks(
css="""
/* Layout & theme */
body { background: #0b0f14; color: #e6edf3; }
.gradio-container { max-width: 1180px !important; }
.section-title { font-size: 22px; font-weight: 700; margin: 6px 0 12px; }
/* Cards */
.card { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px; padding: 14px; }
.card-title { font-size: 16px; letter-spacing: .3px; color: #9fb3c8; margin-bottom: 8px; text-transform: uppercase; }
.answer { font-size: 18px; line-height: 1.5; margin-bottom: 8px; }
.sub { color: #a8b3bf; margin-top: 6px; font-size: 14px; }
/* Badges */
.badge { padding: 3px 10px; border-radius: 999px; font-size: 12px; font-weight: 700; display: inline-block; }
.badge-high { background: #12391a; color: #6ee787; border: 1px solid #285f36; }
.badge-medium { background: #3a2b13; color: #ffd277; border: 1px solid #6b4e1f; }
.badge-low { background: #3b1616; color: #ff9492; border: 1px solid #6b2020; }
.badge-none { background: #223; color: #9fb3c8; border: 1px solid #334; }
/* Graph */
.graph-wrap { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px;
padding: 12px; height: 460px; overflow: auto; }
.graph-empty { color: #9fb3c8; font-style: italic; padding: 16px; }
/* Logs */
#logs-box textarea {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace !important;
max-height: 280px !important;
overflow-y: auto !important;
}
"""
) as demo:
gr.Markdown("### 🚀 GraphRAG — Live Demo")
with gr.Tab("Add Document"):
with gr.Row():
with gr.Column(scale=3):
text_in = gr.Textbox(
label="Document",
lines=10,
placeholder="Paste text to inject into Graph + Vector DB…",
)
with gr.Column(scale=1):
source_in = gr.Textbox(label="Source", value="user")
# Date & Time pickers
if hasattr(gr, "Date"):
ts_date = gr.Date(label="Date (optional)")
else:
ts_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)")
if hasattr(gr, "Time"):
ts_time = gr.Time(label="Time (optional)", value="00:00")
else:
ts_time = gr.Textbox(label="Time (HH:MM, optional)", value="00:00")
add_btn = gr.Button("Add Doc", variant="primary")
add_logs = gr.Textbox(label="Ingestion Logs", lines=14, elem_id="logs-box")
add_btn.click(
add_doc_ui,
inputs=[text_in, source_in, ts_date, ts_time],
outputs=add_logs
)
with gr.Tab("Ask Question"):
with gr.Row():
q_in = gr.Textbox(
label="Question", placeholder="e.g., Who acquired Instagram?"
)
ask_btn = gr.Button("Ask", variant="primary")
with gr.Accordion("Rerank Weights", open=False):
w_cos = gr.Slider(0, 1, value=0.60, step=0.05, label="Cosine weight")
w_path = gr.Slider(0, 1, value=0.20, step=0.05, label="Path proximity weight")
w_fresh = gr.Slider(0, 1, value=0.15, step=0.05, label="Freshness weight")
w_deg = gr.Slider(0, 1, value=0.05, step=0.05, label="Degree norm weight")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<div class='section-title'>Answer</div>")
ans_html = gr.HTML(value=_render_answer_card("Ask something to see results."))
evid = gr.Accordion("Evidence (ranked)", open=True)
with evid:
evid_md = gr.Markdown()
logs = gr.Accordion("Debug logs", open=False)
with logs:
logs_txt = gr.Textbox(lines=14, elem_id="logs-box")
with gr.Column(scale=1):
gr.Markdown("<div class='section-title'>Evidence Graph</div>")
graph_html = gr.HTML(value="<div id='graph' style='height:600px'></div>")
graph_data = gr.JSON(label="graph-data", visible=False)
# ask -> 5 outputs (answer, evidence, logs, graph container, graph JSON)
ask_btn.click(
ask_ui,
inputs=[q_in, w_cos, w_path, w_fresh, w_deg],
outputs=[ans_html, evid_md, logs_txt, graph_html, graph_data],
)
with gr.Tab("Metrics"):
metrics_btn = gr.Button("Run Evaluation", variant="primary")
metrics_out = gr.Markdown("Click run to evaluate baseline vs hybrid.")
metrics_btn.click(metrics_ui, inputs=[], outputs=metrics_out)
# D3 renderer (zoom + pan)
DRAW_JS = r"""
(value) => {
const el = document.querySelector("#graph");
if (!el) return null;
el.innerHTML = "";
if (!value || !value.nodes || value.nodes.length === 0) {
el.innerHTML = "<div class='graph-empty'>No graph</div>";
return null;
}
function ensureD3(cb) {
if (window.d3) return cb();
const s = document.createElement("script");
s.src = "https://cdn.jsdelivr.net/npm/d3@7";
s.onload = cb;
document.head.appendChild(s);
}
ensureD3(() => {
const width = el.clientWidth || 900;
const height = 600;
const svg = d3.select(el).append("svg")
.attr("viewBox", [0, 0, width, height])
.attr("preserveAspectRatio", "xMidYMid meet")
.style("width", "100%")
.style("height", "100%");
// Zoomable container
const container = svg.append("g");
// Enable zoom & pan
svg.call(
d3.zoom()
.scaleExtent([0.2, 3])
.on("zoom", (event) => {
container.attr("transform", event.transform);
})
);
const sim = d3.forceSimulation(value.nodes)
.force("link", d3.forceLink(value.links).id(d => d.id).distance(140).strength(0.4))
.force("charge", d3.forceManyBody().strength(-220))
.force("center", d3.forceCenter(width / 2, height / 2));
const link = container.append("g")
.attr("stroke", "#999")
.attr("stroke-opacity", 0.6)
.selectAll("line")
.data(value.links)
.enter().append("line")
.attr("stroke-width", 1.5);
const edgeLabels = container.append("g")
.selectAll("text")
.data(value.links)
.enter().append("text")
.attr("font-size", 10)
.attr("fill", "#bbb")
.text(d => d.label);
const node = container.append("g")
.selectAll("circle")
.data(value.nodes)
.enter().append("circle")
.attr("r", 12)
.attr("fill", "#69b3a2")
.attr("stroke", "#2dd4bf")
.attr("stroke-width", 1.2)
.call(d3.drag()
.on("start", (event, d) => { if (!event.active) sim.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y; })
.on("drag", (event, d) => { d.fx = event.x; d.fy = event.y; })
.on("end", (event, d) => { if (!event.active) sim.alphaTarget(0); d.fx = null; d.fy = null; })
);
const labels = container.append("g")
.selectAll("text")
.data(value.nodes)
.enter().append("text")
.attr("font-size", 12)
.attr("fill", "#ddd")
.attr("dy", 18)
.attr("text-anchor", "middle")
.text(d => d.id);
sim.on("tick", () => {
link
.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
edgeLabels
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2);
node
.attr("cx", d => d.x)
.attr("cy", d => d.y);
labels
.attr("x", d => d.x)
.attr("y", d => d.y);
});
});
return null;
}
"""
graph_data.change(lambda x: x, inputs=graph_data, outputs=graph_data).then(
None, inputs=graph_data, outputs=None, js=DRAW_JS
)
if __name__ == "__main__":
demo.launch()