File size: 3,531 Bytes
81726c9
 
 
 
 
 
 
 
 
 
 
 
 
 
0314823
81726c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# rag/retriever.py

import os
import torch
from dotenv import load_dotenv
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer

load_dotenv()

# ── Config ───────────────────────────────────────────────────────────────────
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX   = os.getenv("PINECONE_INDEX", "study-saathi")
EMBEDDING_MODEL  = "intfloat/multilingual-e5-large"
TOP_K            = 10

# ── Device ───────────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"

# ── Load Embedding Model ──────────────────────────────────────────────────────
print("[INFO] Loading embedding model...")
embedder = SentenceTransformer(EMBEDDING_MODEL, device=device)

# ── Pinecone Setup ────────────────────────────────────────────────────────────
pc    = Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(PINECONE_INDEX)

# ── Retrieve ──────────────────────────────────────────────────────────────────
def retrieve(query: str, topic: str = None, top_k: int = TOP_K) -> list:
    """
    Retrieve relevant chunks from Pinecone.
    - query : user's question or topic
    - topic : filename-based topic filter (e.g. "deadlocks")
    - top_k : number of chunks to return
    """
    # e5 requires "query: " prefix for questions
    query_embedding = embedder.encode(
        f"query: {query}",
        normalize_embeddings=True
    ).tolist()

    # build filter only if topic is provided
    filter_dict = {"topic": {"$eq": topic}} if topic else None

    results = index.query(
        vector=query_embedding,
        top_k=top_k,
        include_metadata=True,
        filter=filter_dict
    )

    chunks = []
    for match in results["matches"]:
        chunks.append({
            "text":  match["metadata"]["text"],
            "topic": match["metadata"].get("topic", "unknown"),
            "score": round(match["score"], 4)
        })

    return chunks


# ── Format for LLM ────────────────────────────────────────────────────────────
def format_context(chunks: list) -> str:
    """Joins retrieved chunks into a single context string for the LLM."""
    return "\n\n".join(
        f"[Chunk {i+1} | Topic: {c['topic']} | Score: {c['score']}]\n{c['text']}"
        for i, c in enumerate(chunks)
    )


# ── Quick Test ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    query  = "Explain Process Registers?"
    topic  = "ch-01-updated"   # change to match your filename

    chunks  = retrieve(query, topic=topic)
    context = format_context(chunks)

    print(context)