GraphRAG / backend /rag.py
Sanjam19's picture
clean initial commit
a4ab72e
Raw
History Blame Contribute Delete
1.95 kB
import chromadb
import json
from groq import Groq
from sentence_transformers import SentenceTransformer
import os
from dotenv import load_dotenv
load_dotenv()
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
embedder = SentenceTransformer("all-MiniLM-L6-v2")
chroma = chromadb.PersistentClient(path="data/chroma")
collection = chroma.get_or_create_collection("financebench")
def ingest_chunks():
chunks = json.load(open("data/chunks.json"))
texts = [c["text"] for c in chunks]
ids = [c["chunk_id"] for c in chunks]
metadatas = [
{
"company": c["company"],
"doc": c["doc_name"]
}
for c in chunks
]
embeddings = embedder.encode(texts).tolist()
collection.upsert(
documents=texts,
ids=ids,
embeddings=embeddings,
metadatas=metadatas
)
print(f"Ingested {len(chunks)} chunks into ChromaDB")
def query_rag(question: str, top_k: int = 5) -> dict:
q_embedding = embedder.encode([question]).tolist()[0]
results = collection.query(
query_embeddings=[q_embedding],
n_results=top_k
)
context = "\n\n".join(results["documents"][0])
prompt = (
f"Context:\n{context}\n\n"
f"Question: {question}\n"
f"Answer:"
)
input_tokens = len(prompt.split())
response = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{
"role": "user",
"content": prompt
}],
max_tokens=200
)
answer = response.choices[0].message.content
total_tokens = response.usage.total_tokens
return {
"answer": answer,
"input_tokens": input_tokens,
"total_tokens": total_tokens,
"context_chunks": len(results["documents"][0])
}
if __name__ == "__main__":
ingest_chunks()
result = query_rag(
"What was Apple's revenue in 2022?"
)
print(result)