import os import json import time import gradio as gr import networkx as nx from groq import Groq from dotenv import load_dotenv load_dotenv() # ----------------------------- # API # ----------------------------- client = Groq(api_key=os.getenv("GROQ_API_KEY")) # ----------------------------- # Fallback data (avoid startup crash) # ----------------------------- if not os.path.exists("data"): os.makedirs("data", exist_ok=True) if not os.path.exists("data/chunks.json"): dummy_chunks = [ { "chunk_id": "demo_1", "company": "apple", "doc_name": "sample_doc", "period": "2024", "text": "Apple reported strong revenue growth in FY2024.", "question": "", "answer": "", "source": "demo" } ] with open("data/chunks.json", "w", encoding="utf-8") as f: json.dump(dummy_chunks, f) if not os.path.exists("data/entities.json"): with open("data/entities.json", "w", encoding="utf-8") as f: json.dump([], f) # ----------------------------- # Load data # ----------------------------- with open("data/chunks.json", encoding="utf-8") as f: chunks = json.load(f) with open("data/entities.json", encoding="utf-8") as f: entities = json.load(f) # ----------------------------- # Build graph # ----------------------------- def build_graph(): G = nx.Graph() for c in chunks: company = c.get("company", "unknown") doc = c.get("doc_name", "unknown_doc") cid = c.get("chunk_id", "unknown_chunk") text = c.get("text", "") G.add_node(cid, type="chunk", text=text, company=company) G.add_node(company, type="company") G.add_node(doc, type="filing") G.add_edge(company, doc, rel="FILED") G.add_edge(doc, cid, rel="CONTAINS") return G G = build_graph() COMPANIES = [ "apple", "microsoft", "amazon", "google", "tesla", "3m", "boeing", "amd" ] STOPWORDS = { "what", "was", "the", "in", "of", "a", "an", "is", "are", "how", "did", "does", "we", "if", "that", "you", "as", "based", "on", "which", "has", "for", "by", "per" } # ----------------------------- # LLM Query # ----------------------------- def query_llm(question): t0 = time.time() r = client.chat.completions.create( model="llama3-70b-8192", messages=[{"role": "user", "content": question}], max_tokens=200 ) return ( r.choices[0].message.content, r.usage.total_tokens, round(time.time() - t0, 2) ) # ----------------------------- # GraphRAG Query # ----------------------------- def query_graphrag(question): t0 = time.time() keywords = [ w.lower().strip("?.,") for w in question.split() if w.lower() not in STOPWORDS and len(w) > 2 ] chunk_scores = {} for node, data in G.nodes(data=True): if data.get("type") != "chunk": continue text = ( data.get("text", "") + " " + data.get("company", "") ).lower() score = sum(1 for kw in keywords if kw in text) if score > 0: chunk_scores[node] = score top = sorted( chunk_scores, key=chunk_scores.get, reverse=True )[:2] if not top: top = [ n for n, d in G.nodes(data=True) if d.get("type") == "chunk" ][:2] context = "\n\n".join([ G.nodes[n].get("text", "") for n in top ]) prompt = f""" Context: {context} Question: {question} Answer: """ r = client.chat.completions.create( model="llama3-70b-8192", messages=[{"role": "user", "content": prompt}], max_tokens=200 ) return ( r.choices[0].message.content, r.usage.total_tokens, round(time.time() - t0, 2) ) # ----------------------------- # Compare # ----------------------------- def compare(question): if not question.strip(): return "-", "-", "-", "-", "-", "-", "-" llm_ans, llm_tok, llm_lat = query_llm(question) grag_ans, grag_tok, grag_lat = query_graphrag(question) reduction = ( round((llm_tok - grag_tok) / llm_tok * 100, 1) if llm_tok > 0 else 0 ) return ( llm_ans, str(llm_tok), f"{llm_lat}s", grag_ans, str(grag_tok), f"{grag_lat}s", f"{reduction}% token reduction" ) # ----------------------------- # UI # ----------------------------- with gr.Blocks(title="GraphRAG vs LLM") as demo: gr.Markdown("# GraphRAG vs LLM") inp = gr.Textbox( label="Question", placeholder="What was Apple's revenue?" ) btn = gr.Button("Run") with gr.Row(): with gr.Column(): gr.Markdown("### LLM Only") llm_ans = gr.Textbox(lines=5) llm_tok = gr.Textbox(label="Tokens") llm_lat = gr.Textbox(label="Latency") with gr.Column(): gr.Markdown("### GraphRAG") grag_ans = gr.Textbox(lines=5) grag_tok = gr.Textbox(label="Tokens") grag_lat = gr.Textbox(label="Latency") summary = gr.Textbox(label="Token Reduction") btn.click( compare, inputs=inp, outputs=[ llm_ans, llm_tok, llm_lat, grag_ans, grag_tok, grag_lat, summary ] ) demo.launch(server_name="0.0.0.0", server_port=7860, share=True)