GraphRAG / app.py
Sanjam19's picture
Update app.py
567197e verified
Raw
History Blame Contribute Delete
5.59 kB
import os
import json
import time
import gradio as gr
import networkx as nx
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
# -----------------------------
# API
# -----------------------------
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# -----------------------------
# Fallback data (avoid startup crash)
# -----------------------------
if not os.path.exists("data"):
os.makedirs("data", exist_ok=True)
if not os.path.exists("data/chunks.json"):
dummy_chunks = [
{
"chunk_id": "demo_1",
"company": "apple",
"doc_name": "sample_doc",
"period": "2024",
"text": "Apple reported strong revenue growth in FY2024.",
"question": "",
"answer": "",
"source": "demo"
}
]
with open("data/chunks.json", "w", encoding="utf-8") as f:
json.dump(dummy_chunks, f)
if not os.path.exists("data/entities.json"):
with open("data/entities.json", "w", encoding="utf-8") as f:
json.dump([], f)
# -----------------------------
# Load data
# -----------------------------
with open("data/chunks.json", encoding="utf-8") as f:
chunks = json.load(f)
with open("data/entities.json", encoding="utf-8") as f:
entities = json.load(f)
# -----------------------------
# Build graph
# -----------------------------
def build_graph():
G = nx.Graph()
for c in chunks:
company = c.get("company", "unknown")
doc = c.get("doc_name", "unknown_doc")
cid = c.get("chunk_id", "unknown_chunk")
text = c.get("text", "")
G.add_node(cid, type="chunk", text=text, company=company)
G.add_node(company, type="company")
G.add_node(doc, type="filing")
G.add_edge(company, doc, rel="FILED")
G.add_edge(doc, cid, rel="CONTAINS")
return G
G = build_graph()
COMPANIES = [
"apple", "microsoft", "amazon",
"google", "tesla", "3m",
"boeing", "amd"
]
STOPWORDS = {
"what", "was", "the", "in",
"of", "a", "an", "is", "are",
"how", "did", "does", "we",
"if", "that", "you", "as",
"based", "on", "which", "has",
"for", "by", "per"
}
# -----------------------------
# LLM Query
# -----------------------------
def query_llm(question):
t0 = time.time()
r = client.chat.completions.create(
model="llama3-70b-8192",
messages=[{"role": "user", "content": question}],
max_tokens=200
)
return (
r.choices[0].message.content,
r.usage.total_tokens,
round(time.time() - t0, 2)
)
# -----------------------------
# GraphRAG Query
# -----------------------------
def query_graphrag(question):
t0 = time.time()
keywords = [
w.lower().strip("?.,")
for w in question.split()
if w.lower() not in STOPWORDS and len(w) > 2
]
chunk_scores = {}
for node, data in G.nodes(data=True):
if data.get("type") != "chunk":
continue
text = (
data.get("text", "") +
" " +
data.get("company", "")
).lower()
score = sum(1 for kw in keywords if kw in text)
if score > 0:
chunk_scores[node] = score
top = sorted(
chunk_scores,
key=chunk_scores.get,
reverse=True
)[:2]
if not top:
top = [
n for n, d in G.nodes(data=True)
if d.get("type") == "chunk"
][:2]
context = "\n\n".join([
G.nodes[n].get("text", "")
for n in top
])
prompt = f"""
Context:
{context}
Question:
{question}
Answer:
"""
r = client.chat.completions.create(
model="llama3-70b-8192",
messages=[{"role": "user", "content": prompt}],
max_tokens=200
)
return (
r.choices[0].message.content,
r.usage.total_tokens,
round(time.time() - t0, 2)
)
# -----------------------------
# Compare
# -----------------------------
def compare(question):
if not question.strip():
return "-", "-", "-", "-", "-", "-", "-"
llm_ans, llm_tok, llm_lat = query_llm(question)
grag_ans, grag_tok, grag_lat = query_graphrag(question)
reduction = (
round((llm_tok - grag_tok) / llm_tok * 100, 1)
if llm_tok > 0 else 0
)
return (
llm_ans,
str(llm_tok),
f"{llm_lat}s",
grag_ans,
str(grag_tok),
f"{grag_lat}s",
f"{reduction}% token reduction"
)
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title="GraphRAG vs LLM") as demo:
gr.Markdown("# GraphRAG vs LLM")
inp = gr.Textbox(
label="Question",
placeholder="What was Apple's revenue?"
)
btn = gr.Button("Run")
with gr.Row():
with gr.Column():
gr.Markdown("### LLM Only")
llm_ans = gr.Textbox(lines=5)
llm_tok = gr.Textbox(label="Tokens")
llm_lat = gr.Textbox(label="Latency")
with gr.Column():
gr.Markdown("### GraphRAG")
grag_ans = gr.Textbox(lines=5)
grag_tok = gr.Textbox(label="Tokens")
grag_lat = gr.Textbox(label="Latency")
summary = gr.Textbox(label="Token Reduction")
btn.click(
compare,
inputs=inp,
outputs=[
llm_ans,
llm_tok,
llm_lat,
grag_ans,
grag_tok,
grag_lat,
summary
]
)
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)