GraphRAG-Live / app.py
aayush226's picture
Upload 10 files
4d9fcca verified
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Literal
from datetime import datetime
import os, json
from text import chunk_text
from vec import embed_and_upsert, search
from kg import (
extract_and_insert,
get_subgraph,
compute_path_proximity,
compute_degree_norm,
)
from rerank import rerank_candidates
from eval import evaluate
from utils import compute_freshness
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv(override=True)
key = os.environ.get("OPENAI_API_KEY", "").strip()
client = OpenAI(api_key=key)
app = FastAPI()
# Schemas for Pydantic + structured output
class DocInput(BaseModel):
text: str
source: str = "user"
timestamp: datetime = datetime.now()
class QuestionInput(BaseModel):
question: str
w_cos: float = 0.60
w_path: float = 0.20
w_fresh: float = 0.15
w_deg: float = 0.05
# LLM output requirement (enforceing this with JSON output + Pydantic)
class LLMAnswer(BaseModel):
answer: str = Field(..., description="One-sentence final answer")
citations: List[str] = Field(
default_factory=list,
description="Evidence IDs like E1, E3 that support the answer",
)
graph_reasoning: str = Field(
"", description="How the graph helped, or 'Not used'"
)
confidence: Literal["High", "Medium", "Low"] = "Low"
# Helpers for the explanation on the controls (weights)
def _get_scores(c, w_cos, w_path, w_fresh, w_deg):
cos = float(c.get("cosine", c.get("cosine_sim", 0.0)) or 0.0)
pp = float(c.get("path_proximity", 0.0) or 0.0)
fr = float(c.get("freshness_decay", 0.0) or 0.0)
dg = float(c.get("degree_norm", 0.0) or 0.0)
final = w_cos * cos + w_path * pp + w_fresh * fr + w_deg * dg
return cos, pp, fr, dg, final
def _build_knobs_breakdown(numbered, w_cos, w_path, w_fresh, w_deg):
"""
Returns (knobs_line, knobs_explain) strings. Uses top 1 only and runner up if available.
"""
if not numbered:
return "", ""
idx1, c1 = numbered[0]
cos1, pp1, fr1, dg1, fin1 = _get_scores(c1, w_cos, w_path, w_fresh, w_deg)
# Optional runner up
ru_piece, explain = "", ""
if len(numbered) > 1:
idx2, c2 = numbered[1]
cos2, pp2, fr2, dg2, fin2 = _get_scores(c2, w_cos, w_path, w_fresh, w_deg)
margin = fin1 - fin2
ru_piece = f"; Runner-up E{idx2}={fin2:.3f}; Margin={margin:+.3f}"
# Contribution of the deltas (weighted)
deltas = [
("path", w_path * (pp1 - pp2), pp1, pp2, w_path),
("freshness", w_fresh * (fr1 - fr2), fr1, fr2, w_fresh),
("cosine", w_cos * (cos1 - cos2), cos1, cos2, w_cos),
("degree", w_deg * (dg1 - dg2), dg1, dg2, w_deg),
]
deltas.sort(key=lambda x: x[1], reverse=True)
# Pick top positive drivers
drivers = [f"{name} ({d:+.3f})" for name, d, *_ in deltas if d > 0.002][:3]
# A short natural language sentence
if drivers:
top_names = ", ".join(drivers)
else:
top_names = "mostly cosine similarity (others were negligible)"
explain = (
f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
f"E{idx1} leads by {margin:+.3f}. Biggest lifts vs E{idx2}: {top_names}."
)
else:
# No runner up but sstill provide a brief note
explain = (
f"With weights (cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}), "
f"the top candidate E{idx1} scored {fin1:.3f}."
)
knobs_line = (
f"Weights→ cos {w_cos:.2f}, path {w_path:.2f}, fresh {w_fresh:.2f}, deg {w_deg:.2f}. "
f"E{idx1} final={fin1:.3f} = {w_cos:.2f}×{cos1:.3f} + {w_path:.2f}×{pp1:.3f} + "
f"{w_fresh:.2f}×{fr1:.3f} + {w_deg:.2f}×{dg1:.3f}{ru_piece}; Cosine-only(E{idx1})={cos1:.3f}."
)
return knobs_line, explain
# API Endpoints
@app.get("/metrics")
def metrics_endpoint():
logs = []
try:
results = evaluate()
logs.append("✅ Ran evaluation set")
return {"status": "ok", "results": results, "logs": logs}
except Exception as e:
logs.append(f"⚠️ Metrics failed: {e}")
return {"status": "error", "logs": logs}
@app.post("/add_doc")
def add_doc_endpoint(doc: DocInput):
logs = ["📥 Received document"]
text, source, timestamp = doc.text, doc.source, doc.timestamp
# 1) Chunk
chunks = chunk_text(text)
logs.append(f"✂️ Chunked into {len(chunks)} pieces")
# 2) Embed + store
embed_and_upsert(chunks, source=source, timestamp=timestamp.isoformat())
logs.append(f"🧮 Embedded + stored in Qdrant (source={source}, ts={timestamp})")
# 3) Extract triples and feed to Neo4j
neo4j_logs = extract_and_insert(chunks, source=source, timestamp=str(timestamp))
logs.extend(neo4j_logs or ["🌐 No entities/relations extracted for Neo4j"])
return {"status": "ok", "logs": logs}
@app.post("/ask")
def ask_endpoint(query: QuestionInput):
logs = []
q = query.question
logs.append(f"❓ Received question: {q}")
# Retrieve
candidates = search(q, top_k=5)
logs.append(f"🔎 Retrieved {len(candidates)} from Qdrant")
# Graph aware features??
for c in candidates:
c["path_proximity"] = compute_path_proximity(q, c["chunk"])
c["degree_norm"] = compute_degree_norm(c["chunk"])
c["freshness_decay"] = compute_freshness(c.get("timestamp"))
# Rerank
reranked, rerank_logs = rerank_candidates(
candidates,
w_cos=query.w_cos,
w_path=query.w_path,
w_fresh=query.w_fresh,
w_deg=query.w_deg,
)
logs.append("📊 Applied graph-aware re-ranking")
logs.extend(rerank_logs)
# Evidence subgraph (≤2 hops)
triples = get_subgraph(q, source=None)
logs.append(f"🌐 Subgraph triples: {len(triples)}")
# Prepare evidence numbering for citations
numbered = [(i + 1, c) for i, c in enumerate(reranked)]
TOP_N = 2 # TODO -> expermient with more
reranked = reranked[:TOP_N]
numbered = [(i + 1, c) for i, c in enumerate(reranked)]
evidence_for_prompt = [f"[E{i}] {c['chunk']}" for i, c in numbered]
evidence_for_ui = [f"[E{i}] {c['chunk']}" for i, c in numbered]
knobs_line, knobs_explain = _build_knobs_breakdown(
numbered, query.w_cos, query.w_path, query.w_fresh, query.w_deg
)
# LLM answer (OpenAI, structured JSON -> Pydantic)
if reranked:
triples_text = "\n".join([f"({s}) -[{r}]-> ({o})" for s, r, o in triples])
# Schema friendly request
prompt = f"""
You are a precise QA assistant that MUST use BOTH the retrieved evidence and the graph triples.
Question:
{q}
Retrieved Evidence (ranked by importance, highest first):
{chr(10).join(evidence_for_prompt)}
Knowledge Graph Triples:
{triples_text}
Instructions:
- E1 is the most relevant, E2 is second-most, and so on.
- Prefer evidence with a lower number if multiple sources conflict.
- If supported, produce a single-sentence answer.
- Cite supporting evidence IDs (e.g., E1, E2).
- If the graph helped, say how; else "Not used".
- If not supported, return "I don’t know..." with Low confidence.
Return ONLY a JSON object matching this schema:
{{
"answer": "string",
"citations": ["E1","E2"],
"graph_reasoning": "string",
"confidence": "High|Medium|Low"
}}
""".strip()
logs.append("📝 Built prompt with evidence + graph")
try:
comp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "Respond ONLY with a JSON object."},
{"role": "user", "content": prompt},
],
# Ensures valid JSON
response_format={"type": "json_object"},
temperature=0,
max_tokens=300,
)
raw_json = comp.choices[0].message.content or "{}"
data = json.loads(raw_json)
# Validate and normalize with Pydantic
parsed = LLMAnswer.model_validate(data)
# Build display string for the UI card
citations_txt = ", ".join(parsed.citations) if parsed.citations else "None"
answer_text = (
f"{parsed.answer}\n"
f"Citations: {citations_txt}\n"
f"Graph reasoning: {parsed.graph_reasoning or '—'}\n"
f"Confidence: {parsed.confidence}\n"
f"Knobs: {knobs_line or '—'}\n"
f"Knobs explain: {knobs_explain or '—'}"
)
answer = answer_text
logs.append("🤖 Called OpenAI")
logs.append("🧠 Generated final answer")
except Exception as e:
top_chunk = reranked[0]["chunk"] if reranked else "No evidence"
answer = (
f"Based on evidence: {top_chunk}\n"
f"Citations: None\n"
f"Graph reasoning: Not used\n"
f"Confidence: Low\n"
f"Knobs: {knobs_line or '—'}\n"
f"Knobs explain: {knobs_explain or '—'}"
)
logs.append(f"⚠️ OpenAI failed, fallback to stub ({e})")
else:
answer = (
"No evidence found.\n"
"Citations: None\n"
"Graph reasoning: Not used\n"
"Confidence: Low\n"
f"Knobs: {knobs_line or '—'}\n"
f"Knobs explain: {knobs_explain or '—'}"
)
evidence_for_ui = []
logs.append("⚠️ No evidence, answer is empty")
# Build D3 JSON
node_map = {}
links = []
for s, r, o in triples:
node_map.setdefault(s, {"id": s})
node_map.setdefault(o, {"id": o})
links.append({"source": s, "target": o, "label": r})
subgraph_json = {"nodes": list(node_map.values()), "links": links}
# Server side SVG fallback in case D3 fails to render
import networkx as nx
G = nx.DiGraph()
for s, r, o in triples:
G.add_node(s)
G.add_node(o)
G.add_edge(s, o, label=r)
pos = nx.spring_layout(G, seed=42)
width, height, pad = 720, 420, 40
xs = [p[0] for p in pos.values()] or [0.0]
ys = [p[1] for p in pos.values()] or [0.0]
minx, maxx = min(xs), max(xs)
miny, maxy = min(ys), max(ys)
rangex = (maxx - minx) or 1.0
rangey = (maxy - miny) or 1.0
def sx(x): return pad + (x - minx) / rangex * (width - 2 * pad)
def sy(y): return pad + (y - miny) / rangey * (height - 2 * pad)
parts = []
parts.append(
f'<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" '
f'xmlns="http://www.w3.org/2000/svg">'
)
parts.append(
"""
<defs>
<marker id="arrow" markerUnits="strokeWidth" markerWidth="10" markerHeight="8"
viewBox="0 0 10 8" refX="10" refY="4" orient="auto">
<path d="M0 0 L10 4 L0 8 z" fill="#999"/>
</marker>
<style>
.edge { stroke:#999; stroke-width:1.5; }
.nodelabel { font:12px sans-serif; fill:#ddd; }
.edgelabel { font:10px sans-serif; fill:#bbb; }
.node { fill:#69b3a2; stroke:#2dd4bf; stroke-width:1; }
</style>
</defs>
"""
)
for u, v, data in G.edges(data=True):
x1, y1 = sx(pos[u][0]), sy(pos[u][1])
x2, y2 = sx(pos[v][0]), sy(pos[v][1])
parts.append(
f'<line class="edge" x1="{x1:.1f}" y1="{y1:.1f}" '
f'x2="{x2:.1f}" y2="{y2:.1f}" marker-end="url(#arrow)"/>'
)
mx, my = (x1 + x2) / 2.0, (y1 + y2) / 2.0
lbl = (data.get("label") or "").replace("&", "&amp;").replace("<", "&lt;")
parts.append(
f'<text class="edgelabel" x="{mx:.1f}" y="{my:.1f}" text-anchor="middle">{lbl}</text>'
)
for n in G.nodes():
x, y = sx(pos[n][0]), sy(pos[n][1])
node_txt = str(n).replace("&", "&amp;").replace("<", "&lt;")
r = max(16, len(node_txt) * 4)
parts.append(f'<circle class="node" cx="{x:.1f}" cy="{y:.1f}" r="{r}"/>')
parts.append(
f'<text class="nodelabel" x="{x:.1f}" y="{y + r + 14:.1f}" text-anchor="middle">{node_txt}</text>'
)
parts.append("</svg>")
subgraph_svg = "".join(parts)
logs.append(f"📦 Subgraph JSON dump: {subgraph_json}")
return {
"answer": answer,
"evidence": evidence_for_ui,
"subgraph_svg": subgraph_svg, # fallback
"subgraph_json": subgraph_json, # for D3 in UI
"logs": logs,
}
@app.get("/healthz")
def healthz():
return {"ok": True}