Spaces:
Build error
Build error
| import os | |
| from typing import TypedDict, Optional, List, Literal | |
| from langchain_core.documents import Document | |
| from src.utils.helper import ( | |
| fake_token_counter, | |
| convert_list_context_source_to_str, | |
| convert_message, | |
| ) | |
| from src.utils.logger import logger | |
| from langchain_core.messages import trim_messages, AnyMessage | |
| from src.config.vector_store import ( | |
| vector_store_chatbot, | |
| vector_store_fresher, | |
| vector_store_tutor, | |
| ) | |
| from .prompt import ( | |
| RouteQuery, | |
| route_chain, | |
| transform_query_chain, | |
| ExtractFilter, | |
| extract_filter_chain, | |
| GradeDocuments, | |
| GenerateAnswer, | |
| GradeHallucinations, | |
| gen_normal_answer_chain, | |
| gen_answer_rag_chain, | |
| grade_documents_chain, | |
| gen_answer_rag_chain, | |
| grade_documents_chain, | |
| grade_hallucinations_chain, | |
| gen_answer_rag_tutor_chain, | |
| ) | |
| class StateRAGAccuracy(TypedDict): | |
| user_query: str | AnyMessage | |
| route_response: str | |
| messages_history: list | |
| documents: list[Document] | |
| filter: dict | |
| llm_response: AnyMessage | |
| grade_response: Literal["yes", "no"] | |
| language: str | |
| document_id_selected: Optional[List] | |
| topic: str | |
| class StateRAGSpeed(TypedDict): | |
| user_query: str | AnyMessage | |
| messages_history: list | |
| documents: list[Document] | |
| filter: dict | |
| llm_response: AnyMessage | |
| language: str | |
| document_id_selected: Optional[List] | |
| topic: str | |
| def trim_history(state: StateRAGAccuracy | StateRAGSpeed): | |
| history = ( | |
| convert_message(state["messages_history"]) | |
| if state.get("messages_history") | |
| else None | |
| ) | |
| if not history: | |
| return {"messages_history": []} | |
| chat_message_history = trim_messages( | |
| history, | |
| strategy="last", | |
| token_counter=fake_token_counter, | |
| max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 2000)), | |
| start_on="human", | |
| end_on="ai", | |
| include_system=False, | |
| allow_partial=False, | |
| ) | |
| return {"messages_history": chat_message_history} | |
| async def route(state: StateRAGAccuracy): | |
| logger.info(f"routing") | |
| question = state["user_query"] | |
| chat_history = state.get("messages_history", None) | |
| route_response: RouteQuery = await route_chain.ainvoke( | |
| { | |
| "question": question, | |
| "chat_history": chat_history, | |
| "topic": state["topic"], | |
| } | |
| ) | |
| logger.info(f"Route response: {route_response.datasource}") | |
| return {"route_response": route_response.datasource} | |
| async def transform_query(state: StateRAGAccuracy | StateRAGSpeed): | |
| question = state["user_query"] | |
| chat_history = state.get("messages_history", None) | |
| transform_response = await transform_query_chain.ainvoke( | |
| { | |
| "question": question, | |
| "chat_history": chat_history, | |
| "topic": state["topic"], | |
| } | |
| ) | |
| logger.info(f"Transform response: {transform_response.content}") | |
| return {"user_query": transform_response.content} | |
| async def retrieve_document(state: StateRAGAccuracy): | |
| question = state["user_query"] | |
| filter = state.get("filter", {}) | |
| logger.info(f"Filter: {filter}") | |
| if filter: | |
| retriever = vector_store_tutor.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": 5, "score_threshold": 0.3}, | |
| ) | |
| else: | |
| retriever = vector_store_chatbot.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": 5, "score_threshold": 0.3}, | |
| ) | |
| documents = retriever.invoke(question, filter=filter) | |
| show_doc = " \n =============\n".join([doc.page_content for doc in documents]) | |
| logger.info(f"Retrieved documents: {show_doc}") | |
| return {"documents": documents} | |
| async def grade_document(state: StateRAGAccuracy): | |
| question = state["user_query"] | |
| documents = state["documents"] | |
| inputs_bach = [ | |
| {"question": question, "document": doc.page_content} for doc in documents | |
| ] | |
| grade_document_response: list[GradeDocuments] = await grade_documents_chain.abatch( | |
| inputs_bach | |
| ) | |
| logger.info(f"Grade response: {grade_document_response}") | |
| document_index = [ | |
| index | |
| for index, doc in enumerate(grade_document_response) | |
| if doc.binary_score == "yes" | |
| ] | |
| filtered_documents = [documents[i] for i in document_index] | |
| return {"documents": filtered_documents} | |
| async def generate_answer_rag(state: StateRAGAccuracy): | |
| question = state["user_query"] | |
| documents = state["documents"] | |
| language = state["language"] | |
| context_str = convert_list_context_source_to_str(documents) | |
| gen_answer_response: GenerateAnswer = await gen_answer_rag_tutor_chain.ainvoke( | |
| { | |
| "question": question, | |
| "context": context_str, | |
| "language": language, | |
| "topic": state["topic"], | |
| } | |
| ) | |
| logger.info(f"Generate answer response: {gen_answer_response}") | |
| id_selected = gen_answer_response.selected_document_index | |
| return { | |
| "llm_response": gen_answer_response.answer, | |
| "document_id_selected": 1, | |
| } | |
| async def grade_hallucinations(state: StateRAGAccuracy): | |
| question = state["user_query"] | |
| llm_response = state["llm_response"] | |
| grade_response: GradeHallucinations = await grade_hallucinations_chain.ainvoke( | |
| {"question": question, "generation": llm_response} | |
| ) | |
| return {"grade_response": grade_response.binary_score} | |
| async def gen_answer_normal(state: StateRAGAccuracy): | |
| question = state["user_query"] | |
| history = state["messages_history"] | |
| gen_answer_response = await gen_normal_answer_chain.ainvoke( | |
| { | |
| "question": question, | |
| "history": history, | |
| "topic": state["topic"], | |
| } | |
| ) | |
| final_response = ( | |
| gen_answer_response.content + "\n**Nguồn thông tin: Kiến thức của AI**" | |
| ) | |
| return {"llm_response": final_response} | |