| import os |
| import json |
| import time |
| import gradio as gr |
| import networkx as nx |
| from groq import Groq |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| |
| |
| |
| client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
| |
| |
| |
| if not os.path.exists("data"): |
| os.makedirs("data", exist_ok=True) |
|
|
| if not os.path.exists("data/chunks.json"): |
| dummy_chunks = [ |
| { |
| "chunk_id": "demo_1", |
| "company": "apple", |
| "doc_name": "sample_doc", |
| "period": "2024", |
| "text": "Apple reported strong revenue growth in FY2024.", |
| "question": "", |
| "answer": "", |
| "source": "demo" |
| } |
| ] |
|
|
| with open("data/chunks.json", "w", encoding="utf-8") as f: |
| json.dump(dummy_chunks, f) |
|
|
| if not os.path.exists("data/entities.json"): |
| with open("data/entities.json", "w", encoding="utf-8") as f: |
| json.dump([], f) |
|
|
| |
| |
| |
| with open("data/chunks.json", encoding="utf-8") as f: |
| chunks = json.load(f) |
|
|
| with open("data/entities.json", encoding="utf-8") as f: |
| entities = json.load(f) |
|
|
| |
| |
| |
| def build_graph(): |
| G = nx.Graph() |
|
|
| for c in chunks: |
| company = c.get("company", "unknown") |
| doc = c.get("doc_name", "unknown_doc") |
| cid = c.get("chunk_id", "unknown_chunk") |
| text = c.get("text", "") |
|
|
| G.add_node(cid, type="chunk", text=text, company=company) |
| G.add_node(company, type="company") |
| G.add_node(doc, type="filing") |
|
|
| G.add_edge(company, doc, rel="FILED") |
| G.add_edge(doc, cid, rel="CONTAINS") |
|
|
| return G |
|
|
|
|
| G = build_graph() |
|
|
| COMPANIES = [ |
| "apple", "microsoft", "amazon", |
| "google", "tesla", "3m", |
| "boeing", "amd" |
| ] |
|
|
| STOPWORDS = { |
| "what", "was", "the", "in", |
| "of", "a", "an", "is", "are", |
| "how", "did", "does", "we", |
| "if", "that", "you", "as", |
| "based", "on", "which", "has", |
| "for", "by", "per" |
| } |
|
|
| |
| |
| |
| def query_llm(question): |
| t0 = time.time() |
|
|
| r = client.chat.completions.create( |
| model="llama3-70b-8192", |
| messages=[{"role": "user", "content": question}], |
| max_tokens=200 |
| ) |
|
|
| return ( |
| r.choices[0].message.content, |
| r.usage.total_tokens, |
| round(time.time() - t0, 2) |
| ) |
|
|
| |
| |
| |
| def query_graphrag(question): |
| t0 = time.time() |
|
|
| keywords = [ |
| w.lower().strip("?.,") |
| for w in question.split() |
| if w.lower() not in STOPWORDS and len(w) > 2 |
| ] |
|
|
| chunk_scores = {} |
|
|
| for node, data in G.nodes(data=True): |
| if data.get("type") != "chunk": |
| continue |
|
|
| text = ( |
| data.get("text", "") + |
| " " + |
| data.get("company", "") |
| ).lower() |
|
|
| score = sum(1 for kw in keywords if kw in text) |
|
|
| if score > 0: |
| chunk_scores[node] = score |
|
|
| top = sorted( |
| chunk_scores, |
| key=chunk_scores.get, |
| reverse=True |
| )[:2] |
|
|
| if not top: |
| top = [ |
| n for n, d in G.nodes(data=True) |
| if d.get("type") == "chunk" |
| ][:2] |
|
|
| context = "\n\n".join([ |
| G.nodes[n].get("text", "") |
| for n in top |
| ]) |
|
|
| prompt = f""" |
| Context: |
| {context} |
| |
| Question: |
| {question} |
| |
| Answer: |
| """ |
|
|
| r = client.chat.completions.create( |
| model="llama3-70b-8192", |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=200 |
| ) |
|
|
| return ( |
| r.choices[0].message.content, |
| r.usage.total_tokens, |
| round(time.time() - t0, 2) |
| ) |
|
|
| |
| |
| |
| def compare(question): |
| if not question.strip(): |
| return "-", "-", "-", "-", "-", "-", "-" |
|
|
| llm_ans, llm_tok, llm_lat = query_llm(question) |
| grag_ans, grag_tok, grag_lat = query_graphrag(question) |
|
|
| reduction = ( |
| round((llm_tok - grag_tok) / llm_tok * 100, 1) |
| if llm_tok > 0 else 0 |
| ) |
|
|
| return ( |
| llm_ans, |
| str(llm_tok), |
| f"{llm_lat}s", |
| grag_ans, |
| str(grag_tok), |
| f"{grag_lat}s", |
| f"{reduction}% token reduction" |
| ) |
|
|
| |
| |
| |
| with gr.Blocks(title="GraphRAG vs LLM") as demo: |
| gr.Markdown("# GraphRAG vs LLM") |
|
|
| inp = gr.Textbox( |
| label="Question", |
| placeholder="What was Apple's revenue?" |
| ) |
|
|
| btn = gr.Button("Run") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### LLM Only") |
| llm_ans = gr.Textbox(lines=5) |
| llm_tok = gr.Textbox(label="Tokens") |
| llm_lat = gr.Textbox(label="Latency") |
|
|
| with gr.Column(): |
| gr.Markdown("### GraphRAG") |
| grag_ans = gr.Textbox(lines=5) |
| grag_tok = gr.Textbox(label="Tokens") |
| grag_lat = gr.Textbox(label="Latency") |
|
|
| summary = gr.Textbox(label="Token Reduction") |
|
|
| btn.click( |
| compare, |
| inputs=inp, |
| outputs=[ |
| llm_ans, |
| llm_tok, |
| llm_lat, |
| grag_ans, |
| grag_tok, |
| grag_lat, |
| summary |
| ] |
| ) |
|
|
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |