File size: 1,958 Bytes
e23acaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from src.embeddings.embedding_factory import get_text_embedding
from src.retrieval.vector_store import VectorStoreFactory
from src.llm.llm_factory import get_llm
from src.utils.logger import get_logger

logger = get_logger(__name__)


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


def main():

    logger.info("Starting RAG query interface...")

    embedding = get_text_embedding()
    vectordb = VectorStoreFactory.create(embedding)
    retriever = vectordb.as_retriever(search_kwargs={"k": 5})

    llm = get_llm()

    prompt = ChatPromptTemplate.from_template("""
You are an anatomy tutor.

Answer the question using ONLY the context below.

Context:
{context}

Question:
{question}
""")

    rag_chain = (
        prompt
        | llm
        | StrOutputParser()
    )


    


    while True:

        query = input("\nAsk a question (or type 'exit'): ")

        if query.lower() == "exit":
            break

        # 🔎 STEP 1 — RETRIEVE DOCUMENTS
        docs = retriever.invoke(query)

        # 🧪 DEBUG: SEE WHAT IS RETRIEVED (keep for now)
        print("\nRETRIEVED CHUNKS:\n")
        for d in docs:
            print(d.page_content[:300])
            print("------")

        # 🧠 STEP 2 — PREPARE CONTEXT FOR LLM
        context = "\n\n".join(doc.page_content for doc in docs)

        response = rag_chain.invoke({
            "context": context,
            "question": query
        })

        # 🧾 STEP 3 — PRINT ANSWER
        print("\nANSWER:\n")
        print(response)

        # 📚 STEP 4 — PRINT SOURCES
        print("\nSOURCES:\n")
        for doc in docs:
            source = doc.metadata.get("source", "unknown")
            page = doc.metadata.get("page", "unknown")
            print(f"{source} — page {page}")


if __name__ == "__main__":
    main()