# rag/retriever.py import os import torch from dotenv import load_dotenv from pinecone import Pinecone from sentence_transformers import SentenceTransformer load_dotenv() # ── Config ─────────────────────────────────────────────────────────────────── PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") PINECONE_INDEX = os.getenv("PINECONE_INDEX", "study-saathi") EMBEDDING_MODEL = "intfloat/multilingual-e5-large" TOP_K = 10 # ── Device ─────────────────────────────────────────────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" # ── Load Embedding Model ────────────────────────────────────────────────────── print("[INFO] Loading embedding model...") embedder = SentenceTransformer(EMBEDDING_MODEL, device=device) # ── Pinecone Setup ──────────────────────────────────────────────────────────── pc = Pinecone(api_key=PINECONE_API_KEY) index = pc.Index(PINECONE_INDEX) # ── Retrieve ────────────────────────────────────────────────────────────────── def retrieve(query: str, topic: str = None, top_k: int = TOP_K) -> list: """ Retrieve relevant chunks from Pinecone. - query : user's question or topic - topic : filename-based topic filter (e.g. "deadlocks") - top_k : number of chunks to return """ # e5 requires "query: " prefix for questions query_embedding = embedder.encode( f"query: {query}", normalize_embeddings=True ).tolist() # build filter only if topic is provided filter_dict = {"topic": {"$eq": topic}} if topic else None results = index.query( vector=query_embedding, top_k=top_k, include_metadata=True, filter=filter_dict ) chunks = [] for match in results["matches"]: chunks.append({ "text": match["metadata"]["text"], "topic": match["metadata"].get("topic", "unknown"), "score": round(match["score"], 4) }) return chunks # ── Format for LLM ──────────────────────────────────────────────────────────── def format_context(chunks: list) -> str: """Joins retrieved chunks into a single context string for the LLM.""" return "\n\n".join( f"[Chunk {i+1} | Topic: {c['topic']} | Score: {c['score']}]\n{c['text']}" for i, c in enumerate(chunks) ) # ── Quick Test ──────────────────────────────────────────────────────────────── if __name__ == "__main__": query = "Explain Process Registers?" topic = "ch-01-updated" # change to match your filename chunks = retrieve(query, topic=topic) context = format_context(chunks) print(context)