File size: 1,501 Bytes
c0f74f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir))

sys.path.insert(0, project_root)


from models.llm import get_llm
from langchain_core.prompts import ChatPromptTemplate
from models.llm import get_llm
from langchain_core.output_parsers import StrOutputParser
from data.dataingestion import load_all_pdfs
document = load_all_pdfs()

def route_node(state):
    question = state["messages"][-1].content

    api_key = state.get("api_key")
    if not api_key:
         raise ValueError("API Key not found in state.")

    model = get_llm(api=api_key)
    
    prompt = ChatPromptTemplate.from_messages([
        ("system","""You are an expert router.

        Your task is to classify the user's question based on its content:

        1. 'rag': If the question is related to the topics provided in these documents : {documents}

        2. 'wikipedia': If the question is about general knowledge, history, people, or events.

        Return ONLY a single word string: 'rag' or 'wikipedia'.

        """),
        ("user","{question}")
    ])

    route_chain = prompt|model|StrOutputParser()

    route = route_chain.invoke({"question":question,"documents":document})
    if "rag" in route:
        decision = "rag"
        print("routing to rag")
    else:
        decision = "wiki" 
        print("routing to wikipedia")

    return {"source":decision}
def route_decision(state):
    return state["source"]