Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files
README.md
CHANGED
|
@@ -1,14 +1,33 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
--
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GraphRAG-Live
|
| 2 |
+
|
| 3 |
+
**Hybrid Retrieval-Augmented Generation (RAG) with Graph + Vectors.**
|
| 4 |
+
|
| 5 |
+
This project shows how knowledge graphs (Neo4j Aura) and vector databases (Qdrant) can be combined with re-ranking heuristics to build a smarter, cheaper and more explainable RAG system.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 🌟 Features
|
| 10 |
+
- **Hybrid Retrieval:** Combines semantic search (Qdrant) with graph proximity scoring (Neo4j).
|
| 11 |
+
- **Dynamic Knowledge Injection:** Add new documents on the fly → pipeline updates instantly.
|
| 12 |
+
- **Evidence Subgraphs:** Each answer includes a small 2-hop evidence graph.
|
| 13 |
+
- **Metrics Dashboard:** Compare GraphRAG vs. baseline RAG on hit@10, nDCG@10, citation correctness.
|
| 14 |
+
- **Hosted Demo:** Deployed via Hugging Face Spaces (Gradio UI).
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## 🏗️ Architecture
|
| 19 |
+
|
| 20 |
+
```text
|
| 21 |
+
User Question
|
| 22 |
+
│
|
| 23 |
+
▼
|
| 24 |
+
[Qdrant: semantic chunks] + [Neo4j: graph proximity]
|
| 25 |
+
│
|
| 26 |
+
▼
|
| 27 |
+
Reranker (cosine + path proximity + freshness + degree)
|
| 28 |
+
│
|
| 29 |
+
▼
|
| 30 |
+
Answer Generator (OpenAI)
|
| 31 |
+
│
|
| 32 |
+
▼
|
| 33 |
+
Evidence Subgraph + Answer + Citations
|
app.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import List, Literal
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import os, json
|
| 6 |
+
|
| 7 |
+
from text import chunk_text
|
| 8 |
+
from vec import embed_and_upsert, search
|
| 9 |
+
from kg import (
|
| 10 |
+
extract_and_insert,
|
| 11 |
+
get_subgraph,
|
| 12 |
+
compute_path_proximity,
|
| 13 |
+
compute_degree_norm,
|
| 14 |
+
)
|
| 15 |
+
from rerank import rerank_candidates
|
| 16 |
+
from eval import evaluate
|
| 17 |
+
from utils import compute_freshness
|
| 18 |
+
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
|
| 22 |
+
load_dotenv(override=True)
|
| 23 |
+
|
| 24 |
+
key = os.environ.get("OPENAI_API_KEY", "").strip()
|
| 25 |
+
client = OpenAI(api_key=key)
|
| 26 |
+
|
| 27 |
+
app = FastAPI()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Schemas for Pydantic + structured output
|
| 31 |
+
class DocInput(BaseModel):
|
| 32 |
+
text: str
|
| 33 |
+
source: str = "user"
|
| 34 |
+
timestamp: datetime = datetime.now()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class QuestionInput(BaseModel):
|
| 38 |
+
question: str
|
| 39 |
+
w_cos: float = 0.60
|
| 40 |
+
w_path: float = 0.20
|
| 41 |
+
w_fresh: float = 0.15
|
| 42 |
+
w_deg: float = 0.05
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# LLM output requirement (enforceing this with JSON output + Pydantic)
|
| 46 |
+
class LLMAnswer(BaseModel):
|
| 47 |
+
answer: str = Field(..., description="One-sentence final answer")
|
| 48 |
+
citations: List[str] = Field(
|
| 49 |
+
default_factory=list,
|
| 50 |
+
description="Evidence IDs like E1, E3 that support the answer",
|
| 51 |
+
)
|
| 52 |
+
graph_reasoning: str = Field(
|
| 53 |
+
"", description="How the graph helped, or 'Not used'"
|
| 54 |
+
)
|
| 55 |
+
confidence: Literal["High", "Medium", "Low"] = "Low"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Helpers for the explanation on the controls (weights)
|
| 59 |
+
def _get_scores(c, w_cos, w_path, w_fresh, w_deg):
|
| 60 |
+
cos = float(c.get("cosine", c.get("cosine_sim", 0.0)) or 0.0)
|
| 61 |
+
pp = float(c.get("path_proximity", 0.0) or 0.0)
|
| 62 |
+
fr = float(c.get("freshness_decay", 0.0) or 0.0)
|
| 63 |
+
dg = float(c.get("degree_norm", 0.0) or 0.0)
|
| 64 |
+
final = w_cos * cos + w_path * pp + w_fresh * fr + w_deg * dg
|
| 65 |
+
return cos, pp, fr, dg, final
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _build_knobs_breakdown(numbered, w_cos, w_path, w_fresh, w_deg):
|
| 69 |
+
"""
|
| 70 |
+
Returns (knobs_line, knobs_explain) strings. Uses top 1 only and runner up if available.
|
| 71 |
+
"""
|
| 72 |
+
if not numbered:
|
| 73 |
+
return "", ""
|
| 74 |
+
|
| 75 |
+
idx1, c1 = numbered[0]
|
| 76 |
+
cos1, pp1, fr1, dg1, fin1 = _get_scores(c1, w_cos, w_path, w_fresh, w_deg)
|
| 77 |
+
|
| 78 |
+
# Optional runner up
|
| 79 |
+
ru_piece, explain = "", ""
|
| 80 |
+
if len(numbered) > 1:
|
| 81 |
+
idx2, c2 = numbered[1]
|
| 82 |
+
cos2, pp2, fr2, dg2, fin2 = _get_scores(c2, w_cos, w_path, w_fresh, w_deg)
|
| 83 |
+
margin = fin1 - fin2
|
| 84 |
+
ru_piece = f"; Runner-up E{idx2}={fin2:.3f}; Margin={margin:+.3f}"
|
| 85 |
+
|
| 86 |
+
# Contribution of the deltas (weighted)
|
| 87 |
+
deltas = [
|
| 88 |
+
("path", w_path * (pp1 - pp2), pp1, pp2, w_path),
|
| 89 |
+
("freshness", w_fresh * (fr1 - fr2), fr1, fr2, w_fresh),
|
| 90 |
+
("cosine", w_cos * (cos1 - cos2), cos1, cos2, w_cos),
|
| 91 |
+
("degree", w_deg * (dg1 - dg2), dg1, dg2, w_deg),
|
| 92 |
+
]
|
| 93 |
+
deltas.sort(key=lambda x: x[1], reverse=True)
|
| 94 |
+
# Pick top positive drivers
|
| 95 |
+
drivers = [f"{name} ({d:+.3f})" for name, d, *_ in deltas if d > 0.002][:3]
|
| 96 |
+
# A short natural language sentence
|
| 97 |
+
if drivers:
|
| 98 |
+
top_names = ", ".join(drivers)
|
| 99 |
+
else:
|
| 100 |
+
top_names = "mostly cosine similarity (others were negligible)"
|
| 101 |
+
explain = (
|
| 102 |
+
f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
|
| 103 |
+
f"E{idx1} leads by {margin:+.3f}. Biggest lifts vs E{idx2}: {top_names}."
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
# No runner up but sstill provide a brief note
|
| 107 |
+
explain = (
|
| 108 |
+
f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
|
| 109 |
+
f"the top candidate E{idx1} scored {fin1:.3f}."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
knobs_line = (
|
| 113 |
+
f"Weights→ cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}. "
|
| 114 |
+
f"E{idx1} final={fin1:.3f} = {w_cos:.2f}×{cos1:.3f} + {w_path:.2f}×{pp1:.3f} + "
|
| 115 |
+
f"{w_fresh:.2f}×{fr1:.3f} + {w_deg:.2f}×{dg1:.3f}{ru_piece}; Cosine-only(E{idx1})={cos1:.3f}."
|
| 116 |
+
)
|
| 117 |
+
return knobs_line, explain
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# API Endpoints
|
| 121 |
+
@app.get("/metrics")
|
| 122 |
+
def metrics_endpoint():
|
| 123 |
+
logs = []
|
| 124 |
+
try:
|
| 125 |
+
results = evaluate()
|
| 126 |
+
logs.append("✅ Ran evaluation set")
|
| 127 |
+
return {"status": "ok", "results": results, "logs": logs}
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logs.append(f"⚠️ Metrics failed: {e}")
|
| 130 |
+
return {"status": "error", "logs": logs}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@app.post("/add_doc")
|
| 134 |
+
def add_doc_endpoint(doc: DocInput):
|
| 135 |
+
logs = ["📥 Received document"]
|
| 136 |
+
text, source, timestamp = doc.text, doc.source, doc.timestamp
|
| 137 |
+
|
| 138 |
+
# 1) Chunk
|
| 139 |
+
chunks = chunk_text(text)
|
| 140 |
+
logs.append(f"✂️ Chunked into {len(chunks)} pieces")
|
| 141 |
+
|
| 142 |
+
# 2) Embed + store
|
| 143 |
+
embed_and_upsert(chunks, source=source, timestamp=timestamp.isoformat())
|
| 144 |
+
logs.append(f"🧮 Embedded + stored in Qdrant (source={source}, ts={timestamp})")
|
| 145 |
+
|
| 146 |
+
# 3) Extract triples and feed to Neo4j
|
| 147 |
+
neo4j_logs = extract_and_insert(chunks, source=source, timestamp=str(timestamp))
|
| 148 |
+
logs.extend(neo4j_logs or ["🌐 No entities/relations extracted for Neo4j"])
|
| 149 |
+
return {"status": "ok", "logs": logs}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@app.post("/ask")
|
| 153 |
+
def ask_endpoint(query: QuestionInput):
|
| 154 |
+
logs = []
|
| 155 |
+
q = query.question
|
| 156 |
+
logs.append(f"❓ Received question: {q}")
|
| 157 |
+
|
| 158 |
+
# Retrieve
|
| 159 |
+
candidates = search(q, top_k=5)
|
| 160 |
+
logs.append(f"🔎 Retrieved {len(candidates)} from Qdrant")
|
| 161 |
+
|
| 162 |
+
# Graph aware features??
|
| 163 |
+
for c in candidates:
|
| 164 |
+
c["path_proximity"] = compute_path_proximity(q, c["chunk"])
|
| 165 |
+
c["degree_norm"] = compute_degree_norm(c["chunk"])
|
| 166 |
+
c["freshness_decay"] = compute_freshness(c.get("timestamp"))
|
| 167 |
+
|
| 168 |
+
# Rerank
|
| 169 |
+
reranked, rerank_logs = rerank_candidates(
|
| 170 |
+
candidates,
|
| 171 |
+
w_cos=query.w_cos,
|
| 172 |
+
w_path=query.w_path,
|
| 173 |
+
w_fresh=query.w_fresh,
|
| 174 |
+
w_deg=query.w_deg,
|
| 175 |
+
)
|
| 176 |
+
logs.append("📊 Applied graph-aware re-ranking")
|
| 177 |
+
logs.extend(rerank_logs)
|
| 178 |
+
|
| 179 |
+
# Evidence subgraph (≤2 hops)
|
| 180 |
+
triples = get_subgraph(q, source=None)
|
| 181 |
+
logs.append(f"🌐 Subgraph triples: {len(triples)}")
|
| 182 |
+
|
| 183 |
+
# Prepare evidence numbering for citations
|
| 184 |
+
numbered = [(i + 1, c) for i, c in enumerate(reranked)]
|
| 185 |
+
TOP_N = 2 # TODO -> expermient with more
|
| 186 |
+
reranked = reranked[:TOP_N]
|
| 187 |
+
numbered = [(i + 1, c) for i, c in enumerate(reranked)]
|
| 188 |
+
evidence_for_prompt = [f"[E{i}] {c['chunk']}" for i, c in numbered]
|
| 189 |
+
evidence_for_ui = [f"[E{i}] {c['chunk']}" for i, c in numbered]
|
| 190 |
+
|
| 191 |
+
knobs_line, knobs_explain = _build_knobs_breakdown(
|
| 192 |
+
numbered, query.w_cos, query.w_path, query.w_fresh, query.w_deg
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# LLM answer (OpenAI, structured JSON -> Pydantic)
|
| 196 |
+
if reranked:
|
| 197 |
+
triples_text = "\n".join([f"({s}) -[{r}]-> ({o})" for s, r, o in triples])
|
| 198 |
+
|
| 199 |
+
# Schema friendly request
|
| 200 |
+
prompt = f"""
|
| 201 |
+
You are a precise QA assistant that MUST use BOTH the retrieved evidence and the graph triples.
|
| 202 |
+
|
| 203 |
+
Question:
|
| 204 |
+
{q}
|
| 205 |
+
|
| 206 |
+
Retrieved Evidence (ranked by importance, highest first):
|
| 207 |
+
{chr(10).join(evidence_for_prompt)}
|
| 208 |
+
|
| 209 |
+
Knowledge Graph Triples:
|
| 210 |
+
{triples_text}
|
| 211 |
+
|
| 212 |
+
Instructions:
|
| 213 |
+
- E1 is the most relevant, E2 is second-most, and so on.
|
| 214 |
+
- Prefer evidence with a lower number if multiple sources conflict.
|
| 215 |
+
- If supported, produce a single-sentence answer.
|
| 216 |
+
- Cite supporting evidence IDs (e.g., E1, E2).
|
| 217 |
+
- If the graph helped, say how; else "Not used".
|
| 218 |
+
- If not supported, return "I don’t know..." with Low confidence.
|
| 219 |
+
|
| 220 |
+
Return ONLY a JSON object matching this schema:
|
| 221 |
+
{{
|
| 222 |
+
"answer": "string",
|
| 223 |
+
"citations": ["E1","E2"],
|
| 224 |
+
"graph_reasoning": "string",
|
| 225 |
+
"confidence": "High|Medium|Low"
|
| 226 |
+
}}
|
| 227 |
+
""".strip()
|
| 228 |
+
|
| 229 |
+
logs.append("📝 Built prompt with evidence + graph")
|
| 230 |
+
try:
|
| 231 |
+
comp = client.chat.completions.create(
|
| 232 |
+
model="gpt-4o-mini",
|
| 233 |
+
messages=[
|
| 234 |
+
{"role": "system", "content": "Respond ONLY with a JSON object."},
|
| 235 |
+
{"role": "user", "content": prompt},
|
| 236 |
+
],
|
| 237 |
+
# Ensures valid JSON
|
| 238 |
+
response_format={"type": "json_object"},
|
| 239 |
+
temperature=0,
|
| 240 |
+
max_tokens=300,
|
| 241 |
+
)
|
| 242 |
+
raw_json = comp.choices[0].message.content or "{}"
|
| 243 |
+
data = json.loads(raw_json)
|
| 244 |
+
|
| 245 |
+
# Validate and normalize with Pydantic
|
| 246 |
+
parsed = LLMAnswer.model_validate(data)
|
| 247 |
+
|
| 248 |
+
# Build display string for the UI card
|
| 249 |
+
citations_txt = ", ".join(parsed.citations) if parsed.citations else "None"
|
| 250 |
+
answer_text = (
|
| 251 |
+
f"{parsed.answer}\n"
|
| 252 |
+
f"Citations: {citations_txt}\n"
|
| 253 |
+
f"Graph reasoning: {parsed.graph_reasoning or '—'}\n"
|
| 254 |
+
f"Confidence: {parsed.confidence}\n"
|
| 255 |
+
f"Knobs: {knobs_line or '—'}\n"
|
| 256 |
+
f"Knobs explain: {knobs_explain or '—'}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
answer = answer_text
|
| 260 |
+
logs.append("🤖 Called OpenAI")
|
| 261 |
+
logs.append("🧠 Generated final answer")
|
| 262 |
+
except Exception as e:
|
| 263 |
+
top_chunk = reranked[0]["chunk"] if reranked else "No evidence"
|
| 264 |
+
answer = (
|
| 265 |
+
f"Based on evidence: {top_chunk}\n"
|
| 266 |
+
f"Citations: None\n"
|
| 267 |
+
f"Graph reasoning: Not used\n"
|
| 268 |
+
f"Confidence: Low\n"
|
| 269 |
+
f"Knobs: {knobs_line or '—'}\n"
|
| 270 |
+
f"Knobs explain: {knobs_explain or '—'}"
|
| 271 |
+
)
|
| 272 |
+
logs.append(f"⚠️ OpenAI failed, fallback to stub ({e})")
|
| 273 |
+
else:
|
| 274 |
+
answer = (
|
| 275 |
+
"No evidence found.\n"
|
| 276 |
+
"Citations: None\n"
|
| 277 |
+
"Graph reasoning: Not used\n"
|
| 278 |
+
"Confidence: Low\n"
|
| 279 |
+
f"Knobs: {knobs_line or '—'}\n"
|
| 280 |
+
f"Knobs explain: {knobs_explain or '—'}"
|
| 281 |
+
)
|
| 282 |
+
evidence_for_ui = []
|
| 283 |
+
logs.append("⚠️ No evidence, answer is empty")
|
| 284 |
+
|
| 285 |
+
# Build D3 JSON
|
| 286 |
+
node_map = {}
|
| 287 |
+
links = []
|
| 288 |
+
for s, r, o in triples:
|
| 289 |
+
node_map.setdefault(s, {"id": s})
|
| 290 |
+
node_map.setdefault(o, {"id": o})
|
| 291 |
+
links.append({"source": s, "target": o, "label": r})
|
| 292 |
+
subgraph_json = {"nodes": list(node_map.values()), "links": links}
|
| 293 |
+
|
| 294 |
+
# Server side SVG fallback in case D3 fails to render
|
| 295 |
+
import networkx as nx
|
| 296 |
+
|
| 297 |
+
G = nx.DiGraph()
|
| 298 |
+
for s, r, o in triples:
|
| 299 |
+
G.add_node(s)
|
| 300 |
+
G.add_node(o)
|
| 301 |
+
G.add_edge(s, o, label=r)
|
| 302 |
+
|
| 303 |
+
pos = nx.spring_layout(G, seed=42)
|
| 304 |
+
width, height, pad = 720, 420, 40
|
| 305 |
+
xs = [p[0] for p in pos.values()] or [0.0]
|
| 306 |
+
ys = [p[1] for p in pos.values()] or [0.0]
|
| 307 |
+
minx, maxx = min(xs), max(xs)
|
| 308 |
+
miny, maxy = min(ys), max(ys)
|
| 309 |
+
rangex = (maxx - minx) or 1.0
|
| 310 |
+
rangey = (maxy - miny) or 1.0
|
| 311 |
+
|
| 312 |
+
def sx(x): return pad + (x - minx) / rangex * (width - 2 * pad)
|
| 313 |
+
def sy(y): return pad + (y - miny) / rangey * (height - 2 * pad)
|
| 314 |
+
|
| 315 |
+
parts = []
|
| 316 |
+
parts.append(
|
| 317 |
+
f'<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" '
|
| 318 |
+
f'xmlns="http://www.w3.org/2000/svg">'
|
| 319 |
+
)
|
| 320 |
+
parts.append(
|
| 321 |
+
"""
|
| 322 |
+
<defs>
|
| 323 |
+
<marker id="arrow" markerUnits="strokeWidth" markerWidth="10" markerHeight="8"
|
| 324 |
+
viewBox="0 0 10 8" refX="10" refY="4" orient="auto">
|
| 325 |
+
<path d="M0 0 L10 4 L0 8 z" fill="#999"/>
|
| 326 |
+
</marker>
|
| 327 |
+
<style>
|
| 328 |
+
.edge { stroke:#999; stroke-width:1.5; }
|
| 329 |
+
.nodelabel { font:12px sans-serif; fill:#ddd; }
|
| 330 |
+
.edgelabel { font:10px sans-serif; fill:#bbb; }
|
| 331 |
+
.node { fill:#69b3a2; stroke:#2dd4bf; stroke-width:1; }
|
| 332 |
+
</style>
|
| 333 |
+
</defs>
|
| 334 |
+
"""
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
for u, v, data in G.edges(data=True):
|
| 338 |
+
x1, y1 = sx(pos[u][0]), sy(pos[u][1])
|
| 339 |
+
x2, y2 = sx(pos[v][0]), sy(pos[v][1])
|
| 340 |
+
parts.append(
|
| 341 |
+
f'<line class="edge" x1="{x1:.1f}" y1="{y1:.1f}" '
|
| 342 |
+
f'x2="{x2:.1f}" y2="{y2:.1f}" marker-end="url(#arrow)"/>'
|
| 343 |
+
)
|
| 344 |
+
mx, my = (x1 + x2) / 2.0, (y1 + y2) / 2.0
|
| 345 |
+
lbl = (data.get("label") or "").replace("&", "&").replace("<", "<")
|
| 346 |
+
parts.append(
|
| 347 |
+
f'<text class="edgelabel" x="{mx:.1f}" y="{my:.1f}" text-anchor="middle">{lbl}</text>'
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
for n in G.nodes():
|
| 351 |
+
x, y = sx(pos[n][0]), sy(pos[n][1])
|
| 352 |
+
node_txt = str(n).replace("&", "&").replace("<", "<")
|
| 353 |
+
r = max(16, len(node_txt) * 4)
|
| 354 |
+
parts.append(f'<circle class="node" cx="{x:.1f}" cy="{y:.1f}" r="{r}"/>')
|
| 355 |
+
parts.append(
|
| 356 |
+
f'<text class="nodelabel" x="{x:.1f}" y="{y + r + 14:.1f}" text-anchor="middle">{node_txt}</text>'
|
| 357 |
+
)
|
| 358 |
+
parts.append("</svg>")
|
| 359 |
+
subgraph_svg = "".join(parts)
|
| 360 |
+
|
| 361 |
+
logs.append(f"📦 Subgraph JSON dump: {subgraph_json}")
|
| 362 |
+
|
| 363 |
+
return {
|
| 364 |
+
"answer": answer,
|
| 365 |
+
"evidence": evidence_for_ui,
|
| 366 |
+
"subgraph_svg": subgraph_svg, # fallback
|
| 367 |
+
"subgraph_json": subgraph_json, # for D3 in UI
|
| 368 |
+
"logs": logs,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@app.get("/healthz")
|
| 373 |
+
def healthz():
|
| 374 |
+
return {"ok": True}
|
eval.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from vec import search
|
| 3 |
+
from rerank import rerank_candidates
|
| 4 |
+
from kg import compute_path_proximity, compute_degree_norm
|
| 5 |
+
from utils import compute_freshness
|
| 6 |
+
|
| 7 |
+
# A toy dataset: {question: [expected substrings]}
|
| 8 |
+
EVAL_SET = [
|
| 9 |
+
{"q": "Who founded OpenAI?", "answers": ["Elon Musk", "Sam Altman"]},
|
| 10 |
+
{"q": "What did OpenAI release?", "answers": ["GPT-4o", "Whisper", "SORA"]},
|
| 11 |
+
{"q": "What did Google acquire?", "answers": ["YouTube", "Instagram"]},
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
def evaluate():
|
| 15 |
+
baseline_hits, hybrid_hits = [], []
|
| 16 |
+
baseline_ndcg, hybrid_ndcg = [], []
|
| 17 |
+
citation_correctness = []
|
| 18 |
+
latencies = []
|
| 19 |
+
|
| 20 |
+
for item in EVAL_SET:
|
| 21 |
+
q, golds = item["q"], item["answers"]
|
| 22 |
+
|
| 23 |
+
# Baseline (cosine only)
|
| 24 |
+
start = time.time()
|
| 25 |
+
baseline = search(q, top_k=10)
|
| 26 |
+
latencies.append(time.time() - start)
|
| 27 |
+
|
| 28 |
+
# Did we hit a gold in top-10?
|
| 29 |
+
hit = any(any(g.lower() in c["chunk"].lower() for g in golds) for c in baseline)
|
| 30 |
+
baseline_hits.append(1 if hit else 0)
|
| 31 |
+
|
| 32 |
+
# nDCG@10
|
| 33 |
+
scores = []
|
| 34 |
+
for rank, c in enumerate(baseline, 1):
|
| 35 |
+
rel = 1 if any(g.lower() in c["chunk"].lower() for g in golds) else 0
|
| 36 |
+
if rel:
|
| 37 |
+
scores.append(1 / (rank))
|
| 38 |
+
baseline_ndcg.append(sum(scores))
|
| 39 |
+
|
| 40 |
+
# Hybrid (cosine + path + freshness + degree)
|
| 41 |
+
for c in baseline:
|
| 42 |
+
c["path_proximity"] = compute_path_proximity(q, c["chunk"])
|
| 43 |
+
c["degree_norm"] = compute_degree_norm(c["chunk"])
|
| 44 |
+
c["freshness_decay"] = compute_freshness(c.get("timestamp"))
|
| 45 |
+
|
| 46 |
+
reranked, _ = rerank_candidates(baseline)
|
| 47 |
+
hit = any(any(g.lower() in c["chunk"].lower() for g in golds) for c in reranked[:10])
|
| 48 |
+
hybrid_hits.append(1 if hit else 0)
|
| 49 |
+
|
| 50 |
+
scores = []
|
| 51 |
+
for rank, c in enumerate(reranked, 1):
|
| 52 |
+
rel = 1 if any(g.lower() in c["chunk"].lower() for g in golds) else 0
|
| 53 |
+
if rel:
|
| 54 |
+
scores.append(1 / (rank))
|
| 55 |
+
hybrid_ndcg.append(sum(scores))
|
| 56 |
+
|
| 57 |
+
# TODO -> Citation correctness
|
| 58 |
+
citation_correctness.append(1)
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"baseline": {
|
| 62 |
+
"hit@10": sum(baseline_hits)/len(baseline_hits),
|
| 63 |
+
"nDCG@10": sum(baseline_ndcg)/len(baseline_ndcg),
|
| 64 |
+
},
|
| 65 |
+
"hybrid": {
|
| 66 |
+
"hit@10": sum(hybrid_hits)/len(hybrid_hits),
|
| 67 |
+
"nDCG@10": sum(hybrid_ndcg)/len(hybrid_ndcg),
|
| 68 |
+
},
|
| 69 |
+
"citation_correctness": sum(citation_correctness)/len(citation_correctness),
|
| 70 |
+
"avg_latency_sec": sum(latencies)/len(latencies),
|
| 71 |
+
}
|
kg.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re, json
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from neo4j import GraphDatabase
|
| 4 |
+
import spacy
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from groq import Groq
|
| 8 |
+
except Exception:
|
| 9 |
+
Groq = None
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
nlp = spacy.load("en_core_web_sm")
|
| 13 |
+
|
| 14 |
+
uri = os.getenv("NEO4J_URI")
|
| 15 |
+
user = os.getenv("NEO4J_USER")
|
| 16 |
+
password = os.getenv("NEO4J_PASS")
|
| 17 |
+
database = os.getenv("NEO4J_DATABASE", "neo4j")
|
| 18 |
+
|
| 19 |
+
# Options for "spacy" and "groq"
|
| 20 |
+
KG_EXTRACTOR = os.getenv("KG_EXTRACTOR", "spacy").strip().lower()
|
| 21 |
+
|
| 22 |
+
# Groq config (to extract the triplets)
|
| 23 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "").strip()
|
| 24 |
+
GROQ_MODEL = os.getenv("GROQ_MODEL", "openai/gpt-oss-20b").strip()
|
| 25 |
+
_groq = Groq(api_key=GROQ_API_KEY) if (Groq and GROQ_API_KEY) else None
|
| 26 |
+
|
| 27 |
+
driver = GraphDatabase.driver(uri, auth=(user, password))
|
| 28 |
+
|
| 29 |
+
ORG_HINTS = {
|
| 30 |
+
"inc","corp","corporation","ltd","llc","bank","securities","university",
|
| 31 |
+
"labs","institute","tech","technologies","systems","solutions","group"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# scoring helpers
|
| 35 |
+
def compute_path_proximity(question, chunk):
|
| 36 |
+
q_doc = nlp(question)
|
| 37 |
+
c_doc = nlp(chunk)
|
| 38 |
+
q_ents = {ent.text.lower() for ent in q_doc.ents}
|
| 39 |
+
c_ents = {ent.text.lower() for ent in c_doc.ents}
|
| 40 |
+
return 1.0 if q_ents & c_ents else 0.0
|
| 41 |
+
|
| 42 |
+
def compute_degree_norm(chunk):
|
| 43 |
+
doc = nlp(chunk)
|
| 44 |
+
ents = [ent.text for ent in doc.ents]
|
| 45 |
+
if not ents:
|
| 46 |
+
return 0.0
|
| 47 |
+
degrees = []
|
| 48 |
+
with driver.session(database=database) as session:
|
| 49 |
+
for e in ents:
|
| 50 |
+
res = session.run("""
|
| 51 |
+
MATCH (n {name_lc:$name_lc})
|
| 52 |
+
RETURN count { (n)--() } AS deg
|
| 53 |
+
LIMIT 1
|
| 54 |
+
""", name_lc=normalize_key(e))
|
| 55 |
+
rec = res.single()
|
| 56 |
+
if rec and rec["deg"] is not None:
|
| 57 |
+
degrees.append(rec["deg"])
|
| 58 |
+
if not degrees:
|
| 59 |
+
return 0.0
|
| 60 |
+
return 1.0 / (1 + sum(degrees) / len(degrees))
|
| 61 |
+
|
| 62 |
+
# label + text helpers
|
| 63 |
+
def get_label(ent_label: str) -> str:
|
| 64 |
+
mapping = {
|
| 65 |
+
"PERSON": "Person","ORG": "Org","GPE": "Location","NORP": "Group",
|
| 66 |
+
"FAC": "Facility","LOC": "Location","PRODUCT": "Product","EVENT": "Event",
|
| 67 |
+
"WORK_OF_ART": "Work","LAW": "Law","LANGUAGE": "Language","DATE": "Date",
|
| 68 |
+
"TIME": "Time","PERCENT": "Percent","MONEY": "Money","QUANTITY": "Quantity",
|
| 69 |
+
"ORDINAL": "Ordinal","CARDINAL": "Number",
|
| 70 |
+
}
|
| 71 |
+
return mapping.get(ent_label, "Entity")
|
| 72 |
+
|
| 73 |
+
def normalize_key(text: str) -> str:
|
| 74 |
+
return re.sub(r"\s+", " ", text.strip()).lower()
|
| 75 |
+
|
| 76 |
+
def entity_label_for_text(text: str, doc) -> str:
|
| 77 |
+
for ent in doc.ents:
|
| 78 |
+
if ent.text == text:
|
| 79 |
+
return get_label(ent.label_)
|
| 80 |
+
words = {w.lower() for w in text.split()}
|
| 81 |
+
if words & ORG_HINTS:
|
| 82 |
+
return "Org"
|
| 83 |
+
if text and text[0].isupper() and " " not in text:
|
| 84 |
+
return "Person"
|
| 85 |
+
return "Entity"
|
| 86 |
+
|
| 87 |
+
def span_text_for_token(tok, doc) -> str:
|
| 88 |
+
for ent in doc.ents:
|
| 89 |
+
if ent.start <= tok.i < ent.end:
|
| 90 |
+
return ent.text
|
| 91 |
+
left = [w for w in tok.lefts if w.dep_ in ("compound","amod","flat","nmod")]
|
| 92 |
+
right = [w for w in tok.rights if w.dep_ in ("compound","flat","nmod")]
|
| 93 |
+
tokens = sorted([*left, tok, *right], key=lambda t: t.i)
|
| 94 |
+
return " ".join(t.text for t in tokens if t.pos_ != "PUNCT")
|
| 95 |
+
|
| 96 |
+
def subjects_for_verb(v):
|
| 97 |
+
subs = [w for w in v.lefts if w.dep_ in ("nsubj","nsubjpass","csubj")]
|
| 98 |
+
if not subs and v.dep_ == "conj":
|
| 99 |
+
subs = subjects_for_verb(v.head)
|
| 100 |
+
out = []
|
| 101 |
+
for s in subs:
|
| 102 |
+
out.append(s)
|
| 103 |
+
out.extend(list(s.conjuncts))
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
def objects_for_verb(v):
|
| 107 |
+
objs = [w for w in v.rights if w.dep_ in ("dobj","attr","pobj","dative","oprd")]
|
| 108 |
+
for prep in [w for w in v.rights if w.dep_ == "prep"]:
|
| 109 |
+
objs.extend([w for w in prep.rights if w.dep_ == "pobj"])
|
| 110 |
+
out = []
|
| 111 |
+
for o in objs:
|
| 112 |
+
out.append(o)
|
| 113 |
+
out.extend(list(o.conjuncts))
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
# spaCy extractor
|
| 117 |
+
def _extract_triples_spacy(text: str):
|
| 118 |
+
triples = []
|
| 119 |
+
doc = nlp(text)
|
| 120 |
+
for tok in doc:
|
| 121 |
+
if tok.pos_ == "VERB":
|
| 122 |
+
subs = subjects_for_verb(tok)
|
| 123 |
+
objs = objects_for_verb(tok)
|
| 124 |
+
if not subs or not objs:
|
| 125 |
+
continue
|
| 126 |
+
rel = tok.lemma_.upper()
|
| 127 |
+
for s in subs:
|
| 128 |
+
s_text = span_text_for_token(s, doc)
|
| 129 |
+
s_label = entity_label_for_text(s_text, doc)
|
| 130 |
+
for o in objs:
|
| 131 |
+
o_text = span_text_for_token(o, doc)
|
| 132 |
+
o_label = entity_label_for_text(o_text, doc)
|
| 133 |
+
triples.append({
|
| 134 |
+
"subject": s_text, "subject_label": s_label,
|
| 135 |
+
"relation": rel,
|
| 136 |
+
"object": o_text, "object_label": o_label
|
| 137 |
+
})
|
| 138 |
+
return triples
|
| 139 |
+
|
| 140 |
+
# Groq extractor (structured)
|
| 141 |
+
_GROQ_SCHEMA = {
|
| 142 |
+
"type": "object",
|
| 143 |
+
"properties": {
|
| 144 |
+
"triples": {
|
| 145 |
+
"type": "array",
|
| 146 |
+
"items": {
|
| 147 |
+
"type": "object",
|
| 148 |
+
"properties": {
|
| 149 |
+
"subject": {"type": "string"},
|
| 150 |
+
"subject_label": {"type": "string"},
|
| 151 |
+
"relation": {"type": "string"},
|
| 152 |
+
"object": {"type": "string"},
|
| 153 |
+
"object_label": {"type": "string"},
|
| 154 |
+
},
|
| 155 |
+
"required": ["subject", "relation", "object"]
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
},
|
| 159 |
+
"required": ["triples"]
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
_ALLOWED_LABELS = (
|
| 163 |
+
"Person","Org","Location","Group","Facility","Product","Event","Work",
|
| 164 |
+
"Law","Language","Date","Time","Percent","Money","Quantity","Ordinal",
|
| 165 |
+
"Number","Entity"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def _extract_triples_groq(text: str):
|
| 169 |
+
if not _groq:
|
| 170 |
+
raise RuntimeError("Groq not configured. Set GROQ_API_KEY or use KG_EXTRACTOR=spacy.")
|
| 171 |
+
prompt = f"""
|
| 172 |
+
Extract concise subject-RELATION-object triples from the text.
|
| 173 |
+
|
| 174 |
+
Rules:
|
| 175 |
+
- Use a SINGLE UPPERCASE token for relation (e.g., ACQUIRE, FOUND, PARTNER_WITH).
|
| 176 |
+
- Provide subject_label and object_label using this set: {_ALLOWED_LABELS}.
|
| 177 |
+
- Merge duplicates; at most 8 triples per chunk.
|
| 178 |
+
- Return ONLY JSON matching this schema:
|
| 179 |
+
|
| 180 |
+
{json.dumps(_GROQ_SCHEMA, indent=2)}
|
| 181 |
+
|
| 182 |
+
Text:
|
| 183 |
+
\"\"\"{text}\"\"\"
|
| 184 |
+
"""
|
| 185 |
+
resp = _groq.chat.completions.create(
|
| 186 |
+
model=GROQ_MODEL,
|
| 187 |
+
messages=[
|
| 188 |
+
{"role": "system", "content": "You are an information extractor. Output strictly valid JSON."},
|
| 189 |
+
{"role": "user", "content": prompt},
|
| 190 |
+
],
|
| 191 |
+
temperature=0,
|
| 192 |
+
max_tokens=600,
|
| 193 |
+
response_format={"type": "json_object"},
|
| 194 |
+
)
|
| 195 |
+
raw = resp.choices[0].message.content or "{}"
|
| 196 |
+
try:
|
| 197 |
+
data = json.loads(raw)
|
| 198 |
+
except Exception:
|
| 199 |
+
start = raw.find("{"); end = raw.rfind("}")
|
| 200 |
+
data = json.loads(raw[start:end+1]) if start != -1 and end != -1 else {"triples": []}
|
| 201 |
+
triples = data.get("triples", [])
|
| 202 |
+
out = []
|
| 203 |
+
for t in triples:
|
| 204 |
+
subj = (t.get("subject") or "").strip()
|
| 205 |
+
obj = (t.get("object") or "").strip()
|
| 206 |
+
rel = (t.get("relation") or "").strip().upper().replace(" ", "_")
|
| 207 |
+
if not subj or not obj or not rel:
|
| 208 |
+
continue
|
| 209 |
+
sl = t.get("subject_label") or "Entity"
|
| 210 |
+
ol = t.get("object_label") or "Entity"
|
| 211 |
+
if sl not in _ALLOWED_LABELS: sl = "Entity"
|
| 212 |
+
if ol not in _ALLOWED_LABELS: ol = "Entity"
|
| 213 |
+
out.append({"subject": subj, "subject_label": sl, "relation": rel,
|
| 214 |
+
"object": obj, "object_label": ol})
|
| 215 |
+
return out
|
| 216 |
+
|
| 217 |
+
# Inserts
|
| 218 |
+
def _insert_triples(triples, source: str, timestamp: str, logs: list):
|
| 219 |
+
if not triples:
|
| 220 |
+
return
|
| 221 |
+
with driver.session(database=database) as session:
|
| 222 |
+
for t in triples:
|
| 223 |
+
s_text, s_label = t["subject"], t["subject_label"]
|
| 224 |
+
o_text, o_label = t["object"], t["object_label"]
|
| 225 |
+
rel = t["relation"]
|
| 226 |
+
cypher = f"""
|
| 227 |
+
MERGE (a:{s_label} {{name_lc:$a_key}})
|
| 228 |
+
ON CREATE SET a.name = $a_name
|
| 229 |
+
MERGE (b:{o_label} {{name_lc:$b_key}})
|
| 230 |
+
ON CREATE SET b.name = $b_name
|
| 231 |
+
MERGE (a)-[r:{rel}]->(b)
|
| 232 |
+
ON CREATE SET r.source=$source, r.ts=$ts
|
| 233 |
+
SET r.source=$source, r.ts=$ts
|
| 234 |
+
"""
|
| 235 |
+
session.run(
|
| 236 |
+
cypher,
|
| 237 |
+
a_key=normalize_key(s_text), a_name=s_text,
|
| 238 |
+
b_key=normalize_key(o_text), b_name=o_text,
|
| 239 |
+
source=source, ts=timestamp
|
| 240 |
+
)
|
| 241 |
+
logs.append(f"🌐 Inserted ({s_text}:{s_label})-[:{rel}]->({o_text}:{o_label}) [src={source}, ts={timestamp}]")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# Public API used by app.py
|
| 245 |
+
def extract_and_insert(chunks, source="user", timestamp=None):
|
| 246 |
+
"""
|
| 247 |
+
Extract triples (Groq or spaCy) and insert into Neo4j with source and timestamp.
|
| 248 |
+
"""
|
| 249 |
+
from datetime import datetime
|
| 250 |
+
if timestamp is None:
|
| 251 |
+
timestamp = datetime.now().isoformat()
|
| 252 |
+
|
| 253 |
+
logs = []
|
| 254 |
+
use_groq = (KG_EXTRACTOR == "groq") and _groq is not None
|
| 255 |
+
for chunk in chunks:
|
| 256 |
+
triples = []
|
| 257 |
+
if use_groq:
|
| 258 |
+
try:
|
| 259 |
+
triples = _extract_triples_groq(chunk)
|
| 260 |
+
logs.append(f"🤝 Groq extracted {len(triples)} triples")
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logs.append(f"⚠️ Groq extraction failed: {e}. Falling back to spaCy.")
|
| 263 |
+
if not triples:
|
| 264 |
+
triples = _extract_triples_spacy(chunk)
|
| 265 |
+
logs.append(f"🧠 spaCy extracted {len(triples)} triples")
|
| 266 |
+
_insert_triples(triples, source=source, timestamp=timestamp, logs=logs)
|
| 267 |
+
return logs
|
| 268 |
+
|
| 269 |
+
def test_connection():
|
| 270 |
+
with driver.session(database=database) as session:
|
| 271 |
+
msg = session.run("RETURN 'Connected to Neo4j!' AS msg").single()
|
| 272 |
+
print(msg["msg"])
|
| 273 |
+
|
| 274 |
+
def get_subgraph(question: str, source: str | None = None, limit: int = 24, evidence_chunks: list[str] | None = None):
|
| 275 |
+
"""
|
| 276 |
+
Subgraph search by using NER (spaCy) and words from evidence chunks.
|
| 277 |
+
"""
|
| 278 |
+
results = set()
|
| 279 |
+
doc = nlp(question)
|
| 280 |
+
entities = [ent.text for ent in doc.ents] or [t.text for t in doc if t.pos_ in ("PROPN","NOUN")]
|
| 281 |
+
|
| 282 |
+
if evidence_chunks:
|
| 283 |
+
for chunk in evidence_chunks:
|
| 284 |
+
c_doc = nlp(chunk)
|
| 285 |
+
entities.extend([ent.text for ent in c_doc.ents])
|
| 286 |
+
|
| 287 |
+
# Normalize + deduplicate
|
| 288 |
+
entities = list({normalize_key(e) for e in entities if e.strip()})
|
| 289 |
+
|
| 290 |
+
with driver.session(database=database) as session:
|
| 291 |
+
for e in entities:
|
| 292 |
+
cypher = """
|
| 293 |
+
MATCH (a)-[r]-(b)
|
| 294 |
+
WHERE (
|
| 295 |
+
(a.name_lc IS NOT NULL AND a.name_lc CONTAINS $k) OR
|
| 296 |
+
(b.name_lc IS NOT NULL AND b.name_lc CONTAINS $k) OR
|
| 297 |
+
toLower(a.name) CONTAINS $k OR
|
| 298 |
+
toLower(b.name) CONTAINS $k
|
| 299 |
+
)
|
| 300 |
+
""" + (" AND r.source = $source " if source else "") + """
|
| 301 |
+
RETURN DISTINCT a, type(r) AS rel, b
|
| 302 |
+
LIMIT $limit
|
| 303 |
+
"""
|
| 304 |
+
params = {"k": e, "limit": limit}
|
| 305 |
+
if source:
|
| 306 |
+
params["source"] = source
|
| 307 |
+
|
| 308 |
+
for rec in session.run(cypher, **params):
|
| 309 |
+
a, rel, b = rec["a"], rec["rel"], rec["b"]
|
| 310 |
+
results.add((
|
| 311 |
+
a.get("name", a.get("name_lc","")),
|
| 312 |
+
rel,
|
| 313 |
+
b.get("name", b.get("name_lc",""))
|
| 314 |
+
))
|
| 315 |
+
return list(results)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.45.0
|
| 2 |
+
fastapi==0.111.0
|
| 3 |
+
uvicorn==0.30.1
|
| 4 |
+
neo4j==5.23.0
|
| 5 |
+
qdrant-client==1.9.1
|
| 6 |
+
sentence-transformers==2.7.0
|
| 7 |
+
openai>=1.40.2
|
| 8 |
+
httpx>=0.27.2
|
| 9 |
+
python-dotenv==1.0.1
|
| 10 |
+
spacy==3.8.7
|
| 11 |
+
groq
|
| 12 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
|
rerank.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def rerank_candidates(candidates, w_cos=0.60, w_path=0.20, w_fresh=0.15, w_deg=0.05):
|
| 2 |
+
"""
|
| 3 |
+
Rerank chunks with a hybrid scoring formula.
|
| 4 |
+
Weights are configurable from the ui.
|
| 5 |
+
"""
|
| 6 |
+
reranked = []
|
| 7 |
+
logs = []
|
| 8 |
+
|
| 9 |
+
for idx, c in enumerate(candidates, 1):
|
| 10 |
+
score = (
|
| 11 |
+
w_cos * c.get("cosine", 0) +
|
| 12 |
+
w_path * c.get("path_proximity", 0) +
|
| 13 |
+
w_fresh * c.get("freshness_decay", 0) +
|
| 14 |
+
w_deg * c.get("degree_norm", 0)
|
| 15 |
+
)
|
| 16 |
+
c["final_score"] = score
|
| 17 |
+
reranked.append(c)
|
| 18 |
+
|
| 19 |
+
logs.append(
|
| 20 |
+
f"Candidate {idx}: "
|
| 21 |
+
f"cosine={c.get('cosine',0):.3f}, "
|
| 22 |
+
f"path={c.get('path_proximity',0):.3f}, "
|
| 23 |
+
f"freshness={c.get('freshness_decay',0):.3f}, "
|
| 24 |
+
f"degree={c.get('degree_norm',0):.3f} "
|
| 25 |
+
f"→ final={score:.3f}"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
reranked.sort(key=lambda x: x["final_score"], reverse=True)
|
| 29 |
+
return reranked, logs
|
| 30 |
+
|
text.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def chunk_text(text: str, chunk_size : int = 200, overlap: int = 50) -> str:
|
| 2 |
+
"""
|
| 3 |
+
Here, we will break the text into overlapping chunks and then feed
|
| 4 |
+
them to the embedding pipeline
|
| 5 |
+
"""
|
| 6 |
+
chunks = []
|
| 7 |
+
start = 0
|
| 8 |
+
while start < len(text):
|
| 9 |
+
end = start + chunk_size
|
| 10 |
+
curr_chunk = text[start:end]
|
| 11 |
+
chunks.append(curr_chunk)
|
| 12 |
+
start += chunk_size - overlap # we need some overlap between the chunks
|
| 13 |
+
return chunks
|
ui.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import html
|
| 4 |
+
import os
|
| 5 |
+
from datetime import date as _date
|
| 6 |
+
|
| 7 |
+
API_URL = os.getenv("API_URL", "http://127.0.0.1:8000")
|
| 8 |
+
|
| 9 |
+
# Helpers
|
| 10 |
+
def _parse_answer_sections(answer_text: str):
|
| 11 |
+
lines = [l.strip() for l in (answer_text or "").splitlines() if l.strip()]
|
| 12 |
+
out = {
|
| 13 |
+
"main": "",
|
| 14 |
+
"citations": "",
|
| 15 |
+
"graph_reasoning": "",
|
| 16 |
+
"confidence": "",
|
| 17 |
+
"knobs": "",
|
| 18 |
+
"knobs_explain": "",
|
| 19 |
+
}
|
| 20 |
+
main_parts = []
|
| 21 |
+
for ln in lines:
|
| 22 |
+
ll = ln.lower()
|
| 23 |
+
if ll.startswith("citations:"):
|
| 24 |
+
out["citations"] = ln.split(":", 1)[1].strip()
|
| 25 |
+
elif ll.startswith("graph reasoning:") or ll.startswith("graphreasoning:"):
|
| 26 |
+
out["graph_reasoning"] = ln.split(":", 1)[1].strip()
|
| 27 |
+
elif ll.startswith("confidence:"):
|
| 28 |
+
out["confidence"] = ln.split(":", 1)[1].strip()
|
| 29 |
+
elif ll.startswith("knobs explain:"):
|
| 30 |
+
out["knobs_explain"] = ln.split(":", 1)[1].strip()
|
| 31 |
+
elif ll.startswith("knobs:"):
|
| 32 |
+
out["knobs"] = ln.split(":", 1)[1].strip()
|
| 33 |
+
else:
|
| 34 |
+
main_parts.append(ln)
|
| 35 |
+
out["main"] = " ".join(main_parts).strip() or (answer_text or "").strip()
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _confidence_class(conf: str) -> str:
|
| 40 |
+
c = (conf or "").strip().lower()
|
| 41 |
+
if c.startswith("high"):
|
| 42 |
+
return "badge-high"
|
| 43 |
+
if c.startswith("medium"):
|
| 44 |
+
return "badge-medium"
|
| 45 |
+
if c.startswith("low"):
|
| 46 |
+
return "badge-low"
|
| 47 |
+
return "badge-none"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _render_answer_card(answer_text: str) -> str:
|
| 51 |
+
sec = _parse_answer_sections(answer_text)
|
| 52 |
+
conf_cls = _confidence_class(sec["confidence"])
|
| 53 |
+
main = html.escape(sec["main"])
|
| 54 |
+
citations = html.escape(sec["citations"] or "None")
|
| 55 |
+
greason = html.escape(sec["graph_reasoning"] or "—")
|
| 56 |
+
conf = html.escape(sec["confidence"] or "—")
|
| 57 |
+
knobs = html.escape(sec["knobs"] or "—")
|
| 58 |
+
knobs_explain = html.escape(sec["knobs_explain"] or "—")
|
| 59 |
+
return f"""
|
| 60 |
+
<div class="card">
|
| 61 |
+
<div class="card-title">Answer</div>
|
| 62 |
+
<div class="answer">{main}</div>
|
| 63 |
+
<div class="meta">
|
| 64 |
+
<span class="badge {conf_cls}">{conf}</span>
|
| 65 |
+
</div>
|
| 66 |
+
<div class="sub"><b>Citations:</b> {citations}</div>
|
| 67 |
+
<div class="sub"><b>Graph reasoning:</b> {greason}</div>
|
| 68 |
+
<div class="sub"><b>Knobs effect:</b> {knobs}</div>
|
| 69 |
+
<div class="sub"><b>Knobs explain:</b> {knobs_explain}</div>
|
| 70 |
+
</div>
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _render_evidence_markdown(evidence_list):
|
| 75 |
+
if not evidence_list:
|
| 76 |
+
return "_No evidence returned._"
|
| 77 |
+
lines = []
|
| 78 |
+
for i, chunk in enumerate(evidence_list, 1):
|
| 79 |
+
chunk = chunk.strip()
|
| 80 |
+
lines.append(f"**E{i}.** {chunk}")
|
| 81 |
+
return "\n\n".join(lines)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _wrap_svg(svg: str) -> str:
|
| 85 |
+
if not svg or "<svg" not in svg:
|
| 86 |
+
return "<div class='graph-empty'>No graph</div>"
|
| 87 |
+
return f"""<div class="graph-wrap">{svg}</div>"""
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def metrics_ui():
|
| 91 |
+
resp = requests.get(f"{API_URL}/metrics")
|
| 92 |
+
try:
|
| 93 |
+
j = resp.json()
|
| 94 |
+
if j.get("status") != "ok":
|
| 95 |
+
return f"Error: {j}"
|
| 96 |
+
r = j["results"]
|
| 97 |
+
return f"""
|
| 98 |
+
### 📊 Evaluation Results
|
| 99 |
+
|
| 100 |
+
**Baseline (cosine-only)**
|
| 101 |
+
- hit@10: {r['baseline']['hit@10']:.2f}
|
| 102 |
+
- nDCG@10: {r['baseline']['nDCG@10']:.2f}
|
| 103 |
+
|
| 104 |
+
**Hybrid (GraphRAG)**
|
| 105 |
+
- hit@10: {r['hybrid']['hit@10']:.2f}
|
| 106 |
+
- nDCG@10: {r['hybrid']['nDCG@10']:.2f}
|
| 107 |
+
|
| 108 |
+
**Other**
|
| 109 |
+
- Citation correctness: {r['citation_correctness']:.2f}
|
| 110 |
+
- Avg latency (s): {r['avg_latency_sec']:.2f}
|
| 111 |
+
"""
|
| 112 |
+
except Exception as e:
|
| 113 |
+
return f"Error: {e}\nRaw: {resp.text[:500]}"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def add_doc_ui(text, source="user", date_val=None, time_val=None):
|
| 117 |
+
payload = {"text": text, "source": source}
|
| 118 |
+
|
| 119 |
+
# Build ISO timestamp if a date was picked
|
| 120 |
+
ts_iso = ""
|
| 121 |
+
if date_val:
|
| 122 |
+
if isinstance(date_val, _date):
|
| 123 |
+
dstr = date_val.isoformat()
|
| 124 |
+
else:
|
| 125 |
+
dstr = str(date_val)
|
| 126 |
+
|
| 127 |
+
tstr = (time_val or "00:00").strip()
|
| 128 |
+
if len(tstr) == 5: # HH:MM -> add seconds
|
| 129 |
+
tstr = f"{tstr}:00"
|
| 130 |
+
ts_iso = f"{dstr}T{tstr}Z"
|
| 131 |
+
|
| 132 |
+
if ts_iso:
|
| 133 |
+
payload["timestamp"] = ts_iso
|
| 134 |
+
|
| 135 |
+
resp = requests.post(f"{API_URL}/add_doc", json=payload)
|
| 136 |
+
try:
|
| 137 |
+
j = resp.json()
|
| 138 |
+
return "\n".join(j.get("logs", [])) or "No logs."
|
| 139 |
+
except Exception as e:
|
| 140 |
+
return f"Error: {e}\nRaw response: {resp.text[:500]}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def ask_ui(question, w_cos, w_path, w_fresh, w_deg):
|
| 144 |
+
payload = {
|
| 145 |
+
"question": question,
|
| 146 |
+
"w_cos": w_cos,
|
| 147 |
+
"w_path": w_path,
|
| 148 |
+
"w_fresh": w_fresh,
|
| 149 |
+
"w_deg": w_deg,
|
| 150 |
+
}
|
| 151 |
+
resp = requests.post(f"{API_URL}/ask", json=payload)
|
| 152 |
+
try:
|
| 153 |
+
j = resp.json()
|
| 154 |
+
except Exception as e:
|
| 155 |
+
err = f"Error: {e}\nRaw response: {resp.text[:500]}"
|
| 156 |
+
return (
|
| 157 |
+
_render_answer_card("I don’t know based on the given evidence.\nConfidence: Low"),
|
| 158 |
+
"_No evidence returned._",
|
| 159 |
+
err,
|
| 160 |
+
"<div id='graph' style='height:420px'></div>",
|
| 161 |
+
{},
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
answer_html = _render_answer_card(j.get("answer", ""))
|
| 165 |
+
evidence_md = _render_evidence_markdown(j.get("evidence", []))
|
| 166 |
+
logs_txt = "\n".join(j.get("logs", [])) or "No logs."
|
| 167 |
+
|
| 168 |
+
# D3 container but if no data fall back to server SVG
|
| 169 |
+
graph_json = j.get("subgraph_json", {})
|
| 170 |
+
if graph_json and graph_json.get("nodes"):
|
| 171 |
+
graph_html_value = "<div id='graph' style='height:420px'></div>"
|
| 172 |
+
else:
|
| 173 |
+
graph_html_value = _wrap_svg(j.get("subgraph_svg", ""))
|
| 174 |
+
|
| 175 |
+
return (answer_html, evidence_md, logs_txt, graph_html_value, graph_json)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# UI
|
| 179 |
+
with gr.Blocks(
|
| 180 |
+
css="""
|
| 181 |
+
/* Layout & theme */
|
| 182 |
+
body { background: #0b0f14; color: #e6edf3; }
|
| 183 |
+
.gradio-container { max-width: 1180px !important; }
|
| 184 |
+
.section-title { font-size: 22px; font-weight: 700; margin: 6px 0 12px; }
|
| 185 |
+
|
| 186 |
+
/* Cards */
|
| 187 |
+
.card { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px; padding: 14px; }
|
| 188 |
+
.card-title { font-size: 16px; letter-spacing: .3px; color: #9fb3c8; margin-bottom: 8px; text-transform: uppercase; }
|
| 189 |
+
.answer { font-size: 18px; line-height: 1.5; margin-bottom: 8px; }
|
| 190 |
+
.sub { color: #a8b3bf; margin-top: 6px; font-size: 14px; }
|
| 191 |
+
|
| 192 |
+
/* Badges */
|
| 193 |
+
.badge { padding: 3px 10px; border-radius: 999px; font-size: 12px; font-weight: 700; display: inline-block; }
|
| 194 |
+
.badge-high { background: #12391a; color: #6ee787; border: 1px solid #285f36; }
|
| 195 |
+
.badge-medium { background: #3a2b13; color: #ffd277; border: 1px solid #6b4e1f; }
|
| 196 |
+
.badge-low { background: #3b1616; color: #ff9492; border: 1px solid #6b2020; }
|
| 197 |
+
.badge-none { background: #223; color: #9fb3c8; border: 1px solid #334; }
|
| 198 |
+
|
| 199 |
+
/* Graph */
|
| 200 |
+
.graph-wrap { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px;
|
| 201 |
+
padding: 12px; height: 460px; overflow: auto; }
|
| 202 |
+
.graph-empty { color: #9fb3c8; font-style: italic; padding: 16px; }
|
| 203 |
+
|
| 204 |
+
/* Logs */
|
| 205 |
+
#logs-box textarea {
|
| 206 |
+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace !important;
|
| 207 |
+
max-height: 280px !important;
|
| 208 |
+
overflow-y: auto !important;
|
| 209 |
+
}
|
| 210 |
+
"""
|
| 211 |
+
) as demo:
|
| 212 |
+
gr.Markdown("### 🚀 GraphRAG — Live Demo")
|
| 213 |
+
|
| 214 |
+
with gr.Tab("Add Document"):
|
| 215 |
+
with gr.Row():
|
| 216 |
+
with gr.Column(scale=3):
|
| 217 |
+
text_in = gr.Textbox(
|
| 218 |
+
label="Document",
|
| 219 |
+
lines=10,
|
| 220 |
+
placeholder="Paste text to inject into Graph + Vector DB…",
|
| 221 |
+
)
|
| 222 |
+
with gr.Column(scale=1):
|
| 223 |
+
source_in = gr.Textbox(label="Source", value="user")
|
| 224 |
+
|
| 225 |
+
if hasattr(gr, "Date"):
|
| 226 |
+
ts_date = gr.Date(label="Date (optional)")
|
| 227 |
+
else:
|
| 228 |
+
ts_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)")
|
| 229 |
+
|
| 230 |
+
if hasattr(gr, "Time"):
|
| 231 |
+
ts_time = gr.Time(label="Time (optional)", value="00:00")
|
| 232 |
+
else:
|
| 233 |
+
ts_time = gr.Textbox(label="Time (HH:MM, optional)", value="00:00")
|
| 234 |
+
|
| 235 |
+
add_btn = gr.Button("Add Doc", variant="primary")
|
| 236 |
+
add_logs = gr.Textbox(label="Ingestion Logs", lines=14, elem_id="logs-box")
|
| 237 |
+
|
| 238 |
+
add_btn.click(
|
| 239 |
+
add_doc_ui,
|
| 240 |
+
inputs=[text_in, source_in, ts_date, ts_time],
|
| 241 |
+
outputs=add_logs
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
with gr.Tab("Ask Question"):
|
| 246 |
+
with gr.Row():
|
| 247 |
+
q_in = gr.Textbox(
|
| 248 |
+
label="Question", placeholder="e.g., Who acquired Instagram?"
|
| 249 |
+
)
|
| 250 |
+
ask_btn = gr.Button("Ask", variant="primary")
|
| 251 |
+
|
| 252 |
+
with gr.Accordion("Rerank Weights", open=False):
|
| 253 |
+
w_cos = gr.Slider(0, 1, value=0.60, step=0.05, label="Cosine weight")
|
| 254 |
+
w_path = gr.Slider(0, 1, value=0.20, step=0.05, label="Path proximity weight")
|
| 255 |
+
w_fresh = gr.Slider(0, 1, value=0.15, step=0.05, label="Freshness weight")
|
| 256 |
+
w_deg = gr.Slider(0, 1, value=0.05, step=0.05, label="Degree norm weight")
|
| 257 |
+
|
| 258 |
+
with gr.Row():
|
| 259 |
+
with gr.Column(scale=1):
|
| 260 |
+
gr.Markdown("<div class='section-title'>Answer</div>")
|
| 261 |
+
ans_html = gr.HTML(value=_render_answer_card("Ask something to see results."))
|
| 262 |
+
|
| 263 |
+
evid = gr.Accordion("Evidence (ranked)", open=True)
|
| 264 |
+
with evid:
|
| 265 |
+
evid_md = gr.Markdown()
|
| 266 |
+
|
| 267 |
+
logs = gr.Accordion("Debug logs", open=False)
|
| 268 |
+
with logs:
|
| 269 |
+
logs_txt = gr.Textbox(lines=14, elem_id="logs-box")
|
| 270 |
+
|
| 271 |
+
with gr.Column(scale=1):
|
| 272 |
+
gr.Markdown("<div class='section-title'>Evidence Graph</div>")
|
| 273 |
+
graph_html = gr.HTML(value="<div id='graph' style='height:600px'></div>")
|
| 274 |
+
graph_data = gr.JSON(label="graph-data", visible=False)
|
| 275 |
+
|
| 276 |
+
ask_btn.click(
|
| 277 |
+
ask_ui,
|
| 278 |
+
inputs=[q_in, w_cos, w_path, w_fresh, w_deg],
|
| 279 |
+
outputs=[ans_html, evid_md, logs_txt, graph_html, graph_data],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
with gr.Tab("Metrics"):
|
| 283 |
+
metrics_btn = gr.Button("Run Evaluation", variant="primary")
|
| 284 |
+
metrics_out = gr.Markdown("Click run to evaluate baseline vs hybrid.")
|
| 285 |
+
metrics_btn.click(metrics_ui, inputs=[], outputs=metrics_out)
|
| 286 |
+
|
| 287 |
+
# D3 rendering for the knowledge graph
|
| 288 |
+
DRAW_JS = r"""
|
| 289 |
+
(value) => {
|
| 290 |
+
const el = document.querySelector("#graph");
|
| 291 |
+
if (!el) return null;
|
| 292 |
+
el.innerHTML = "";
|
| 293 |
+
|
| 294 |
+
if (!value || !value.nodes || value.nodes.length === 0) {
|
| 295 |
+
el.innerHTML = "<div class='graph-empty'>No graph</div>";
|
| 296 |
+
return null;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
function ensureD3(cb) {
|
| 300 |
+
if (window.d3) return cb();
|
| 301 |
+
const s = document.createElement("script");
|
| 302 |
+
s.src = "https://cdn.jsdelivr.net/npm/d3@7";
|
| 303 |
+
s.onload = cb;
|
| 304 |
+
document.head.appendChild(s);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
ensureD3(() => {
|
| 308 |
+
const width = el.clientWidth || 900;
|
| 309 |
+
const height = 600;
|
| 310 |
+
|
| 311 |
+
const svg = d3.select(el).append("svg")
|
| 312 |
+
.attr("viewBox", [0, 0, width, height])
|
| 313 |
+
.attr("preserveAspectRatio", "xMidYMid meet")
|
| 314 |
+
.style("width", "100%")
|
| 315 |
+
.style("height", "100%");
|
| 316 |
+
|
| 317 |
+
// Create zoomable container
|
| 318 |
+
const container = svg.append("g");
|
| 319 |
+
|
| 320 |
+
// Enable zoom & pan
|
| 321 |
+
svg.call(
|
| 322 |
+
d3.zoom()
|
| 323 |
+
.scaleExtent([0.2, 3]) // zoom limits (20%–300%)
|
| 324 |
+
.on("zoom", (event) => {
|
| 325 |
+
container.attr("transform", event.transform);
|
| 326 |
+
})
|
| 327 |
+
);
|
| 328 |
+
|
| 329 |
+
const sim = d3.forceSimulation(value.nodes)
|
| 330 |
+
.force("link", d3.forceLink(value.links).id(d => d.id).distance(140).strength(0.4))
|
| 331 |
+
.force("charge", d3.forceManyBody().strength(-220))
|
| 332 |
+
.force("center", d3.forceCenter(width / 2, height / 2));
|
| 333 |
+
|
| 334 |
+
const link = container.append("g")
|
| 335 |
+
.attr("stroke", "#999")
|
| 336 |
+
.attr("stroke-opacity", 0.6)
|
| 337 |
+
.selectAll("line")
|
| 338 |
+
.data(value.links)
|
| 339 |
+
.enter().append("line")
|
| 340 |
+
.attr("stroke-width", 1.5);
|
| 341 |
+
|
| 342 |
+
const edgeLabels = container.append("g")
|
| 343 |
+
.selectAll("text")
|
| 344 |
+
.data(value.links)
|
| 345 |
+
.enter().append("text")
|
| 346 |
+
.attr("font-size", 10)
|
| 347 |
+
.attr("fill", "#bbb")
|
| 348 |
+
.text(d => d.label);
|
| 349 |
+
|
| 350 |
+
const node = container.append("g")
|
| 351 |
+
.selectAll("circle")
|
| 352 |
+
.data(value.nodes)
|
| 353 |
+
.enter().append("circle")
|
| 354 |
+
.attr("r", 12)
|
| 355 |
+
.attr("fill", "#69b3a2")
|
| 356 |
+
.attr("stroke", "#2dd4bf")
|
| 357 |
+
.attr("stroke-width", 1.2)
|
| 358 |
+
.call(d3.drag()
|
| 359 |
+
.on("start", (event, d) => { if (!event.active) sim.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y; })
|
| 360 |
+
.on("drag", (event, d) => { d.fx = event.x; d.fy = event.y; })
|
| 361 |
+
.on("end", (event, d) => { if (!event.active) sim.alphaTarget(0); d.fx = null; d.fy = null; })
|
| 362 |
+
);
|
| 363 |
+
|
| 364 |
+
const labels = container.append("g")
|
| 365 |
+
.selectAll("text")
|
| 366 |
+
.data(value.nodes)
|
| 367 |
+
.enter().append("text")
|
| 368 |
+
.attr("font-size", 12)
|
| 369 |
+
.attr("fill", "#ddd")
|
| 370 |
+
.attr("dy", 18)
|
| 371 |
+
.attr("text-anchor", "middle")
|
| 372 |
+
.text(d => d.id);
|
| 373 |
+
|
| 374 |
+
sim.on("tick", () => {
|
| 375 |
+
link
|
| 376 |
+
.attr("x1", d => d.source.x)
|
| 377 |
+
.attr("y1", d => d.source.y)
|
| 378 |
+
.attr("x2", d => d.target.x)
|
| 379 |
+
.attr("y2", d => d.target.y);
|
| 380 |
+
|
| 381 |
+
edgeLabels
|
| 382 |
+
.attr("x", d => (d.source.x + d.target.x) / 2)
|
| 383 |
+
.attr("y", d => (d.source.y + d.target.y) / 2);
|
| 384 |
+
|
| 385 |
+
node
|
| 386 |
+
.attr("cx", d => d.x)
|
| 387 |
+
.attr("cy", d => d.y);
|
| 388 |
+
|
| 389 |
+
labels
|
| 390 |
+
.attr("x", d => d.x)
|
| 391 |
+
.attr("y", d => d.y);
|
| 392 |
+
});
|
| 393 |
+
});
|
| 394 |
+
|
| 395 |
+
return null;
|
| 396 |
+
}
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
graph_data.change(lambda x: x, inputs=graph_data, outputs=graph_data).then(
|
| 401 |
+
None, inputs=graph_data, outputs=None, js=DRAW_JS
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
demo.launch()
|
| 406 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timezone
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
def compute_freshness(ts, half_life_days=30):
|
| 5 |
+
"""
|
| 6 |
+
Exponential decay score between 0 and 1.
|
| 7 |
+
- 1.0 -> now
|
| 8 |
+
- 0.5 -> half_life_days old
|
| 9 |
+
- Approaches 0 -> when docs get very old
|
| 10 |
+
"""
|
| 11 |
+
if not ts:
|
| 12 |
+
return 0.5
|
| 13 |
+
if isinstance(ts, str):
|
| 14 |
+
ts = ts.replace("Z", "+00:00")
|
| 15 |
+
try:
|
| 16 |
+
ts = datetime.fromisoformat(ts)
|
| 17 |
+
except Exception:
|
| 18 |
+
return 0.5
|
| 19 |
+
if ts.tzinfo is None:
|
| 20 |
+
ts = ts.replace(tzinfo=timezone.utc)
|
| 21 |
+
|
| 22 |
+
age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400.0
|
| 23 |
+
lam = math.log(2) / max(float(half_life_days), 1.0)
|
| 24 |
+
return math.exp(-lam * age_days)
|
vec.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.http import models
|
| 7 |
+
|
| 8 |
+
import uuid
|
| 9 |
+
from datetime import datetime, timezone
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 13 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 14 |
+
COLLECTION = "docs"
|
| 15 |
+
|
| 16 |
+
model = SentenceTransformer("all-MiniLM-L6-v2") #fast embedder
|
| 17 |
+
qdrant = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
| 18 |
+
|
| 19 |
+
# Ensure collection exists (safe init .... we dont wipe on reload)
|
| 20 |
+
def _ensure_collection():
|
| 21 |
+
try:
|
| 22 |
+
coll_info = qdrant.get_collection(collection_name=COLLECTION)
|
| 23 |
+
if not coll_info:
|
| 24 |
+
raise Exception("Collection not found")
|
| 25 |
+
except Exception:
|
| 26 |
+
print(f"⚠️ Collection '{COLLECTION}' not found. Creating fresh collection...")
|
| 27 |
+
qdrant.create_collection(
|
| 28 |
+
collection_name=COLLECTION,
|
| 29 |
+
vectors_config=models.VectorParams(
|
| 30 |
+
size=384,
|
| 31 |
+
distance=models.Distance.COSINE
|
| 32 |
+
),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
_ensure_collection()
|
| 36 |
+
|
| 37 |
+
def embed_and_upsert(chunks, source="user", timestamp=None):
|
| 38 |
+
if timestamp is None:
|
| 39 |
+
timestamp = datetime.now(timezone.utc).isoformat()
|
| 40 |
+
|
| 41 |
+
embeddings = model.encode(chunks).tolist()
|
| 42 |
+
points = []
|
| 43 |
+
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
|
| 44 |
+
points.append(
|
| 45 |
+
models.PointStruct(
|
| 46 |
+
id=str(uuid.uuid4()),
|
| 47 |
+
vector=emb,
|
| 48 |
+
payload={
|
| 49 |
+
"text": chunk,
|
| 50 |
+
"source": source,
|
| 51 |
+
"timestamp": timestamp,
|
| 52 |
+
"chunk_id": i
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
qdrant.upsert(collection_name=COLLECTION, points=points, wait=True)
|
| 57 |
+
print(f"✅ Stored {len(points)} chunks in Qdrant (source={source}, ts={timestamp})")
|
| 58 |
+
return True
|
| 59 |
+
|
| 60 |
+
def search(query: str, top_k: int = 5):
|
| 61 |
+
q_emb = model.encode([query])[0].tolist()
|
| 62 |
+
results = qdrant.search(
|
| 63 |
+
collection_name=COLLECTION,
|
| 64 |
+
query_vector=q_emb,
|
| 65 |
+
limit=top_k,
|
| 66 |
+
with_payload=True
|
| 67 |
+
)
|
| 68 |
+
return [
|
| 69 |
+
{
|
| 70 |
+
"chunk": r.payload.get("text", ""),
|
| 71 |
+
"cosine": r.score,
|
| 72 |
+
"timestamp": r.payload.get("timestamp"),
|
| 73 |
+
"source": r.payload.get("source")
|
| 74 |
+
}
|
| 75 |
+
for r in results
|
| 76 |
+
]
|