ABAO77's picture
Upload 55 files
83a18c0 verified
import os
from typing import TypedDict, Optional, List, Literal
from langchain_core.documents import Document
from src.utils.helper import (
fake_token_counter,
convert_list_context_source_to_str,
convert_message,
)
from src.utils.logger import logger
from langchain_core.messages import trim_messages, AnyMessage
from src.config.vector_store import (
vector_store_chatbot,
vector_store_fresher,
vector_store_tutor,
)
from .prompt import (
RouteQuery,
route_chain,
transform_query_chain,
ExtractFilter,
extract_filter_chain,
GradeDocuments,
GenerateAnswer,
GradeHallucinations,
gen_normal_answer_chain,
gen_answer_rag_chain,
grade_documents_chain,
gen_answer_rag_chain,
grade_documents_chain,
grade_hallucinations_chain,
gen_answer_rag_tutor_chain,
gen_answer_rag_primary_chat_chain,
)
class StateRAGAccuracy(TypedDict):
user_query: str | AnyMessage
route_response: str
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
grade_response: Literal["yes", "no"]
language: str
document_id_selected: Optional[List]
topic: str
class StateRAGSpeed(TypedDict):
user_query: str | AnyMessage
messages_history: list
documents: list[Document]
filter: dict
llm_response: AnyMessage
language: str
document_id_selected: Optional[List]
topic: str
def trim_history(state: StateRAGAccuracy | StateRAGSpeed):
history = (
convert_message(state["messages_history"])
if state.get("messages_history")
else None
)
if not history:
return {"messages_history": []}
chat_message_history = trim_messages(
history,
strategy="last",
token_counter=fake_token_counter,
max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 2000)),
start_on="human",
end_on="ai",
include_system=False,
allow_partial=False,
)
return {"messages_history": chat_message_history}
async def route(state: StateRAGAccuracy):
logger.info(f"routing")
# question = state["user_query"]
# chat_history = state.get("messages_history", None)
# route_response: RouteQuery = await route_chain.ainvoke(
# {
# "question": question,
# "chat_history": chat_history,
# "topic": state["topic"],
# }
# )
# logger.info(f"Route response: {route_response.datasource}")
# return {"route_response": route_response.datasource}
return {"route_response": "vectorstore"}
async def transform_query(state: StateRAGAccuracy | StateRAGSpeed):
question = state["user_query"]
chat_history = state.get("messages_history", None)
transform_response = await transform_query_chain.ainvoke(
{
"question": question,
"history": chat_history,
"topic": state["topic"],
}
)
logger.info(f"Transform response: {transform_response.content}")
return {"user_query": transform_response.content}
async def retrieve_document(state: StateRAGAccuracy):
question = state["user_query"]
filter = state.get("filter", {})
logger.info(f"Filter: {filter}")
if filter:
retriever = vector_store_tutor.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.3},
)
else:
retriever = vector_store_chatbot.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.3},
)
documents = retriever.invoke(question, filter=filter)
show_doc = " \n =============\n".join([doc.page_content for doc in documents])
logger.info(f"Retrieved documents: {show_doc}")
return {"documents": documents}
async def grade_document(state: StateRAGAccuracy):
# question = state["user_query"]
# documents = state["documents"]
# inputs_bach = [
# {"question": question, "document": doc.page_content} for doc in documents
# ]
# grade_document_response: list[GradeDocuments] = await grade_documents_chain.abatch(
# inputs_bach
# )
# logger.info(f"Grade response: {grade_document_response}")
# document_index = [
# index
# for index, doc in enumerate(grade_document_response)
# if doc.binary_score == "yes"
# ]
# filtered_documents = [documents[i] for i in document_index]
# return {"documents": filtered_documents}
return {}
async def generate_answer_rag(state: StateRAGAccuracy):
question = state["user_query"]
documents = state["documents"]
language = state["language"]
if documents:
context_str = convert_list_context_source_to_str(documents)
else:
context_str = "Không tìm thấy tài liệu"
gen_answer_response: GenerateAnswer = await gen_answer_rag_tutor_chain.ainvoke(
{
"question": question,
"context": context_str,
"language": language,
"topic": state["topic"],
}
)
logger.info(f"Generate answer response: {gen_answer_response}")
id_selected = gen_answer_response.selected_document_index
return {
"llm_response": gen_answer_response.answer,
"document_id_selected": id_selected,
}
async def generate_answer_rag_primary_chat(state: StateRAGAccuracy):
question = state["user_query"]
documents = state["documents"]
language = state["language"]
if documents:
context_str = convert_list_context_source_to_str(documents)
else:
context_str = "Không tìm thấy tài liệu"
gen_answer_response: GenerateAnswer = await gen_answer_rag_primary_chat_chain.ainvoke(
{
"question": question,
"context": context_str,
"language": language,
"topic": state["topic"],
}
)
logger.info(f"Generate answer response: {gen_answer_response}")
id_selected = gen_answer_response.selected_document_index
return {
"llm_response": gen_answer_response.answer,
"document_id_selected": id_selected,
}
async def grade_hallucinations(state: StateRAGAccuracy):
question = state["user_query"]
llm_response = state["llm_response"]
grade_response: GradeHallucinations = await grade_hallucinations_chain.ainvoke(
{"question": question, "generation": llm_response}
)
return {"grade_response": grade_response.binary_score}
async def gen_answer_normal(state: StateRAGAccuracy):
question = state["user_query"]
history = state["messages_history"]
gen_answer_response = await gen_normal_answer_chain.ainvoke(
{
"question": question,
"history": history,
"topic": state["topic"],
}
)
final_response = (
gen_answer_response.content + "\n**Nguồn thông tin: Kiến thức của AI**"
)
return {"llm_response": final_response}