File size: 4,631 Bytes
f01124b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from dotenv import load_dotenv
load_dotenv()
from prompt import (
    RouteQuery,
    GradeDocuments,
    GenerateAnswer,
    GradeHallucinations,
    ExtractFilter,
    route_chain,
    transform_query_chain,
    grade_documents_chain,
    gen_normal_answer_chain,
    gen_answer_rag_chain,
    grade_hallucinations_chain,
    extract_filter_chain,
)
from config.database import vector_store
from langchain_core.documents import Document
from prompt import GradeDocuments
from helper import convert_list_context_source_to_str
from logger import logger
from langgraph.graph.message import AnyMessage, add_messages
from typing import TypedDict, Literal



class State(TypedDict):
    user_query: AnyMessage
    route: str
    messages_history: list
    documents: list[Document]
    filter: dict
    llm_response: AnyMessage
    grade_response: Literal["yes", "no"]
    language: str
    document_id_selected: str


def route_fn(state: State):
    question = state["user_query"].content
    route_response: RouteQuery = route_chain.invoke({"question": question})
    logger.info(f"Route response: {route_response}")
    return {"route": route_response.datasource}


def transform_query_fn(state: State):
    question = state["user_query"].content
    chat_history = state["messages_history"]
    transform_response = transform_query_chain.invoke(
        {"question": question, "chat_history": chat_history}
    )
    logger.info(f"Transform response: {transform_response}")
    return {"user_query": transform_response}


def retrieve_document_fn(state: State):
    question = state["user_query"].content
    history = state["messages_history"]
    filter = state.get("filter", None)
    if not filter:
        filter_response: ExtractFilter = extract_filter_chain.invoke(
            {"question": question, "history": history}
        )
        logger.info(f"Extract filter response: {filter_response}")
        job_title = filter_response.job_title
        job_level = filter_response.job_level
        filter = {}
        if job_title:
            filter["title"] = job_title
        if job_level:
            filter["level"] = job_level
    retriever = vector_store.as_retriever(
        search_type="similarity_score_threshold",
        search_kwargs={"k": 5, "score_threshold": 0.0},
    )
    documents = retriever.invoke(question, filter=filter)
    logger.info(f"Retrieved documents: {documents}")
    return {"documents": documents}


def grade_document_fn(state: State):
    question = state["user_query"].content
    documents = state["documents"]
    inputs_bach = [
        {"question": question, "document": doc.page_content} for doc in documents
    ]
    grade_response: list[GradeDocuments] = grade_documents_chain.batch(inputs_bach)
    logger.info(f"Grade response: {grade_response}")
    document_index = [
        index for index, doc in enumerate(grade_response) if doc.binary_score == "yes"
    ]
    filtered_documents = [documents[i] for i in document_index]

    return {"documents": filtered_documents}


def generate_answer_rag_fn(state: State):
    question = state["user_query"].content
    documents = state["documents"]
    language = state["language"]
    if documents:
        context_str = convert_list_context_source_to_str(documents)

    gen_answer_response: GenerateAnswer = gen_answer_rag_chain.invoke(
        {"question": question, "context": context_str, "language": language}
    )
    logger.info(f"Generate answer response: {gen_answer_response}")
    id_selected = None
    if gen_answer_response.selected_document_index is not None:
        id_selected = documents[gen_answer_response.selected_document_index].metadata[
            "id"
        ]
    logger.info(f"Document id selected: {id_selected}")
    return {
        "llm_response": gen_answer_response.answer,
        "document_id_selected": id_selected,
    }


def grade_hallucinations_fn(state: State):
    question = state["user_query"].content
    llm_response = state["llm_response"]
    grade_response: GradeHallucinations = grade_hallucinations_chain.invoke(
        {"question": question, "generation": llm_response}
    )
    logger.info(f"Grade hallucinations response: {grade_response}")
    return {"grade_response": grade_response.binary_score}


def gen_answer_normal_fn(state: State):
    question = state["user_query"].content
    history = state["messages_history"]
    gen_answer_response = gen_normal_answer_chain.invoke(
        {"question": question, "history": history}
    )
    logger.info(f"Generate answer response: {gen_answer_response}")
    return {"llm_response": gen_answer_response.content}