Spaces:
Sleeping
Sleeping
File size: 4,631 Bytes
f01124b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from dotenv import load_dotenv
load_dotenv()
from prompt import (
RouteQuery,
GradeDocuments,
GenerateAnswer,
GradeHallucinations,
ExtractFilter,
route_chain,
transform_query_chain,
grade_documents_chain,
gen_normal_answer_chain,
gen_answer_rag_chain,
grade_hallucinations_chain,
extract_filter_chain,
)
from config.database import vector_store
from langchain_core.documents import Document
from prompt import GradeDocuments
from helper import convert_list_context_source_to_str
from logger import logger
from langgraph.graph.message import AnyMessage, add_messages
from typing import TypedDict, Literal
class State(TypedDict):
user_query: AnyMessage
route: str
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
grade_response: Literal["yes", "no"]
language: str
document_id_selected: str
def route_fn(state: State):
question = state["user_query"].content
route_response: RouteQuery = route_chain.invoke({"question": question})
logger.info(f"Route response: {route_response}")
return {"route": route_response.datasource}
def transform_query_fn(state: State):
question = state["user_query"].content
chat_history = state["messages_history"]
transform_response = transform_query_chain.invoke(
{"question": question, "chat_history": chat_history}
)
logger.info(f"Transform response: {transform_response}")
return {"user_query": transform_response}
def retrieve_document_fn(state: State):
question = state["user_query"].content
history = state["messages_history"]
filter = state.get("filter", None)
if not filter:
filter_response: ExtractFilter = extract_filter_chain.invoke(
{"question": question, "history": history}
)
logger.info(f"Extract filter response: {filter_response}")
job_title = filter_response.job_title
job_level = filter_response.job_level
filter = {}
if job_title:
filter["title"] = job_title
if job_level:
filter["level"] = job_level
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.0},
)
documents = retriever.invoke(question, filter=filter)
logger.info(f"Retrieved documents: {documents}")
return {"documents": documents}
def grade_document_fn(state: State):
question = state["user_query"].content
documents = state["documents"]
inputs_bach = [
{"question": question, "document": doc.page_content} for doc in documents
]
grade_response: list[GradeDocuments] = grade_documents_chain.batch(inputs_bach)
logger.info(f"Grade response: {grade_response}")
document_index = [
index for index, doc in enumerate(grade_response) if doc.binary_score == "yes"
]
filtered_documents = [documents[i] for i in document_index]
return {"documents": filtered_documents}
def generate_answer_rag_fn(state: State):
question = state["user_query"].content
documents = state["documents"]
language = state["language"]
if documents:
context_str = convert_list_context_source_to_str(documents)
gen_answer_response: GenerateAnswer = gen_answer_rag_chain.invoke(
{"question": question, "context": context_str, "language": language}
)
logger.info(f"Generate answer response: {gen_answer_response}")
id_selected = None
if gen_answer_response.selected_document_index is not None:
id_selected = documents[gen_answer_response.selected_document_index].metadata[
"id"
]
logger.info(f"Document id selected: {id_selected}")
return {
"llm_response": gen_answer_response.answer,
"document_id_selected": id_selected,
}
def grade_hallucinations_fn(state: State):
question = state["user_query"].content
llm_response = state["llm_response"]
grade_response: GradeHallucinations = grade_hallucinations_chain.invoke(
{"question": question, "generation": llm_response}
)
logger.info(f"Grade hallucinations response: {grade_response}")
return {"grade_response": grade_response.binary_score}
def gen_answer_normal_fn(state: State):
question = state["user_query"].content
history = state["messages_history"]
gen_answer_response = gen_normal_answer_chain.invoke(
{"question": question, "history": history}
)
logger.info(f"Generate answer response: {gen_answer_response}")
return {"llm_response": gen_answer_response.content}
|