import os from typing import TypedDict, Optional, List, Literal from langchain_core.documents import Document from src.utils.helper import ( fake_token_counter, convert_list_context_source_to_str, convert_message, ) from src.utils.logger import logger from langchain_core.messages import trim_messages, AnyMessage from src.config.vector_store import ( vector_store_chatbot, vector_store_fresher, vector_store_tutor, ) from .prompt import ( RouteQuery, route_chain, transform_query_chain, ExtractFilter, extract_filter_chain, GradeDocuments, GenerateAnswer, GradeHallucinations, gen_normal_answer_chain, gen_answer_rag_chain, grade_documents_chain, gen_answer_rag_chain, grade_documents_chain, grade_hallucinations_chain, gen_answer_rag_tutor_chain, ) class StateRAGAccuracy(TypedDict): user_query: str | AnyMessage route_response: str messages_history: list documents: list[Document] filter: dict llm_response: AnyMessage grade_response: Literal["yes", "no"] language: str document_id_selected: Optional[List] topic: str class StateRAGSpeed(TypedDict): user_query: str | AnyMessage messages_history: list documents: list[Document] filter: dict llm_response: AnyMessage language: str document_id_selected: Optional[List] topic: str def trim_history(state: StateRAGAccuracy | StateRAGSpeed): history = ( convert_message(state["messages_history"]) if state.get("messages_history") else None ) if not history: return {"messages_history": []} chat_message_history = trim_messages( history, strategy="last", token_counter=fake_token_counter, max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 2000)), start_on="human", end_on="ai", include_system=False, allow_partial=False, ) return {"messages_history": chat_message_history} async def route(state: StateRAGAccuracy): logger.info(f"routing") question = state["user_query"] chat_history = state.get("messages_history", None) route_response: RouteQuery = await route_chain.ainvoke( { "question": question, "chat_history": chat_history, "topic": state["topic"], } ) logger.info(f"Route response: {route_response.datasource}") return {"route_response": route_response.datasource} async def transform_query(state: StateRAGAccuracy | StateRAGSpeed): question = state["user_query"] chat_history = state.get("messages_history", None) transform_response = await transform_query_chain.ainvoke( { "question": question, "chat_history": chat_history, "topic": state["topic"], } ) logger.info(f"Transform response: {transform_response.content}") return {"user_query": transform_response.content} async def retrieve_document(state: StateRAGAccuracy): question = state["user_query"] filter = state.get("filter", {}) logger.info(f"Filter: {filter}") if filter: retriever = vector_store_tutor.as_retriever( search_type="similarity_score_threshold", search_kwargs={"k": 5, "score_threshold": 0.3}, ) else: retriever = vector_store_chatbot.as_retriever( search_type="similarity_score_threshold", search_kwargs={"k": 5, "score_threshold": 0.3}, ) documents = retriever.invoke(question, filter=filter) show_doc = " \n =============\n".join([doc.page_content for doc in documents]) logger.info(f"Retrieved documents: {show_doc}") return {"documents": documents} async def grade_document(state: StateRAGAccuracy): question = state["user_query"] documents = state["documents"] inputs_bach = [ {"question": question, "document": doc.page_content} for doc in documents ] grade_document_response: list[GradeDocuments] = await grade_documents_chain.abatch( inputs_bach ) logger.info(f"Grade response: {grade_document_response}") document_index = [ index for index, doc in enumerate(grade_document_response) if doc.binary_score == "yes" ] filtered_documents = [documents[i] for i in document_index] return {"documents": filtered_documents} async def generate_answer_rag(state: StateRAGAccuracy): question = state["user_query"] documents = state["documents"] language = state["language"] context_str = convert_list_context_source_to_str(documents) gen_answer_response: GenerateAnswer = await gen_answer_rag_tutor_chain.ainvoke( { "question": question, "context": context_str, "language": language, "topic": state["topic"], } ) logger.info(f"Generate answer response: {gen_answer_response}") id_selected = gen_answer_response.selected_document_index return { "llm_response": gen_answer_response.answer, "document_id_selected": 1, } async def grade_hallucinations(state: StateRAGAccuracy): question = state["user_query"] llm_response = state["llm_response"] grade_response: GradeHallucinations = await grade_hallucinations_chain.ainvoke( {"question": question, "generation": llm_response} ) return {"grade_response": grade_response.binary_score} async def gen_answer_normal(state: StateRAGAccuracy): question = state["user_query"] history = state["messages_history"] gen_answer_response = await gen_normal_answer_chain.ainvoke( { "question": question, "history": history, "topic": state["topic"], } ) final_response = ( gen_answer_response.content + "\n**Nguồn thông tin: Kiến thức của AI**" ) return {"llm_response": final_response}