aayush226 commited on
Commit
4d9fcca
·
verified ·
1 Parent(s): dbeed48

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +33 -14
  2. app.py +374 -0
  3. eval.py +71 -0
  4. kg.py +315 -0
  5. requirements.txt +12 -0
  6. rerank.py +30 -0
  7. text.py +13 -0
  8. ui.py +406 -0
  9. utils.py +24 -0
  10. vec.py +76 -0
README.md CHANGED
@@ -1,14 +1,33 @@
1
- ---
2
- title: GraphRAG Live
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.46.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Hybrid RAG with Graph + Vectors
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GraphRAG-Live
2
+
3
+ **Hybrid Retrieval-Augmented Generation (RAG) with Graph + Vectors.**
4
+
5
+ This project shows how knowledge graphs (Neo4j Aura) and vector databases (Qdrant) can be combined with re-ranking heuristics to build a smarter, cheaper and more explainable RAG system.
6
+
7
+ ---
8
+
9
+ ## 🌟 Features
10
+ - **Hybrid Retrieval:** Combines semantic search (Qdrant) with graph proximity scoring (Neo4j).
11
+ - **Dynamic Knowledge Injection:** Add new documents on the fly → pipeline updates instantly.
12
+ - **Evidence Subgraphs:** Each answer includes a small 2-hop evidence graph.
13
+ - **Metrics Dashboard:** Compare GraphRAG vs. baseline RAG on hit@10, nDCG@10, citation correctness.
14
+ - **Hosted Demo:** Deployed via Hugging Face Spaces (Gradio UI).
15
+
16
+ ---
17
+
18
+ ## 🏗️ Architecture
19
+
20
+ ```text
21
+ User Question
22
+
23
+
24
+ [Qdrant: semantic chunks] + [Neo4j: graph proximity]
25
+
26
+
27
+ Reranker (cosine + path proximity + freshness + degree)
28
+
29
+
30
+ Answer Generator (OpenAI)
31
+
32
+
33
+ Evidence Subgraph + Answer + Citations
app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Literal
4
+ from datetime import datetime
5
+ import os, json
6
+
7
+ from text import chunk_text
8
+ from vec import embed_and_upsert, search
9
+ from kg import (
10
+ extract_and_insert,
11
+ get_subgraph,
12
+ compute_path_proximity,
13
+ compute_degree_norm,
14
+ )
15
+ from rerank import rerank_candidates
16
+ from eval import evaluate
17
+ from utils import compute_freshness
18
+
19
+ from dotenv import load_dotenv
20
+ from openai import OpenAI
21
+
22
+ load_dotenv(override=True)
23
+
24
+ key = os.environ.get("OPENAI_API_KEY", "").strip()
25
+ client = OpenAI(api_key=key)
26
+
27
+ app = FastAPI()
28
+
29
+
30
+ # Schemas for Pydantic + structured output
31
+ class DocInput(BaseModel):
32
+ text: str
33
+ source: str = "user"
34
+ timestamp: datetime = datetime.now()
35
+
36
+
37
+ class QuestionInput(BaseModel):
38
+ question: str
39
+ w_cos: float = 0.60
40
+ w_path: float = 0.20
41
+ w_fresh: float = 0.15
42
+ w_deg: float = 0.05
43
+
44
+
45
+ # LLM output requirement (enforceing this with JSON output + Pydantic)
46
+ class LLMAnswer(BaseModel):
47
+ answer: str = Field(..., description="One-sentence final answer")
48
+ citations: List[str] = Field(
49
+ default_factory=list,
50
+ description="Evidence IDs like E1, E3 that support the answer",
51
+ )
52
+ graph_reasoning: str = Field(
53
+ "", description="How the graph helped, or 'Not used'"
54
+ )
55
+ confidence: Literal["High", "Medium", "Low"] = "Low"
56
+
57
+
58
+ # Helpers for the explanation on the controls (weights)
59
+ def _get_scores(c, w_cos, w_path, w_fresh, w_deg):
60
+ cos = float(c.get("cosine", c.get("cosine_sim", 0.0)) or 0.0)
61
+ pp = float(c.get("path_proximity", 0.0) or 0.0)
62
+ fr = float(c.get("freshness_decay", 0.0) or 0.0)
63
+ dg = float(c.get("degree_norm", 0.0) or 0.0)
64
+ final = w_cos * cos + w_path * pp + w_fresh * fr + w_deg * dg
65
+ return cos, pp, fr, dg, final
66
+
67
+
68
+ def _build_knobs_breakdown(numbered, w_cos, w_path, w_fresh, w_deg):
69
+ """
70
+ Returns (knobs_line, knobs_explain) strings. Uses top 1 only and runner up if available.
71
+ """
72
+ if not numbered:
73
+ return "", ""
74
+
75
+ idx1, c1 = numbered[0]
76
+ cos1, pp1, fr1, dg1, fin1 = _get_scores(c1, w_cos, w_path, w_fresh, w_deg)
77
+
78
+ # Optional runner up
79
+ ru_piece, explain = "", ""
80
+ if len(numbered) > 1:
81
+ idx2, c2 = numbered[1]
82
+ cos2, pp2, fr2, dg2, fin2 = _get_scores(c2, w_cos, w_path, w_fresh, w_deg)
83
+ margin = fin1 - fin2
84
+ ru_piece = f"; Runner-up E{idx2}={fin2:.3f}; Margin={margin:+.3f}"
85
+
86
+ # Contribution of the deltas (weighted)
87
+ deltas = [
88
+ ("path", w_path * (pp1 - pp2), pp1, pp2, w_path),
89
+ ("freshness", w_fresh * (fr1 - fr2), fr1, fr2, w_fresh),
90
+ ("cosine", w_cos * (cos1 - cos2), cos1, cos2, w_cos),
91
+ ("degree", w_deg * (dg1 - dg2), dg1, dg2, w_deg),
92
+ ]
93
+ deltas.sort(key=lambda x: x[1], reverse=True)
94
+ # Pick top positive drivers
95
+ drivers = [f"{name} ({d:+.3f})" for name, d, *_ in deltas if d > 0.002][:3]
96
+ # A short natural language sentence
97
+ if drivers:
98
+ top_names = ", ".join(drivers)
99
+ else:
100
+ top_names = "mostly cosine similarity (others were negligible)"
101
+ explain = (
102
+ f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
103
+ f"E{idx1} leads by {margin:+.3f}. Biggest lifts vs E{idx2}: {top_names}."
104
+ )
105
+ else:
106
+ # No runner up but sstill provide a brief note
107
+ explain = (
108
+ f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
109
+ f"the top candidate E{idx1} scored {fin1:.3f}."
110
+ )
111
+
112
+ knobs_line = (
113
+ f"Weights→ cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}. "
114
+ f"E{idx1} final={fin1:.3f} = {w_cos:.2f}×{cos1:.3f} + {w_path:.2f}×{pp1:.3f} + "
115
+ f"{w_fresh:.2f}×{fr1:.3f} + {w_deg:.2f}×{dg1:.3f}{ru_piece}; Cosine-only(E{idx1})={cos1:.3f}."
116
+ )
117
+ return knobs_line, explain
118
+
119
+
120
+ # API Endpoints
121
+ @app.get("/metrics")
122
+ def metrics_endpoint():
123
+ logs = []
124
+ try:
125
+ results = evaluate()
126
+ logs.append("✅ Ran evaluation set")
127
+ return {"status": "ok", "results": results, "logs": logs}
128
+ except Exception as e:
129
+ logs.append(f"⚠️ Metrics failed: {e}")
130
+ return {"status": "error", "logs": logs}
131
+
132
+
133
+ @app.post("/add_doc")
134
+ def add_doc_endpoint(doc: DocInput):
135
+ logs = ["📥 Received document"]
136
+ text, source, timestamp = doc.text, doc.source, doc.timestamp
137
+
138
+ # 1) Chunk
139
+ chunks = chunk_text(text)
140
+ logs.append(f"✂️ Chunked into {len(chunks)} pieces")
141
+
142
+ # 2) Embed + store
143
+ embed_and_upsert(chunks, source=source, timestamp=timestamp.isoformat())
144
+ logs.append(f"🧮 Embedded + stored in Qdrant (source={source}, ts={timestamp})")
145
+
146
+ # 3) Extract triples and feed to Neo4j
147
+ neo4j_logs = extract_and_insert(chunks, source=source, timestamp=str(timestamp))
148
+ logs.extend(neo4j_logs or ["🌐 No entities/relations extracted for Neo4j"])
149
+ return {"status": "ok", "logs": logs}
150
+
151
+
152
+ @app.post("/ask")
153
+ def ask_endpoint(query: QuestionInput):
154
+ logs = []
155
+ q = query.question
156
+ logs.append(f"❓ Received question: {q}")
157
+
158
+ # Retrieve
159
+ candidates = search(q, top_k=5)
160
+ logs.append(f"🔎 Retrieved {len(candidates)} from Qdrant")
161
+
162
+ # Graph aware features??
163
+ for c in candidates:
164
+ c["path_proximity"] = compute_path_proximity(q, c["chunk"])
165
+ c["degree_norm"] = compute_degree_norm(c["chunk"])
166
+ c["freshness_decay"] = compute_freshness(c.get("timestamp"))
167
+
168
+ # Rerank
169
+ reranked, rerank_logs = rerank_candidates(
170
+ candidates,
171
+ w_cos=query.w_cos,
172
+ w_path=query.w_path,
173
+ w_fresh=query.w_fresh,
174
+ w_deg=query.w_deg,
175
+ )
176
+ logs.append("📊 Applied graph-aware re-ranking")
177
+ logs.extend(rerank_logs)
178
+
179
+ # Evidence subgraph (≤2 hops)
180
+ triples = get_subgraph(q, source=None)
181
+ logs.append(f"🌐 Subgraph triples: {len(triples)}")
182
+
183
+ # Prepare evidence numbering for citations
184
+ numbered = [(i + 1, c) for i, c in enumerate(reranked)]
185
+ TOP_N = 2 # TODO -> expermient with more
186
+ reranked = reranked[:TOP_N]
187
+ numbered = [(i + 1, c) for i, c in enumerate(reranked)]
188
+ evidence_for_prompt = [f"[E{i}] {c['chunk']}" for i, c in numbered]
189
+ evidence_for_ui = [f"[E{i}] {c['chunk']}" for i, c in numbered]
190
+
191
+ knobs_line, knobs_explain = _build_knobs_breakdown(
192
+ numbered, query.w_cos, query.w_path, query.w_fresh, query.w_deg
193
+ )
194
+
195
+ # LLM answer (OpenAI, structured JSON -> Pydantic)
196
+ if reranked:
197
+ triples_text = "\n".join([f"({s}) -[{r}]-> ({o})" for s, r, o in triples])
198
+
199
+ # Schema friendly request
200
+ prompt = f"""
201
+ You are a precise QA assistant that MUST use BOTH the retrieved evidence and the graph triples.
202
+
203
+ Question:
204
+ {q}
205
+
206
+ Retrieved Evidence (ranked by importance, highest first):
207
+ {chr(10).join(evidence_for_prompt)}
208
+
209
+ Knowledge Graph Triples:
210
+ {triples_text}
211
+
212
+ Instructions:
213
+ - E1 is the most relevant, E2 is second-most, and so on.
214
+ - Prefer evidence with a lower number if multiple sources conflict.
215
+ - If supported, produce a single-sentence answer.
216
+ - Cite supporting evidence IDs (e.g., E1, E2).
217
+ - If the graph helped, say how; else "Not used".
218
+ - If not supported, return "I don’t know..." with Low confidence.
219
+
220
+ Return ONLY a JSON object matching this schema:
221
+ {{
222
+ "answer": "string",
223
+ "citations": ["E1","E2"],
224
+ "graph_reasoning": "string",
225
+ "confidence": "High|Medium|Low"
226
+ }}
227
+ """.strip()
228
+
229
+ logs.append("📝 Built prompt with evidence + graph")
230
+ try:
231
+ comp = client.chat.completions.create(
232
+ model="gpt-4o-mini",
233
+ messages=[
234
+ {"role": "system", "content": "Respond ONLY with a JSON object."},
235
+ {"role": "user", "content": prompt},
236
+ ],
237
+ # Ensures valid JSON
238
+ response_format={"type": "json_object"},
239
+ temperature=0,
240
+ max_tokens=300,
241
+ )
242
+ raw_json = comp.choices[0].message.content or "{}"
243
+ data = json.loads(raw_json)
244
+
245
+ # Validate and normalize with Pydantic
246
+ parsed = LLMAnswer.model_validate(data)
247
+
248
+ # Build display string for the UI card
249
+ citations_txt = ", ".join(parsed.citations) if parsed.citations else "None"
250
+ answer_text = (
251
+ f"{parsed.answer}\n"
252
+ f"Citations: {citations_txt}\n"
253
+ f"Graph reasoning: {parsed.graph_reasoning or '—'}\n"
254
+ f"Confidence: {parsed.confidence}\n"
255
+ f"Knobs: {knobs_line or '—'}\n"
256
+ f"Knobs explain: {knobs_explain or '—'}"
257
+ )
258
+
259
+ answer = answer_text
260
+ logs.append("🤖 Called OpenAI")
261
+ logs.append("🧠 Generated final answer")
262
+ except Exception as e:
263
+ top_chunk = reranked[0]["chunk"] if reranked else "No evidence"
264
+ answer = (
265
+ f"Based on evidence: {top_chunk}\n"
266
+ f"Citations: None\n"
267
+ f"Graph reasoning: Not used\n"
268
+ f"Confidence: Low\n"
269
+ f"Knobs: {knobs_line or '—'}\n"
270
+ f"Knobs explain: {knobs_explain or '—'}"
271
+ )
272
+ logs.append(f"⚠️ OpenAI failed, fallback to stub ({e})")
273
+ else:
274
+ answer = (
275
+ "No evidence found.\n"
276
+ "Citations: None\n"
277
+ "Graph reasoning: Not used\n"
278
+ "Confidence: Low\n"
279
+ f"Knobs: {knobs_line or '—'}\n"
280
+ f"Knobs explain: {knobs_explain or '—'}"
281
+ )
282
+ evidence_for_ui = []
283
+ logs.append("⚠️ No evidence, answer is empty")
284
+
285
+ # Build D3 JSON
286
+ node_map = {}
287
+ links = []
288
+ for s, r, o in triples:
289
+ node_map.setdefault(s, {"id": s})
290
+ node_map.setdefault(o, {"id": o})
291
+ links.append({"source": s, "target": o, "label": r})
292
+ subgraph_json = {"nodes": list(node_map.values()), "links": links}
293
+
294
+ # Server side SVG fallback in case D3 fails to render
295
+ import networkx as nx
296
+
297
+ G = nx.DiGraph()
298
+ for s, r, o in triples:
299
+ G.add_node(s)
300
+ G.add_node(o)
301
+ G.add_edge(s, o, label=r)
302
+
303
+ pos = nx.spring_layout(G, seed=42)
304
+ width, height, pad = 720, 420, 40
305
+ xs = [p[0] for p in pos.values()] or [0.0]
306
+ ys = [p[1] for p in pos.values()] or [0.0]
307
+ minx, maxx = min(xs), max(xs)
308
+ miny, maxy = min(ys), max(ys)
309
+ rangex = (maxx - minx) or 1.0
310
+ rangey = (maxy - miny) or 1.0
311
+
312
+ def sx(x): return pad + (x - minx) / rangex * (width - 2 * pad)
313
+ def sy(y): return pad + (y - miny) / rangey * (height - 2 * pad)
314
+
315
+ parts = []
316
+ parts.append(
317
+ f'<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" '
318
+ f'xmlns="http://www.w3.org/2000/svg">'
319
+ )
320
+ parts.append(
321
+ """
322
+ <defs>
323
+ <marker id="arrow" markerUnits="strokeWidth" markerWidth="10" markerHeight="8"
324
+ viewBox="0 0 10 8" refX="10" refY="4" orient="auto">
325
+ <path d="M0 0 L10 4 L0 8 z" fill="#999"/>
326
+ </marker>
327
+ <style>
328
+ .edge { stroke:#999; stroke-width:1.5; }
329
+ .nodelabel { font:12px sans-serif; fill:#ddd; }
330
+ .edgelabel { font:10px sans-serif; fill:#bbb; }
331
+ .node { fill:#69b3a2; stroke:#2dd4bf; stroke-width:1; }
332
+ </style>
333
+ </defs>
334
+ """
335
+ )
336
+
337
+ for u, v, data in G.edges(data=True):
338
+ x1, y1 = sx(pos[u][0]), sy(pos[u][1])
339
+ x2, y2 = sx(pos[v][0]), sy(pos[v][1])
340
+ parts.append(
341
+ f'<line class="edge" x1="{x1:.1f}" y1="{y1:.1f}" '
342
+ f'x2="{x2:.1f}" y2="{y2:.1f}" marker-end="url(#arrow)"/>'
343
+ )
344
+ mx, my = (x1 + x2) / 2.0, (y1 + y2) / 2.0
345
+ lbl = (data.get("label") or "").replace("&", "&amp;").replace("<", "&lt;")
346
+ parts.append(
347
+ f'<text class="edgelabel" x="{mx:.1f}" y="{my:.1f}" text-anchor="middle">{lbl}</text>'
348
+ )
349
+
350
+ for n in G.nodes():
351
+ x, y = sx(pos[n][0]), sy(pos[n][1])
352
+ node_txt = str(n).replace("&", "&amp;").replace("<", "&lt;")
353
+ r = max(16, len(node_txt) * 4)
354
+ parts.append(f'<circle class="node" cx="{x:.1f}" cy="{y:.1f}" r="{r}"/>')
355
+ parts.append(
356
+ f'<text class="nodelabel" x="{x:.1f}" y="{y + r + 14:.1f}" text-anchor="middle">{node_txt}</text>'
357
+ )
358
+ parts.append("</svg>")
359
+ subgraph_svg = "".join(parts)
360
+
361
+ logs.append(f"📦 Subgraph JSON dump: {subgraph_json}")
362
+
363
+ return {
364
+ "answer": answer,
365
+ "evidence": evidence_for_ui,
366
+ "subgraph_svg": subgraph_svg, # fallback
367
+ "subgraph_json": subgraph_json, # for D3 in UI
368
+ "logs": logs,
369
+ }
370
+
371
+
372
+ @app.get("/healthz")
373
+ def healthz():
374
+ return {"ok": True}
eval.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from vec import search
3
+ from rerank import rerank_candidates
4
+ from kg import compute_path_proximity, compute_degree_norm
5
+ from utils import compute_freshness
6
+
7
+ # A toy dataset: {question: [expected substrings]}
8
+ EVAL_SET = [
9
+ {"q": "Who founded OpenAI?", "answers": ["Elon Musk", "Sam Altman"]},
10
+ {"q": "What did OpenAI release?", "answers": ["GPT-4o", "Whisper", "SORA"]},
11
+ {"q": "What did Google acquire?", "answers": ["YouTube", "Instagram"]},
12
+ ]
13
+
14
+ def evaluate():
15
+ baseline_hits, hybrid_hits = [], []
16
+ baseline_ndcg, hybrid_ndcg = [], []
17
+ citation_correctness = []
18
+ latencies = []
19
+
20
+ for item in EVAL_SET:
21
+ q, golds = item["q"], item["answers"]
22
+
23
+ # Baseline (cosine only)
24
+ start = time.time()
25
+ baseline = search(q, top_k=10)
26
+ latencies.append(time.time() - start)
27
+
28
+ # Did we hit a gold in top-10?
29
+ hit = any(any(g.lower() in c["chunk"].lower() for g in golds) for c in baseline)
30
+ baseline_hits.append(1 if hit else 0)
31
+
32
+ # nDCG@10
33
+ scores = []
34
+ for rank, c in enumerate(baseline, 1):
35
+ rel = 1 if any(g.lower() in c["chunk"].lower() for g in golds) else 0
36
+ if rel:
37
+ scores.append(1 / (rank))
38
+ baseline_ndcg.append(sum(scores))
39
+
40
+ # Hybrid (cosine + path + freshness + degree)
41
+ for c in baseline:
42
+ c["path_proximity"] = compute_path_proximity(q, c["chunk"])
43
+ c["degree_norm"] = compute_degree_norm(c["chunk"])
44
+ c["freshness_decay"] = compute_freshness(c.get("timestamp"))
45
+
46
+ reranked, _ = rerank_candidates(baseline)
47
+ hit = any(any(g.lower() in c["chunk"].lower() for g in golds) for c in reranked[:10])
48
+ hybrid_hits.append(1 if hit else 0)
49
+
50
+ scores = []
51
+ for rank, c in enumerate(reranked, 1):
52
+ rel = 1 if any(g.lower() in c["chunk"].lower() for g in golds) else 0
53
+ if rel:
54
+ scores.append(1 / (rank))
55
+ hybrid_ndcg.append(sum(scores))
56
+
57
+ # TODO -> Citation correctness
58
+ citation_correctness.append(1)
59
+
60
+ return {
61
+ "baseline": {
62
+ "hit@10": sum(baseline_hits)/len(baseline_hits),
63
+ "nDCG@10": sum(baseline_ndcg)/len(baseline_ndcg),
64
+ },
65
+ "hybrid": {
66
+ "hit@10": sum(hybrid_hits)/len(hybrid_hits),
67
+ "nDCG@10": sum(hybrid_ndcg)/len(hybrid_ndcg),
68
+ },
69
+ "citation_correctness": sum(citation_correctness)/len(citation_correctness),
70
+ "avg_latency_sec": sum(latencies)/len(latencies),
71
+ }
kg.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json
2
+ from dotenv import load_dotenv
3
+ from neo4j import GraphDatabase
4
+ import spacy
5
+
6
+ try:
7
+ from groq import Groq
8
+ except Exception:
9
+ Groq = None
10
+
11
+ load_dotenv()
12
+ nlp = spacy.load("en_core_web_sm")
13
+
14
+ uri = os.getenv("NEO4J_URI")
15
+ user = os.getenv("NEO4J_USER")
16
+ password = os.getenv("NEO4J_PASS")
17
+ database = os.getenv("NEO4J_DATABASE", "neo4j")
18
+
19
+ # Options for "spacy" and "groq"
20
+ KG_EXTRACTOR = os.getenv("KG_EXTRACTOR", "spacy").strip().lower()
21
+
22
+ # Groq config (to extract the triplets)
23
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "").strip()
24
+ GROQ_MODEL = os.getenv("GROQ_MODEL", "openai/gpt-oss-20b").strip()
25
+ _groq = Groq(api_key=GROQ_API_KEY) if (Groq and GROQ_API_KEY) else None
26
+
27
+ driver = GraphDatabase.driver(uri, auth=(user, password))
28
+
29
+ ORG_HINTS = {
30
+ "inc","corp","corporation","ltd","llc","bank","securities","university",
31
+ "labs","institute","tech","technologies","systems","solutions","group"
32
+ }
33
+
34
+ # scoring helpers
35
+ def compute_path_proximity(question, chunk):
36
+ q_doc = nlp(question)
37
+ c_doc = nlp(chunk)
38
+ q_ents = {ent.text.lower() for ent in q_doc.ents}
39
+ c_ents = {ent.text.lower() for ent in c_doc.ents}
40
+ return 1.0 if q_ents & c_ents else 0.0
41
+
42
+ def compute_degree_norm(chunk):
43
+ doc = nlp(chunk)
44
+ ents = [ent.text for ent in doc.ents]
45
+ if not ents:
46
+ return 0.0
47
+ degrees = []
48
+ with driver.session(database=database) as session:
49
+ for e in ents:
50
+ res = session.run("""
51
+ MATCH (n {name_lc:$name_lc})
52
+ RETURN count { (n)--() } AS deg
53
+ LIMIT 1
54
+ """, name_lc=normalize_key(e))
55
+ rec = res.single()
56
+ if rec and rec["deg"] is not None:
57
+ degrees.append(rec["deg"])
58
+ if not degrees:
59
+ return 0.0
60
+ return 1.0 / (1 + sum(degrees) / len(degrees))
61
+
62
+ # label + text helpers
63
+ def get_label(ent_label: str) -> str:
64
+ mapping = {
65
+ "PERSON": "Person","ORG": "Org","GPE": "Location","NORP": "Group",
66
+ "FAC": "Facility","LOC": "Location","PRODUCT": "Product","EVENT": "Event",
67
+ "WORK_OF_ART": "Work","LAW": "Law","LANGUAGE": "Language","DATE": "Date",
68
+ "TIME": "Time","PERCENT": "Percent","MONEY": "Money","QUANTITY": "Quantity",
69
+ "ORDINAL": "Ordinal","CARDINAL": "Number",
70
+ }
71
+ return mapping.get(ent_label, "Entity")
72
+
73
+ def normalize_key(text: str) -> str:
74
+ return re.sub(r"\s+", " ", text.strip()).lower()
75
+
76
+ def entity_label_for_text(text: str, doc) -> str:
77
+ for ent in doc.ents:
78
+ if ent.text == text:
79
+ return get_label(ent.label_)
80
+ words = {w.lower() for w in text.split()}
81
+ if words & ORG_HINTS:
82
+ return "Org"
83
+ if text and text[0].isupper() and " " not in text:
84
+ return "Person"
85
+ return "Entity"
86
+
87
+ def span_text_for_token(tok, doc) -> str:
88
+ for ent in doc.ents:
89
+ if ent.start <= tok.i < ent.end:
90
+ return ent.text
91
+ left = [w for w in tok.lefts if w.dep_ in ("compound","amod","flat","nmod")]
92
+ right = [w for w in tok.rights if w.dep_ in ("compound","flat","nmod")]
93
+ tokens = sorted([*left, tok, *right], key=lambda t: t.i)
94
+ return " ".join(t.text for t in tokens if t.pos_ != "PUNCT")
95
+
96
+ def subjects_for_verb(v):
97
+ subs = [w for w in v.lefts if w.dep_ in ("nsubj","nsubjpass","csubj")]
98
+ if not subs and v.dep_ == "conj":
99
+ subs = subjects_for_verb(v.head)
100
+ out = []
101
+ for s in subs:
102
+ out.append(s)
103
+ out.extend(list(s.conjuncts))
104
+ return out
105
+
106
+ def objects_for_verb(v):
107
+ objs = [w for w in v.rights if w.dep_ in ("dobj","attr","pobj","dative","oprd")]
108
+ for prep in [w for w in v.rights if w.dep_ == "prep"]:
109
+ objs.extend([w for w in prep.rights if w.dep_ == "pobj"])
110
+ out = []
111
+ for o in objs:
112
+ out.append(o)
113
+ out.extend(list(o.conjuncts))
114
+ return out
115
+
116
+ # spaCy extractor
117
+ def _extract_triples_spacy(text: str):
118
+ triples = []
119
+ doc = nlp(text)
120
+ for tok in doc:
121
+ if tok.pos_ == "VERB":
122
+ subs = subjects_for_verb(tok)
123
+ objs = objects_for_verb(tok)
124
+ if not subs or not objs:
125
+ continue
126
+ rel = tok.lemma_.upper()
127
+ for s in subs:
128
+ s_text = span_text_for_token(s, doc)
129
+ s_label = entity_label_for_text(s_text, doc)
130
+ for o in objs:
131
+ o_text = span_text_for_token(o, doc)
132
+ o_label = entity_label_for_text(o_text, doc)
133
+ triples.append({
134
+ "subject": s_text, "subject_label": s_label,
135
+ "relation": rel,
136
+ "object": o_text, "object_label": o_label
137
+ })
138
+ return triples
139
+
140
+ # Groq extractor (structured)
141
+ _GROQ_SCHEMA = {
142
+ "type": "object",
143
+ "properties": {
144
+ "triples": {
145
+ "type": "array",
146
+ "items": {
147
+ "type": "object",
148
+ "properties": {
149
+ "subject": {"type": "string"},
150
+ "subject_label": {"type": "string"},
151
+ "relation": {"type": "string"},
152
+ "object": {"type": "string"},
153
+ "object_label": {"type": "string"},
154
+ },
155
+ "required": ["subject", "relation", "object"]
156
+ }
157
+ }
158
+ },
159
+ "required": ["triples"]
160
+ }
161
+
162
+ _ALLOWED_LABELS = (
163
+ "Person","Org","Location","Group","Facility","Product","Event","Work",
164
+ "Law","Language","Date","Time","Percent","Money","Quantity","Ordinal",
165
+ "Number","Entity"
166
+ )
167
+
168
+ def _extract_triples_groq(text: str):
169
+ if not _groq:
170
+ raise RuntimeError("Groq not configured. Set GROQ_API_KEY or use KG_EXTRACTOR=spacy.")
171
+ prompt = f"""
172
+ Extract concise subject-RELATION-object triples from the text.
173
+
174
+ Rules:
175
+ - Use a SINGLE UPPERCASE token for relation (e.g., ACQUIRE, FOUND, PARTNER_WITH).
176
+ - Provide subject_label and object_label using this set: {_ALLOWED_LABELS}.
177
+ - Merge duplicates; at most 8 triples per chunk.
178
+ - Return ONLY JSON matching this schema:
179
+
180
+ {json.dumps(_GROQ_SCHEMA, indent=2)}
181
+
182
+ Text:
183
+ \"\"\"{text}\"\"\"
184
+ """
185
+ resp = _groq.chat.completions.create(
186
+ model=GROQ_MODEL,
187
+ messages=[
188
+ {"role": "system", "content": "You are an information extractor. Output strictly valid JSON."},
189
+ {"role": "user", "content": prompt},
190
+ ],
191
+ temperature=0,
192
+ max_tokens=600,
193
+ response_format={"type": "json_object"},
194
+ )
195
+ raw = resp.choices[0].message.content or "{}"
196
+ try:
197
+ data = json.loads(raw)
198
+ except Exception:
199
+ start = raw.find("{"); end = raw.rfind("}")
200
+ data = json.loads(raw[start:end+1]) if start != -1 and end != -1 else {"triples": []}
201
+ triples = data.get("triples", [])
202
+ out = []
203
+ for t in triples:
204
+ subj = (t.get("subject") or "").strip()
205
+ obj = (t.get("object") or "").strip()
206
+ rel = (t.get("relation") or "").strip().upper().replace(" ", "_")
207
+ if not subj or not obj or not rel:
208
+ continue
209
+ sl = t.get("subject_label") or "Entity"
210
+ ol = t.get("object_label") or "Entity"
211
+ if sl not in _ALLOWED_LABELS: sl = "Entity"
212
+ if ol not in _ALLOWED_LABELS: ol = "Entity"
213
+ out.append({"subject": subj, "subject_label": sl, "relation": rel,
214
+ "object": obj, "object_label": ol})
215
+ return out
216
+
217
+ # Inserts
218
+ def _insert_triples(triples, source: str, timestamp: str, logs: list):
219
+ if not triples:
220
+ return
221
+ with driver.session(database=database) as session:
222
+ for t in triples:
223
+ s_text, s_label = t["subject"], t["subject_label"]
224
+ o_text, o_label = t["object"], t["object_label"]
225
+ rel = t["relation"]
226
+ cypher = f"""
227
+ MERGE (a:{s_label} {{name_lc:$a_key}})
228
+ ON CREATE SET a.name = $a_name
229
+ MERGE (b:{o_label} {{name_lc:$b_key}})
230
+ ON CREATE SET b.name = $b_name
231
+ MERGE (a)-[r:{rel}]->(b)
232
+ ON CREATE SET r.source=$source, r.ts=$ts
233
+ SET r.source=$source, r.ts=$ts
234
+ """
235
+ session.run(
236
+ cypher,
237
+ a_key=normalize_key(s_text), a_name=s_text,
238
+ b_key=normalize_key(o_text), b_name=o_text,
239
+ source=source, ts=timestamp
240
+ )
241
+ logs.append(f"🌐 Inserted ({s_text}:{s_label})-[:{rel}]->({o_text}:{o_label}) [src={source}, ts={timestamp}]")
242
+
243
+
244
+ # Public API used by app.py
245
+ def extract_and_insert(chunks, source="user", timestamp=None):
246
+ """
247
+ Extract triples (Groq or spaCy) and insert into Neo4j with source and timestamp.
248
+ """
249
+ from datetime import datetime
250
+ if timestamp is None:
251
+ timestamp = datetime.now().isoformat()
252
+
253
+ logs = []
254
+ use_groq = (KG_EXTRACTOR == "groq") and _groq is not None
255
+ for chunk in chunks:
256
+ triples = []
257
+ if use_groq:
258
+ try:
259
+ triples = _extract_triples_groq(chunk)
260
+ logs.append(f"🤝 Groq extracted {len(triples)} triples")
261
+ except Exception as e:
262
+ logs.append(f"⚠️ Groq extraction failed: {e}. Falling back to spaCy.")
263
+ if not triples:
264
+ triples = _extract_triples_spacy(chunk)
265
+ logs.append(f"🧠 spaCy extracted {len(triples)} triples")
266
+ _insert_triples(triples, source=source, timestamp=timestamp, logs=logs)
267
+ return logs
268
+
269
+ def test_connection():
270
+ with driver.session(database=database) as session:
271
+ msg = session.run("RETURN 'Connected to Neo4j!' AS msg").single()
272
+ print(msg["msg"])
273
+
274
+ def get_subgraph(question: str, source: str | None = None, limit: int = 24, evidence_chunks: list[str] | None = None):
275
+ """
276
+ Subgraph search by using NER (spaCy) and words from evidence chunks.
277
+ """
278
+ results = set()
279
+ doc = nlp(question)
280
+ entities = [ent.text for ent in doc.ents] or [t.text for t in doc if t.pos_ in ("PROPN","NOUN")]
281
+
282
+ if evidence_chunks:
283
+ for chunk in evidence_chunks:
284
+ c_doc = nlp(chunk)
285
+ entities.extend([ent.text for ent in c_doc.ents])
286
+
287
+ # Normalize + deduplicate
288
+ entities = list({normalize_key(e) for e in entities if e.strip()})
289
+
290
+ with driver.session(database=database) as session:
291
+ for e in entities:
292
+ cypher = """
293
+ MATCH (a)-[r]-(b)
294
+ WHERE (
295
+ (a.name_lc IS NOT NULL AND a.name_lc CONTAINS $k) OR
296
+ (b.name_lc IS NOT NULL AND b.name_lc CONTAINS $k) OR
297
+ toLower(a.name) CONTAINS $k OR
298
+ toLower(b.name) CONTAINS $k
299
+ )
300
+ """ + (" AND r.source = $source " if source else "") + """
301
+ RETURN DISTINCT a, type(r) AS rel, b
302
+ LIMIT $limit
303
+ """
304
+ params = {"k": e, "limit": limit}
305
+ if source:
306
+ params["source"] = source
307
+
308
+ for rec in session.run(cypher, **params):
309
+ a, rel, b = rec["a"], rec["rel"], rec["b"]
310
+ results.add((
311
+ a.get("name", a.get("name_lc","")),
312
+ rel,
313
+ b.get("name", b.get("name_lc",""))
314
+ ))
315
+ return list(results)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.45.0
2
+ fastapi==0.111.0
3
+ uvicorn==0.30.1
4
+ neo4j==5.23.0
5
+ qdrant-client==1.9.1
6
+ sentence-transformers==2.7.0
7
+ openai>=1.40.2
8
+ httpx>=0.27.2
9
+ python-dotenv==1.0.1
10
+ spacy==3.8.7
11
+ groq
12
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
rerank.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def rerank_candidates(candidates, w_cos=0.60, w_path=0.20, w_fresh=0.15, w_deg=0.05):
2
+ """
3
+ Rerank chunks with a hybrid scoring formula.
4
+ Weights are configurable from the ui.
5
+ """
6
+ reranked = []
7
+ logs = []
8
+
9
+ for idx, c in enumerate(candidates, 1):
10
+ score = (
11
+ w_cos * c.get("cosine", 0) +
12
+ w_path * c.get("path_proximity", 0) +
13
+ w_fresh * c.get("freshness_decay", 0) +
14
+ w_deg * c.get("degree_norm", 0)
15
+ )
16
+ c["final_score"] = score
17
+ reranked.append(c)
18
+
19
+ logs.append(
20
+ f"Candidate {idx}: "
21
+ f"cosine={c.get('cosine',0):.3f}, "
22
+ f"path={c.get('path_proximity',0):.3f}, "
23
+ f"freshness={c.get('freshness_decay',0):.3f}, "
24
+ f"degree={c.get('degree_norm',0):.3f} "
25
+ f"→ final={score:.3f}"
26
+ )
27
+
28
+ reranked.sort(key=lambda x: x["final_score"], reverse=True)
29
+ return reranked, logs
30
+
text.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def chunk_text(text: str, chunk_size : int = 200, overlap: int = 50) -> str:
2
+ """
3
+ Here, we will break the text into overlapping chunks and then feed
4
+ them to the embedding pipeline
5
+ """
6
+ chunks = []
7
+ start = 0
8
+ while start < len(text):
9
+ end = start + chunk_size
10
+ curr_chunk = text[start:end]
11
+ chunks.append(curr_chunk)
12
+ start += chunk_size - overlap # we need some overlap between the chunks
13
+ return chunks
ui.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import html
4
+ import os
5
+ from datetime import date as _date
6
+
7
+ API_URL = os.getenv("API_URL", "http://127.0.0.1:8000")
8
+
9
+ # Helpers
10
+ def _parse_answer_sections(answer_text: str):
11
+ lines = [l.strip() for l in (answer_text or "").splitlines() if l.strip()]
12
+ out = {
13
+ "main": "",
14
+ "citations": "",
15
+ "graph_reasoning": "",
16
+ "confidence": "",
17
+ "knobs": "",
18
+ "knobs_explain": "",
19
+ }
20
+ main_parts = []
21
+ for ln in lines:
22
+ ll = ln.lower()
23
+ if ll.startswith("citations:"):
24
+ out["citations"] = ln.split(":", 1)[1].strip()
25
+ elif ll.startswith("graph reasoning:") or ll.startswith("graphreasoning:"):
26
+ out["graph_reasoning"] = ln.split(":", 1)[1].strip()
27
+ elif ll.startswith("confidence:"):
28
+ out["confidence"] = ln.split(":", 1)[1].strip()
29
+ elif ll.startswith("knobs explain:"):
30
+ out["knobs_explain"] = ln.split(":", 1)[1].strip()
31
+ elif ll.startswith("knobs:"):
32
+ out["knobs"] = ln.split(":", 1)[1].strip()
33
+ else:
34
+ main_parts.append(ln)
35
+ out["main"] = " ".join(main_parts).strip() or (answer_text or "").strip()
36
+ return out
37
+
38
+
39
+ def _confidence_class(conf: str) -> str:
40
+ c = (conf or "").strip().lower()
41
+ if c.startswith("high"):
42
+ return "badge-high"
43
+ if c.startswith("medium"):
44
+ return "badge-medium"
45
+ if c.startswith("low"):
46
+ return "badge-low"
47
+ return "badge-none"
48
+
49
+
50
+ def _render_answer_card(answer_text: str) -> str:
51
+ sec = _parse_answer_sections(answer_text)
52
+ conf_cls = _confidence_class(sec["confidence"])
53
+ main = html.escape(sec["main"])
54
+ citations = html.escape(sec["citations"] or "None")
55
+ greason = html.escape(sec["graph_reasoning"] or "—")
56
+ conf = html.escape(sec["confidence"] or "—")
57
+ knobs = html.escape(sec["knobs"] or "—")
58
+ knobs_explain = html.escape(sec["knobs_explain"] or "—")
59
+ return f"""
60
+ <div class="card">
61
+ <div class="card-title">Answer</div>
62
+ <div class="answer">{main}</div>
63
+ <div class="meta">
64
+ <span class="badge {conf_cls}">{conf}</span>
65
+ </div>
66
+ <div class="sub"><b>Citations:</b> {citations}</div>
67
+ <div class="sub"><b>Graph reasoning:</b> {greason}</div>
68
+ <div class="sub"><b>Knobs effect:</b> {knobs}</div>
69
+ <div class="sub"><b>Knobs explain:</b> {knobs_explain}</div>
70
+ </div>
71
+ """
72
+
73
+
74
+ def _render_evidence_markdown(evidence_list):
75
+ if not evidence_list:
76
+ return "_No evidence returned._"
77
+ lines = []
78
+ for i, chunk in enumerate(evidence_list, 1):
79
+ chunk = chunk.strip()
80
+ lines.append(f"**E{i}.** {chunk}")
81
+ return "\n\n".join(lines)
82
+
83
+
84
+ def _wrap_svg(svg: str) -> str:
85
+ if not svg or "<svg" not in svg:
86
+ return "<div class='graph-empty'>No graph</div>"
87
+ return f"""<div class="graph-wrap">{svg}</div>"""
88
+
89
+
90
+ def metrics_ui():
91
+ resp = requests.get(f"{API_URL}/metrics")
92
+ try:
93
+ j = resp.json()
94
+ if j.get("status") != "ok":
95
+ return f"Error: {j}"
96
+ r = j["results"]
97
+ return f"""
98
+ ### 📊 Evaluation Results
99
+
100
+ **Baseline (cosine-only)**
101
+ - hit@10: {r['baseline']['hit@10']:.2f}
102
+ - nDCG@10: {r['baseline']['nDCG@10']:.2f}
103
+
104
+ **Hybrid (GraphRAG)**
105
+ - hit@10: {r['hybrid']['hit@10']:.2f}
106
+ - nDCG@10: {r['hybrid']['nDCG@10']:.2f}
107
+
108
+ **Other**
109
+ - Citation correctness: {r['citation_correctness']:.2f}
110
+ - Avg latency (s): {r['avg_latency_sec']:.2f}
111
+ """
112
+ except Exception as e:
113
+ return f"Error: {e}\nRaw: {resp.text[:500]}"
114
+
115
+
116
+ def add_doc_ui(text, source="user", date_val=None, time_val=None):
117
+ payload = {"text": text, "source": source}
118
+
119
+ # Build ISO timestamp if a date was picked
120
+ ts_iso = ""
121
+ if date_val:
122
+ if isinstance(date_val, _date):
123
+ dstr = date_val.isoformat()
124
+ else:
125
+ dstr = str(date_val)
126
+
127
+ tstr = (time_val or "00:00").strip()
128
+ if len(tstr) == 5: # HH:MM -> add seconds
129
+ tstr = f"{tstr}:00"
130
+ ts_iso = f"{dstr}T{tstr}Z"
131
+
132
+ if ts_iso:
133
+ payload["timestamp"] = ts_iso
134
+
135
+ resp = requests.post(f"{API_URL}/add_doc", json=payload)
136
+ try:
137
+ j = resp.json()
138
+ return "\n".join(j.get("logs", [])) or "No logs."
139
+ except Exception as e:
140
+ return f"Error: {e}\nRaw response: {resp.text[:500]}"
141
+
142
+
143
+ def ask_ui(question, w_cos, w_path, w_fresh, w_deg):
144
+ payload = {
145
+ "question": question,
146
+ "w_cos": w_cos,
147
+ "w_path": w_path,
148
+ "w_fresh": w_fresh,
149
+ "w_deg": w_deg,
150
+ }
151
+ resp = requests.post(f"{API_URL}/ask", json=payload)
152
+ try:
153
+ j = resp.json()
154
+ except Exception as e:
155
+ err = f"Error: {e}\nRaw response: {resp.text[:500]}"
156
+ return (
157
+ _render_answer_card("I don’t know based on the given evidence.\nConfidence: Low"),
158
+ "_No evidence returned._",
159
+ err,
160
+ "<div id='graph' style='height:420px'></div>",
161
+ {},
162
+ )
163
+
164
+ answer_html = _render_answer_card(j.get("answer", ""))
165
+ evidence_md = _render_evidence_markdown(j.get("evidence", []))
166
+ logs_txt = "\n".join(j.get("logs", [])) or "No logs."
167
+
168
+ # D3 container but if no data fall back to server SVG
169
+ graph_json = j.get("subgraph_json", {})
170
+ if graph_json and graph_json.get("nodes"):
171
+ graph_html_value = "<div id='graph' style='height:420px'></div>"
172
+ else:
173
+ graph_html_value = _wrap_svg(j.get("subgraph_svg", ""))
174
+
175
+ return (answer_html, evidence_md, logs_txt, graph_html_value, graph_json)
176
+
177
+
178
+ # UI
179
+ with gr.Blocks(
180
+ css="""
181
+ /* Layout & theme */
182
+ body { background: #0b0f14; color: #e6edf3; }
183
+ .gradio-container { max-width: 1180px !important; }
184
+ .section-title { font-size: 22px; font-weight: 700; margin: 6px 0 12px; }
185
+
186
+ /* Cards */
187
+ .card { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px; padding: 14px; }
188
+ .card-title { font-size: 16px; letter-spacing: .3px; color: #9fb3c8; margin-bottom: 8px; text-transform: uppercase; }
189
+ .answer { font-size: 18px; line-height: 1.5; margin-bottom: 8px; }
190
+ .sub { color: #a8b3bf; margin-top: 6px; font-size: 14px; }
191
+
192
+ /* Badges */
193
+ .badge { padding: 3px 10px; border-radius: 999px; font-size: 12px; font-weight: 700; display: inline-block; }
194
+ .badge-high { background: #12391a; color: #6ee787; border: 1px solid #285f36; }
195
+ .badge-medium { background: #3a2b13; color: #ffd277; border: 1px solid #6b4e1f; }
196
+ .badge-low { background: #3b1616; color: #ff9492; border: 1px solid #6b2020; }
197
+ .badge-none { background: #223; color: #9fb3c8; border: 1px solid #334; }
198
+
199
+ /* Graph */
200
+ .graph-wrap { background: #0f1720; border: 1px solid #1f2a36; border-radius: 14px;
201
+ padding: 12px; height: 460px; overflow: auto; }
202
+ .graph-empty { color: #9fb3c8; font-style: italic; padding: 16px; }
203
+
204
+ /* Logs */
205
+ #logs-box textarea {
206
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace !important;
207
+ max-height: 280px !important;
208
+ overflow-y: auto !important;
209
+ }
210
+ """
211
+ ) as demo:
212
+ gr.Markdown("### 🚀 GraphRAG — Live Demo")
213
+
214
+ with gr.Tab("Add Document"):
215
+ with gr.Row():
216
+ with gr.Column(scale=3):
217
+ text_in = gr.Textbox(
218
+ label="Document",
219
+ lines=10,
220
+ placeholder="Paste text to inject into Graph + Vector DB…",
221
+ )
222
+ with gr.Column(scale=1):
223
+ source_in = gr.Textbox(label="Source", value="user")
224
+
225
+ if hasattr(gr, "Date"):
226
+ ts_date = gr.Date(label="Date (optional)")
227
+ else:
228
+ ts_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)")
229
+
230
+ if hasattr(gr, "Time"):
231
+ ts_time = gr.Time(label="Time (optional)", value="00:00")
232
+ else:
233
+ ts_time = gr.Textbox(label="Time (HH:MM, optional)", value="00:00")
234
+
235
+ add_btn = gr.Button("Add Doc", variant="primary")
236
+ add_logs = gr.Textbox(label="Ingestion Logs", lines=14, elem_id="logs-box")
237
+
238
+ add_btn.click(
239
+ add_doc_ui,
240
+ inputs=[text_in, source_in, ts_date, ts_time],
241
+ outputs=add_logs
242
+ )
243
+
244
+
245
+ with gr.Tab("Ask Question"):
246
+ with gr.Row():
247
+ q_in = gr.Textbox(
248
+ label="Question", placeholder="e.g., Who acquired Instagram?"
249
+ )
250
+ ask_btn = gr.Button("Ask", variant="primary")
251
+
252
+ with gr.Accordion("Rerank Weights", open=False):
253
+ w_cos = gr.Slider(0, 1, value=0.60, step=0.05, label="Cosine weight")
254
+ w_path = gr.Slider(0, 1, value=0.20, step=0.05, label="Path proximity weight")
255
+ w_fresh = gr.Slider(0, 1, value=0.15, step=0.05, label="Freshness weight")
256
+ w_deg = gr.Slider(0, 1, value=0.05, step=0.05, label="Degree norm weight")
257
+
258
+ with gr.Row():
259
+ with gr.Column(scale=1):
260
+ gr.Markdown("<div class='section-title'>Answer</div>")
261
+ ans_html = gr.HTML(value=_render_answer_card("Ask something to see results."))
262
+
263
+ evid = gr.Accordion("Evidence (ranked)", open=True)
264
+ with evid:
265
+ evid_md = gr.Markdown()
266
+
267
+ logs = gr.Accordion("Debug logs", open=False)
268
+ with logs:
269
+ logs_txt = gr.Textbox(lines=14, elem_id="logs-box")
270
+
271
+ with gr.Column(scale=1):
272
+ gr.Markdown("<div class='section-title'>Evidence Graph</div>")
273
+ graph_html = gr.HTML(value="<div id='graph' style='height:600px'></div>")
274
+ graph_data = gr.JSON(label="graph-data", visible=False)
275
+
276
+ ask_btn.click(
277
+ ask_ui,
278
+ inputs=[q_in, w_cos, w_path, w_fresh, w_deg],
279
+ outputs=[ans_html, evid_md, logs_txt, graph_html, graph_data],
280
+ )
281
+
282
+ with gr.Tab("Metrics"):
283
+ metrics_btn = gr.Button("Run Evaluation", variant="primary")
284
+ metrics_out = gr.Markdown("Click run to evaluate baseline vs hybrid.")
285
+ metrics_btn.click(metrics_ui, inputs=[], outputs=metrics_out)
286
+
287
+ # D3 rendering for the knowledge graph
288
+ DRAW_JS = r"""
289
+ (value) => {
290
+ const el = document.querySelector("#graph");
291
+ if (!el) return null;
292
+ el.innerHTML = "";
293
+
294
+ if (!value || !value.nodes || value.nodes.length === 0) {
295
+ el.innerHTML = "<div class='graph-empty'>No graph</div>";
296
+ return null;
297
+ }
298
+
299
+ function ensureD3(cb) {
300
+ if (window.d3) return cb();
301
+ const s = document.createElement("script");
302
+ s.src = "https://cdn.jsdelivr.net/npm/d3@7";
303
+ s.onload = cb;
304
+ document.head.appendChild(s);
305
+ }
306
+
307
+ ensureD3(() => {
308
+ const width = el.clientWidth || 900;
309
+ const height = 600;
310
+
311
+ const svg = d3.select(el).append("svg")
312
+ .attr("viewBox", [0, 0, width, height])
313
+ .attr("preserveAspectRatio", "xMidYMid meet")
314
+ .style("width", "100%")
315
+ .style("height", "100%");
316
+
317
+ // Create zoomable container
318
+ const container = svg.append("g");
319
+
320
+ // Enable zoom & pan
321
+ svg.call(
322
+ d3.zoom()
323
+ .scaleExtent([0.2, 3]) // zoom limits (20%–300%)
324
+ .on("zoom", (event) => {
325
+ container.attr("transform", event.transform);
326
+ })
327
+ );
328
+
329
+ const sim = d3.forceSimulation(value.nodes)
330
+ .force("link", d3.forceLink(value.links).id(d => d.id).distance(140).strength(0.4))
331
+ .force("charge", d3.forceManyBody().strength(-220))
332
+ .force("center", d3.forceCenter(width / 2, height / 2));
333
+
334
+ const link = container.append("g")
335
+ .attr("stroke", "#999")
336
+ .attr("stroke-opacity", 0.6)
337
+ .selectAll("line")
338
+ .data(value.links)
339
+ .enter().append("line")
340
+ .attr("stroke-width", 1.5);
341
+
342
+ const edgeLabels = container.append("g")
343
+ .selectAll("text")
344
+ .data(value.links)
345
+ .enter().append("text")
346
+ .attr("font-size", 10)
347
+ .attr("fill", "#bbb")
348
+ .text(d => d.label);
349
+
350
+ const node = container.append("g")
351
+ .selectAll("circle")
352
+ .data(value.nodes)
353
+ .enter().append("circle")
354
+ .attr("r", 12)
355
+ .attr("fill", "#69b3a2")
356
+ .attr("stroke", "#2dd4bf")
357
+ .attr("stroke-width", 1.2)
358
+ .call(d3.drag()
359
+ .on("start", (event, d) => { if (!event.active) sim.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y; })
360
+ .on("drag", (event, d) => { d.fx = event.x; d.fy = event.y; })
361
+ .on("end", (event, d) => { if (!event.active) sim.alphaTarget(0); d.fx = null; d.fy = null; })
362
+ );
363
+
364
+ const labels = container.append("g")
365
+ .selectAll("text")
366
+ .data(value.nodes)
367
+ .enter().append("text")
368
+ .attr("font-size", 12)
369
+ .attr("fill", "#ddd")
370
+ .attr("dy", 18)
371
+ .attr("text-anchor", "middle")
372
+ .text(d => d.id);
373
+
374
+ sim.on("tick", () => {
375
+ link
376
+ .attr("x1", d => d.source.x)
377
+ .attr("y1", d => d.source.y)
378
+ .attr("x2", d => d.target.x)
379
+ .attr("y2", d => d.target.y);
380
+
381
+ edgeLabels
382
+ .attr("x", d => (d.source.x + d.target.x) / 2)
383
+ .attr("y", d => (d.source.y + d.target.y) / 2);
384
+
385
+ node
386
+ .attr("cx", d => d.x)
387
+ .attr("cy", d => d.y);
388
+
389
+ labels
390
+ .attr("x", d => d.x)
391
+ .attr("y", d => d.y);
392
+ });
393
+ });
394
+
395
+ return null;
396
+ }
397
+ """
398
+
399
+
400
+ graph_data.change(lambda x: x, inputs=graph_data, outputs=graph_data).then(
401
+ None, inputs=graph_data, outputs=None, js=DRAW_JS
402
+ )
403
+
404
+ if __name__ == "__main__":
405
+ demo.launch()
406
+
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ import math
3
+
4
+ def compute_freshness(ts, half_life_days=30):
5
+ """
6
+ Exponential decay score between 0 and 1.
7
+ - 1.0 -> now
8
+ - 0.5 -> half_life_days old
9
+ - Approaches 0 -> when docs get very old
10
+ """
11
+ if not ts:
12
+ return 0.5
13
+ if isinstance(ts, str):
14
+ ts = ts.replace("Z", "+00:00")
15
+ try:
16
+ ts = datetime.fromisoformat(ts)
17
+ except Exception:
18
+ return 0.5
19
+ if ts.tzinfo is None:
20
+ ts = ts.replace(tzinfo=timezone.utc)
21
+
22
+ age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400.0
23
+ lam = math.log(2) / max(float(half_life_days), 1.0)
24
+ return math.exp(-lam * age_days)
vec.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from dotenv import load_dotenv
4
+ from sentence_transformers import SentenceTransformer
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client.http import models
7
+
8
+ import uuid
9
+ from datetime import datetime, timezone
10
+
11
+ load_dotenv()
12
+ QDRANT_URL = os.getenv("QDRANT_URL")
13
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
14
+ COLLECTION = "docs"
15
+
16
+ model = SentenceTransformer("all-MiniLM-L6-v2") #fast embedder
17
+ qdrant = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
18
+
19
+ # Ensure collection exists (safe init .... we dont wipe on reload)
20
+ def _ensure_collection():
21
+ try:
22
+ coll_info = qdrant.get_collection(collection_name=COLLECTION)
23
+ if not coll_info:
24
+ raise Exception("Collection not found")
25
+ except Exception:
26
+ print(f"⚠️ Collection '{COLLECTION}' not found. Creating fresh collection...")
27
+ qdrant.create_collection(
28
+ collection_name=COLLECTION,
29
+ vectors_config=models.VectorParams(
30
+ size=384,
31
+ distance=models.Distance.COSINE
32
+ ),
33
+ )
34
+
35
+ _ensure_collection()
36
+
37
+ def embed_and_upsert(chunks, source="user", timestamp=None):
38
+ if timestamp is None:
39
+ timestamp = datetime.now(timezone.utc).isoformat()
40
+
41
+ embeddings = model.encode(chunks).tolist()
42
+ points = []
43
+ for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
44
+ points.append(
45
+ models.PointStruct(
46
+ id=str(uuid.uuid4()),
47
+ vector=emb,
48
+ payload={
49
+ "text": chunk,
50
+ "source": source,
51
+ "timestamp": timestamp,
52
+ "chunk_id": i
53
+ }
54
+ )
55
+ )
56
+ qdrant.upsert(collection_name=COLLECTION, points=points, wait=True)
57
+ print(f"✅ Stored {len(points)} chunks in Qdrant (source={source}, ts={timestamp})")
58
+ return True
59
+
60
+ def search(query: str, top_k: int = 5):
61
+ q_emb = model.encode([query])[0].tolist()
62
+ results = qdrant.search(
63
+ collection_name=COLLECTION,
64
+ query_vector=q_emb,
65
+ limit=top_k,
66
+ with_payload=True
67
+ )
68
+ return [
69
+ {
70
+ "chunk": r.payload.get("text", ""),
71
+ "cosine": r.score,
72
+ "timestamp": r.payload.get("timestamp"),
73
+ "source": r.payload.get("source")
74
+ }
75
+ for r in results
76
+ ]