import json import networkx as nx import os from groq import Groq from dotenv import load_dotenv load_dotenv() client = Groq(api_key=os.getenv("GROQ_API_KEY")) def build_graph() -> nx.Graph: chunks = json.load( open( "data/chunks.json", encoding="utf-8" ) ) entities = json.load( open( "data/entities.json", encoding="utf-8" ) ) G = nx.Graph() for c in chunks: G.add_node( c["chunk_id"], type="chunk", text=c["text"], company=c["company"], doc=c["doc_name"] ) G.add_node( c["company"], type="company" ) G.add_node( c["doc_name"], type="filing" ) G.add_edge( c["company"], c["doc_name"], rel="FILED" ) G.add_edge( c["doc_name"], c["chunk_id"], rel="CONTAINS" ) for e in entities: for ent in e["entities"]: ent_id = ( f"{ent['label']}_" f"{ent['text'][:50]}" ) G.add_node( ent_id, type="entity", label=ent["label"], text=ent["text"] ) G.add_edge( e["chunk_id"], ent_id, rel="MENTIONS" ) return G def query_graphrag( question: str, G: nx.Graph ) -> dict: stopwords = { "what", "was", "the", "in", "of", "a", "an", "is", "are", "how", "did", "does", "we", "if", "that", "you", "as", "based", "on", "which", "has", "for", "by", "per" } keywords = [ w.lower().strip("?.,") for w in question.split() if ( w.lower() not in stopwords and len(w) > 2 ) ] chunk_scores = {} for node, data in G.nodes(data=True): if data.get("type") != "chunk": continue text = ( data.get("text", "") + " " + data.get("company", "") ).lower() score = sum( 1 for kw in keywords if kw in text ) if score > 0: chunk_scores[node] = score top_chunks = sorted( chunk_scores, key=chunk_scores.get, reverse=True )[:1] if not top_chunks: top_chunks = [ n for n, d in G.nodes(data=True) if d.get("type") == "chunk" ][:1] context = "\n\n".join([ G.nodes[n].get( "text", "" ) for n in top_chunks ]) prompt = ( f"Context:\n{context}\n\n" f"Question: {question}\n" f"Answer:" ) input_tokens = len( prompt.split() ) response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=[{ "role": "user", "content": prompt }], max_tokens=200 ) answer = ( response .choices[0] .message.content ) total_tokens = ( response .usage .total_tokens ) return { "answer": answer, "input_tokens": input_tokens, "total_tokens": total_tokens, "context_chunks": len(top_chunks) } if __name__ == "__main__": G = build_graph() print( f"Graph: " f"{G.number_of_nodes()} nodes, " f"{G.number_of_edges()} edges" ) result = query_graphrag( "What was Apple's revenue in 2022?", G ) print(result)