Recruitment / graph_function.py
ABAO77's picture
Upload 14 files
f01124b verified
from dotenv import load_dotenv
load_dotenv()
from prompt import (
RouteQuery,
GradeDocuments,
GenerateAnswer,
GradeHallucinations,
ExtractFilter,
route_chain,
transform_query_chain,
grade_documents_chain,
gen_normal_answer_chain,
gen_answer_rag_chain,
grade_hallucinations_chain,
extract_filter_chain,
)
from config.database import vector_store
from langchain_core.documents import Document
from prompt import GradeDocuments
from helper import convert_list_context_source_to_str
from logger import logger
from langgraph.graph.message import AnyMessage, add_messages
from typing import TypedDict, Literal
class State(TypedDict):
user_query: AnyMessage
route: str
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
grade_response: Literal["yes", "no"]
language: str
document_id_selected: str
def route_fn(state: State):
question = state["user_query"].content
route_response: RouteQuery = route_chain.invoke({"question": question})
logger.info(f"Route response: {route_response}")
return {"route": route_response.datasource}
def transform_query_fn(state: State):
question = state["user_query"].content
chat_history = state["messages_history"]
transform_response = transform_query_chain.invoke(
{"question": question, "chat_history": chat_history}
)
logger.info(f"Transform response: {transform_response}")
return {"user_query": transform_response}
def retrieve_document_fn(state: State):
question = state["user_query"].content
history = state["messages_history"]
filter = state.get("filter", None)
if not filter:
filter_response: ExtractFilter = extract_filter_chain.invoke(
{"question": question, "history": history}
)
logger.info(f"Extract filter response: {filter_response}")
job_title = filter_response.job_title
job_level = filter_response.job_level
filter = {}
if job_title:
filter["title"] = job_title
if job_level:
filter["level"] = job_level
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.0},
)
documents = retriever.invoke(question, filter=filter)
logger.info(f"Retrieved documents: {documents}")
return {"documents": documents}
def grade_document_fn(state: State):
question = state["user_query"].content
documents = state["documents"]
inputs_bach = [
{"question": question, "document": doc.page_content} for doc in documents
]
grade_response: list[GradeDocuments] = grade_documents_chain.batch(inputs_bach)
logger.info(f"Grade response: {grade_response}")
document_index = [
index for index, doc in enumerate(grade_response) if doc.binary_score == "yes"
]
filtered_documents = [documents[i] for i in document_index]
return {"documents": filtered_documents}
def generate_answer_rag_fn(state: State):
question = state["user_query"].content
documents = state["documents"]
language = state["language"]
if documents:
context_str = convert_list_context_source_to_str(documents)
gen_answer_response: GenerateAnswer = gen_answer_rag_chain.invoke(
{"question": question, "context": context_str, "language": language}
)
logger.info(f"Generate answer response: {gen_answer_response}")
id_selected = None
if gen_answer_response.selected_document_index is not None:
id_selected = documents[gen_answer_response.selected_document_index].metadata[
"id"
]
logger.info(f"Document id selected: {id_selected}")
return {
"llm_response": gen_answer_response.answer,
"document_id_selected": id_selected,
}
def grade_hallucinations_fn(state: State):
question = state["user_query"].content
llm_response = state["llm_response"]
grade_response: GradeHallucinations = grade_hallucinations_chain.invoke(
{"question": question, "generation": llm_response}
)
logger.info(f"Grade hallucinations response: {grade_response}")
return {"grade_response": grade_response.binary_score}
def gen_answer_normal_fn(state: State):
question = state["user_query"].content
history = state["messages_history"]
gen_answer_response = gen_normal_answer_chain.invoke(
{"question": question, "history": history}
)
logger.info(f"Generate answer response: {gen_answer_response}")
return {"llm_response": gen_answer_response.content}