GraphRAG / backend /graphrag.py
Sanjam19's picture
clean initial commit
a4ab72e
Raw
History Blame Contribute Delete
3.79 kB
import json
import networkx as nx
import os
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
def build_graph() -> nx.Graph:
chunks = json.load(
open(
"data/chunks.json",
encoding="utf-8"
)
)
entities = json.load(
open(
"data/entities.json",
encoding="utf-8"
)
)
G = nx.Graph()
for c in chunks:
G.add_node(
c["chunk_id"],
type="chunk",
text=c["text"],
company=c["company"],
doc=c["doc_name"]
)
G.add_node(
c["company"],
type="company"
)
G.add_node(
c["doc_name"],
type="filing"
)
G.add_edge(
c["company"],
c["doc_name"],
rel="FILED"
)
G.add_edge(
c["doc_name"],
c["chunk_id"],
rel="CONTAINS"
)
for e in entities:
for ent in e["entities"]:
ent_id = (
f"{ent['label']}_"
f"{ent['text'][:50]}"
)
G.add_node(
ent_id,
type="entity",
label=ent["label"],
text=ent["text"]
)
G.add_edge(
e["chunk_id"],
ent_id,
rel="MENTIONS"
)
return G
def query_graphrag(
question: str,
G: nx.Graph
) -> dict:
stopwords = {
"what", "was", "the", "in",
"of", "a", "an", "is",
"are", "how", "did",
"does", "we", "if",
"that", "you", "as",
"based", "on", "which",
"has", "for", "by", "per"
}
keywords = [
w.lower().strip("?.,")
for w in question.split()
if (
w.lower() not in stopwords
and len(w) > 2
)
]
chunk_scores = {}
for node, data in G.nodes(data=True):
if data.get("type") != "chunk":
continue
text = (
data.get("text", "")
+ " "
+ data.get("company", "")
).lower()
score = sum(
1
for kw in keywords
if kw in text
)
if score > 0:
chunk_scores[node] = score
top_chunks = sorted(
chunk_scores,
key=chunk_scores.get,
reverse=True
)[:1]
if not top_chunks:
top_chunks = [
n
for n, d in G.nodes(data=True)
if d.get("type") == "chunk"
][:1]
context = "\n\n".join([
G.nodes[n].get(
"text",
""
)
for n in top_chunks
])
prompt = (
f"Context:\n{context}\n\n"
f"Question: {question}\n"
f"Answer:"
)
input_tokens = len(
prompt.split()
)
response = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{
"role": "user",
"content": prompt
}],
max_tokens=200
)
answer = (
response
.choices[0]
.message.content
)
total_tokens = (
response
.usage
.total_tokens
)
return {
"answer": answer,
"input_tokens": input_tokens,
"total_tokens": total_tokens,
"context_chunks": len(top_chunks)
}
if __name__ == "__main__":
G = build_graph()
print(
f"Graph: "
f"{G.number_of_nodes()} nodes, "
f"{G.number_of_edges()} edges"
)
result = query_graphrag(
"What was Apple's revenue in 2022?",
G
)
print(result)