Spaces:
Sleeping
Sleeping
File size: 3,531 Bytes
81726c9 0314823 81726c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | # rag/retriever.py
import os
import torch
from dotenv import load_dotenv
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
load_dotenv()
# ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX = os.getenv("PINECONE_INDEX", "study-saathi")
EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
TOP_K = 10
# ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
# ββ Load Embedding Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("[INFO] Loading embedding model...")
embedder = SentenceTransformer(EMBEDDING_MODEL, device=device)
# ββ Pinecone Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
pc = Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(PINECONE_INDEX)
# ββ Retrieve ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def retrieve(query: str, topic: str = None, top_k: int = TOP_K) -> list:
"""
Retrieve relevant chunks from Pinecone.
- query : user's question or topic
- topic : filename-based topic filter (e.g. "deadlocks")
- top_k : number of chunks to return
"""
# e5 requires "query: " prefix for questions
query_embedding = embedder.encode(
f"query: {query}",
normalize_embeddings=True
).tolist()
# build filter only if topic is provided
filter_dict = {"topic": {"$eq": topic}} if topic else None
results = index.query(
vector=query_embedding,
top_k=top_k,
include_metadata=True,
filter=filter_dict
)
chunks = []
for match in results["matches"]:
chunks.append({
"text": match["metadata"]["text"],
"topic": match["metadata"].get("topic", "unknown"),
"score": round(match["score"], 4)
})
return chunks
# ββ Format for LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def format_context(chunks: list) -> str:
"""Joins retrieved chunks into a single context string for the LLM."""
return "\n\n".join(
f"[Chunk {i+1} | Topic: {c['topic']} | Score: {c['score']}]\n{c['text']}"
for i, c in enumerate(chunks)
)
# ββ Quick Test ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
query = "Explain Process Registers?"
topic = "ch-01-updated" # change to match your filename
chunks = retrieve(query, topic=topic)
context = format_context(chunks)
print(context) |