|
|
import logging |
|
|
|
|
|
from src.rag.retriever import ArxivRetriever |
|
|
from src.rag.llm import get_chain |
|
|
from src.rag.reranker import ArxivReranker |
|
|
from src.core.config import settings |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
reranker = ArxivReranker() |
|
|
retriever = ArxivRetriever() |
|
|
|
|
|
|
|
|
def create_context(docs) -> str: |
|
|
"""Create context string from documents.""" |
|
|
context = "" |
|
|
for doc in docs: |
|
|
context += f"Title: {doc.metadata.get('Titles', 'No Title')}\n" |
|
|
context += f"Content: {doc.page_content}\n" |
|
|
context += f"Year of Publication: {doc.metadata.get('Years', 'Unknown')}\n\n" |
|
|
return context |
|
|
|
|
|
|
|
|
def extract_citations(docs) -> list[dict]: |
|
|
"""Extract paper titles and years for citations.""" |
|
|
citations = [] |
|
|
for doc in docs: |
|
|
citations.append({ |
|
|
"title": doc.metadata.get('Titles', 'No Title'), |
|
|
"year": doc.metadata.get('Years', 'Unknown') |
|
|
}) |
|
|
return citations |
|
|
|
|
|
|
|
|
def answer_question(question: str) -> tuple[str, list[dict]]: |
|
|
"""Answer a question using retrieved and reranked documents. |
|
|
|
|
|
Returns: |
|
|
tuple: (answer, citations) where citations is a list of dicts with 'title' and 'year' |
|
|
""" |
|
|
|
|
|
try: |
|
|
retriever.k = settings.RETRIEVER_K_BEFORE_RERANK |
|
|
results = retriever.invoke(question) |
|
|
|
|
|
|
|
|
reranked_results = reranker.rerank_documents( |
|
|
question, results, top_k=settings.RETRIEVER_K_AFTER_RERANK |
|
|
) |
|
|
|
|
|
context = create_context(reranked_results) |
|
|
citations = extract_citations(reranked_results) |
|
|
logging.info(f"Constructed context for LLM: {context}") |
|
|
chain = get_chain() |
|
|
|
|
|
response = chain.invoke({"context": context, "question": question}) |
|
|
|
|
|
answer_text = response.content if hasattr(response, 'content') else str(response) |
|
|
return answer_text, citations |
|
|
except Exception as e: |
|
|
logging.error(f"Error occurred while answering question: {e}") |
|
|
error_msg = f"Sorry, an error occurred while processing your request: {str(e)}" |
|
|
return error_msg, [] |
|
|
|