File size: 2,205 Bytes
0e9a6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d8214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e9a6da
 
 
 
 
 
 
 
 
 
 
c4d8214
0e9a6da
 
 
 
ead947b
 
 
0e9a6da
 
c4d8214
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging

from src.rag.retriever import ArxivRetriever
from src.rag.llm import get_chain
from src.rag.reranker import ArxivReranker
from src.core.config import settings

logging.basicConfig(level=logging.INFO)

reranker = ArxivReranker()
retriever = ArxivRetriever()


def create_context(docs) -> str:
    """Create context string from documents."""
    context = ""
    for doc in docs:
        context += f"Title: {doc.metadata.get('Titles', 'No Title')}\n"
        context += f"Content: {doc.page_content}\n"
        context += f"Year of Publication: {doc.metadata.get('Years', 'Unknown')}\n\n"
    return context


def extract_citations(docs) -> list[dict]:
    """Extract paper titles and years for citations."""
    citations = []
    for doc in docs:
        citations.append({
            "title": doc.metadata.get('Titles', 'No Title'),
            "year": doc.metadata.get('Years', 'Unknown')
        })
    return citations


def answer_question(question: str) -> tuple[str, list[dict]]:
    """Answer a question using retrieved and reranked documents.
    
    Returns:
        tuple: (answer, citations) where citations is a list of dicts with 'title' and 'year'
    """
    # Retrieve more documents than needed for reranking
    try:
        retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
        results = retriever.invoke(question)

        # Rerank and get top k
        reranked_results = reranker.rerank_documents(
            question, results, top_k=settings.RETRIEVER_K_AFTER_RERANK
        )

        context = create_context(reranked_results)
        citations = extract_citations(reranked_results)
        logging.info(f"Constructed context for LLM: {context}")
        chain = get_chain()

        response = chain.invoke({"context": context, "question": question})
        # Extract just the content from the response message
        answer_text = response.content if hasattr(response, 'content') else str(response)
        return answer_text, citations
    except Exception as e:
        logging.error(f"Error occurred while answering question: {e}")
        error_msg = f"Sorry, an error occurred while processing your request: {str(e)}"
        return error_msg, []