study-sathi / rag /retriever.py
YousifCreates's picture
updated top_k value
0314823
# rag/retriever.py
import os
import torch
from dotenv import load_dotenv
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
load_dotenv()
# ── Config ───────────────────────────────────────────────────────────────────
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX = os.getenv("PINECONE_INDEX", "study-saathi")
EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
TOP_K = 10
# ── Device ───────────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
# ── Load Embedding Model ──────────────────────────────────────────────────────
print("[INFO] Loading embedding model...")
embedder = SentenceTransformer(EMBEDDING_MODEL, device=device)
# ── Pinecone Setup ────────────────────────────────────────────────────────────
pc = Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(PINECONE_INDEX)
# ── Retrieve ──────────────────────────────────────────────────────────────────
def retrieve(query: str, topic: str = None, top_k: int = TOP_K) -> list:
"""
Retrieve relevant chunks from Pinecone.
- query : user's question or topic
- topic : filename-based topic filter (e.g. "deadlocks")
- top_k : number of chunks to return
"""
# e5 requires "query: " prefix for questions
query_embedding = embedder.encode(
f"query: {query}",
normalize_embeddings=True
).tolist()
# build filter only if topic is provided
filter_dict = {"topic": {"$eq": topic}} if topic else None
results = index.query(
vector=query_embedding,
top_k=top_k,
include_metadata=True,
filter=filter_dict
)
chunks = []
for match in results["matches"]:
chunks.append({
"text": match["metadata"]["text"],
"topic": match["metadata"].get("topic", "unknown"),
"score": round(match["score"], 4)
})
return chunks
# ── Format for LLM ────────────────────────────────────────────────────────────
def format_context(chunks: list) -> str:
"""Joins retrieved chunks into a single context string for the LLM."""
return "\n\n".join(
f"[Chunk {i+1} | Topic: {c['topic']} | Score: {c['score']}]\n{c['text']}"
for i, c in enumerate(chunks)
)
# ── Quick Test ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
query = "Explain Process Registers?"
topic = "ch-01-updated" # change to match your filename
chunks = retrieve(query, topic=topic)
context = format_context(chunks)
print(context)