RAG-lab / src /rag /pipeline.py
mechark
refac: clearer code, fix output
ead947b
import logging
from src.rag.retriever import ArxivRetriever
from src.rag.llm import get_chain
from src.rag.reranker import ArxivReranker
from src.core.config import settings
logging.basicConfig(level=logging.INFO)
reranker = ArxivReranker()
retriever = ArxivRetriever()
def create_context(docs) -> str:
"""Create context string from documents."""
context = ""
for doc in docs:
context += f"Title: {doc.metadata.get('Titles', 'No Title')}\n"
context += f"Content: {doc.page_content}\n"
context += f"Year of Publication: {doc.metadata.get('Years', 'Unknown')}\n\n"
return context
def extract_citations(docs) -> list[dict]:
"""Extract paper titles and years for citations."""
citations = []
for doc in docs:
citations.append({
"title": doc.metadata.get('Titles', 'No Title'),
"year": doc.metadata.get('Years', 'Unknown')
})
return citations
def answer_question(question: str) -> tuple[str, list[dict]]:
"""Answer a question using retrieved and reranked documents.
Returns:
tuple: (answer, citations) where citations is a list of dicts with 'title' and 'year'
"""
# Retrieve more documents than needed for reranking
try:
retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
results = retriever.invoke(question)
# Rerank and get top k
reranked_results = reranker.rerank_documents(
question, results, top_k=settings.RETRIEVER_K_AFTER_RERANK
)
context = create_context(reranked_results)
citations = extract_citations(reranked_results)
logging.info(f"Constructed context for LLM: {context}")
chain = get_chain()
response = chain.invoke({"context": context, "question": question})
# Extract just the content from the response message
answer_text = response.content if hasattr(response, 'content') else str(response)
return answer_text, citations
except Exception as e:
logging.error(f"Error occurred while answering question: {e}")
error_msg = f"Sorry, an error occurred while processing your request: {str(e)}"
return error_msg, []