ABAO77's picture
Upload 42 files
7365dcb verified
import os
from typing import TypedDict, Optional, List, Literal
from langchain_core.documents import Document
from src.utils.helper import (
fake_token_counter,
convert_list_context_source_to_str,
convert_message,
)
from src.utils.logger import logger
from langchain_core.messages import trim_messages, AnyMessage
from src.config.vector_store import (
vector_store_chatbot,
vector_store_fresher,
vector_store_tutor,
)
from .prompt import (
RouteQuery,
route_chain,
transform_query_chain,
ExtractFilter,
extract_filter_chain,
GradeDocuments,
GenerateAnswer,
GradeHallucinations,
gen_normal_answer_chain,
gen_answer_rag_chain,
grade_documents_chain,
gen_answer_rag_chain,
grade_documents_chain,
grade_hallucinations_chain,
gen_answer_rag_tutor_chain,
)
class StateRAGAccuracy(TypedDict):
user_query: str | AnyMessage
route_response: str
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
grade_response: Literal["yes", "no"]
language: str
document_id_selected: Optional[List]
topic: str
class StateRAGSpeed(TypedDict):
user_query: str | AnyMessage
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
language: str
document_id_selected: Optional[List]
topic: str
def trim_history(state: StateRAGAccuracy | StateRAGSpeed):
history = (
convert_message(state["messages_history"])
if state.get("messages_history")
else None
)
if not history:
return {"messages_history": []}
chat_message_history = trim_messages(
history,
strategy="last",
token_counter=fake_token_counter,
max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 2000)),
start_on="human",
end_on="ai",
include_system=False,
allow_partial=False,
)
return {"messages_history": chat_message_history}
async def route(state: StateRAGAccuracy):
logger.info(f"routing")
question = state["user_query"]
chat_history = state.get("messages_history", None)
route_response: RouteQuery = await route_chain.ainvoke(
{
"question": question,
"chat_history": chat_history,
"topic": state["topic"],
}
)
logger.info(f"Route response: {route_response.datasource}")
return {"route_response": route_response.datasource}
async def transform_query(state: StateRAGAccuracy | StateRAGSpeed):
question = state["user_query"]
chat_history = state.get("messages_history", None)
transform_response = await transform_query_chain.ainvoke(
{
"question": question,
"chat_history": chat_history,
"topic": state["topic"],
}
)
logger.info(f"Transform response: {transform_response.content}")
return {"user_query": transform_response.content}
async def retrieve_document(state: StateRAGAccuracy):
question = state["user_query"]
filter = state.get("filter", {})
logger.info(f"Filter: {filter}")
if filter:
retriever = vector_store_tutor.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.3},
)
else:
retriever = vector_store_chatbot.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.3},
)
documents = retriever.invoke(question, filter=filter)
show_doc = " \n =============\n".join([doc.page_content for doc in documents])
logger.info(f"Retrieved documents: {show_doc}")
return {"documents": documents}
async def grade_document(state: StateRAGAccuracy):
question = state["user_query"]
documents = state["documents"]
inputs_bach = [
{"question": question, "document": doc.page_content} for doc in documents
]
grade_document_response: list[GradeDocuments] = await grade_documents_chain.abatch(
inputs_bach
)
logger.info(f"Grade response: {grade_document_response}")
document_index = [
index
for index, doc in enumerate(grade_document_response)
if doc.binary_score == "yes"
]
filtered_documents = [documents[i] for i in document_index]
return {"documents": filtered_documents}
async def generate_answer_rag(state: StateRAGAccuracy):
question = state["user_query"]
documents = state["documents"]
language = state["language"]
context_str = convert_list_context_source_to_str(documents)
gen_answer_response: GenerateAnswer = await gen_answer_rag_tutor_chain.ainvoke(
{
"question": question,
"context": context_str,
"language": language,
"topic": state["topic"],
}
)
logger.info(f"Generate answer response: {gen_answer_response}")
id_selected = gen_answer_response.selected_document_index
return {
"llm_response": gen_answer_response.answer,
"document_id_selected": 1,
}
async def grade_hallucinations(state: StateRAGAccuracy):
question = state["user_query"]
llm_response = state["llm_response"]
grade_response: GradeHallucinations = await grade_hallucinations_chain.ainvoke(
{"question": question, "generation": llm_response}
)
return {"grade_response": grade_response.binary_score}
async def gen_answer_normal(state: StateRAGAccuracy):
question = state["user_query"]
history = state["messages_history"]
gen_answer_response = await gen_normal_answer_chain.ainvoke(
{
"question": question,
"history": history,
"topic": state["topic"],
}
)
final_response = (
gen_answer_response.content + "\n**Nguồn thông tin: Kiến thức của AI**"
)
return {"llm_response": final_response}