mini-rag / rag_core.py
navyamehta's picture
Upload 11 files
33f5651 verified
import os
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
from llm import LLMProvider
from pinecone_client import PineconeClient
load_dotenv()
def _build_prompt(query: str, contexts: List[str]) -> List[Dict[str, str]]:
system = (
"You are a helpful assistant. Answer the user's question using the provided context. "
"If the answer isn't in the context, say you don't know. Be concise."
)
context_block = "\n\n".join([f"[Source {i+1}]\n{c}" for i, c in enumerate(contexts)])
user = f"Question: {query}\n\nContext:\n{context_block}"
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
def _build_citation_prompt(query: str, contexts: List[str]) -> List[Dict[str, str]]:
system = (
"You are a helpful assistant. Answer the user's question using the provided context. "
"IMPORTANT: Use inline citations [1], [2], [3] etc. to reference specific sources. "
"Each citation number should correspond to the source number from the context. "
"If the answer isn't in the context, say you don't know. Be concise and accurate."
)
context_block = "\n\n".join([f"[Source {i+1}]\n{c}" for i, c in enumerate(contexts)])
user = f"Question: {query}\n\nContext:\n{context_block}\n\nAnswer with inline citations [1], [2], etc.:"
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
class RAGCore:
def __init__(self) -> None:
self.llm = LLMProvider()
self.pc = PineconeClient()
def ensure_index(self, embedding_dim: int) -> None:
self.pc.ensure_index(dimension=embedding_dim)
def retrieve(self, query: str, top_k: int = 5, rerank: bool = True) -> Tuple[List[Dict[str, Any]], List[str]]:
q_vec = self.llm.embed_texts([query])[0]
results = self.pc.query(vector=q_vec, top_k=top_k)
matches = results.get("matches", [])
docs: List[Dict[str, Any]] = []
for m in matches:
md = m.get("metadata", {}) or {}
text = md.get("text", "")
docs.append({
"id": m.get("id"),
"text": text,
"score": float(m.get("score", 0.0)),
"metadata": md,
})
if rerank:
docs = self.llm.rerank(query, docs)
contexts = [d["text"] for d in docs]
return docs, contexts
def generate(self, query: str, contexts: List[str]) -> str:
messages = _build_prompt(query, contexts)
return self.llm.chat(messages)
def generate_with_citations(self, query: str, contexts: List[str]) -> str:
"""Generate answer with inline citations [1], [2], etc."""
if not contexts:
return "No relevant context found to answer this question."
messages = _build_citation_prompt(query, contexts)
return self.llm.chat(messages)