Spaces:
Sleeping
Sleeping
| # rag/retriever.py | |
| import os | |
| import torch | |
| from dotenv import load_dotenv | |
| from pinecone import Pinecone | |
| from sentence_transformers import SentenceTransformer | |
| load_dotenv() | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
| PINECONE_INDEX = os.getenv("PINECONE_INDEX", "study-saathi") | |
| EMBEDDING_MODEL = "intfloat/multilingual-e5-large" | |
| TOP_K = 10 | |
| # ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ββ Load Embedding Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("[INFO] Loading embedding model...") | |
| embedder = SentenceTransformer(EMBEDDING_MODEL, device=device) | |
| # ββ Pinecone Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| pc = Pinecone(api_key=PINECONE_API_KEY) | |
| index = pc.Index(PINECONE_INDEX) | |
| # ββ Retrieve ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve(query: str, topic: str = None, top_k: int = TOP_K) -> list: | |
| """ | |
| Retrieve relevant chunks from Pinecone. | |
| - query : user's question or topic | |
| - topic : filename-based topic filter (e.g. "deadlocks") | |
| - top_k : number of chunks to return | |
| """ | |
| # e5 requires "query: " prefix for questions | |
| query_embedding = embedder.encode( | |
| f"query: {query}", | |
| normalize_embeddings=True | |
| ).tolist() | |
| # build filter only if topic is provided | |
| filter_dict = {"topic": {"$eq": topic}} if topic else None | |
| results = index.query( | |
| vector=query_embedding, | |
| top_k=top_k, | |
| include_metadata=True, | |
| filter=filter_dict | |
| ) | |
| chunks = [] | |
| for match in results["matches"]: | |
| chunks.append({ | |
| "text": match["metadata"]["text"], | |
| "topic": match["metadata"].get("topic", "unknown"), | |
| "score": round(match["score"], 4) | |
| }) | |
| return chunks | |
| # ββ Format for LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def format_context(chunks: list) -> str: | |
| """Joins retrieved chunks into a single context string for the LLM.""" | |
| return "\n\n".join( | |
| f"[Chunk {i+1} | Topic: {c['topic']} | Score: {c['score']}]\n{c['text']}" | |
| for i, c in enumerate(chunks) | |
| ) | |
| # ββ Quick Test ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| query = "Explain Process Registers?" | |
| topic = "ch-01-updated" # change to match your filename | |
| chunks = retrieve(query, topic=topic) | |
| context = format_context(chunks) | |
| print(context) |