from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph from .func import ( StateRAGAccuracy, StateRAGSpeed, trim_history, route, transform_query, retrieve_document, grade_document, generate_answer_rag, grade_hallucinations, gen_answer_normal, ) class PrimaryChatBotAccuracy: def __init__(self): self.builder = StateGraph(StateRAGAccuracy) @staticmethod def routing_after_route(state: StateRAGAccuracy): if state["route_response"] == "vectorstore": return "transform_query" else: return "generate_answer_normal" @staticmethod def routing_after_retrieve_document(state: StateRAGAccuracy): return ( "grade_document" if len(state["documents"]) != 0 else "generate_answer_normal" ) @staticmethod def route_after_grade_document(state: StateRAGAccuracy): return ( "generate_answer_rag" if len(state["documents"]) != 0 else "generate_answer_normal" ) @staticmethod def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy): return END if state["grade_response"] == "yes" else "generate_answer_normal" def node(self): self.builder.add_node("trim_history", trim_history) self.builder.add_node("route", route) self.builder.add_node("transform_query", transform_query) self.builder.add_node("retrieve_document", retrieve_document) self.builder.add_node("grade_document", grade_document) self.builder.add_node("generate_answer_rag", generate_answer_rag) self.builder.add_node("grade_hallucinations", grade_hallucinations) self.builder.add_node("generate_answer_normal", gen_answer_normal) def edge(self): self.builder.add_edge(START, "trim_history") self.builder.add_edge("trim_history", "route") self.builder.add_conditional_edges( "route", self.routing_after_route, { "transform_query": "transform_query", "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_edge("transform_query", "retrieve_document") self.builder.add_conditional_edges( "retrieve_document", self.routing_after_retrieve_document, { "grade_document": "grade_document", "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_conditional_edges( "grade_document", self.route_after_grade_document, { "generate_answer_rag": "generate_answer_rag", "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_edge("generate_answer_rag", "grade_hallucinations") self.builder.add_conditional_edges( "grade_hallucinations", self.routing_check_pass_grade_hallucinations, { END: END, "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_edge("generate_answer_normal", END) def __call__(self) -> CompiledStateGraph: self.node() self.edge() return self.builder.compile() class PrimaryChatBotSpeed: def __init__(self): self.builder = StateGraph(StateRAGSpeed) @staticmethod def routing_after_retrieve_document(state: StateRAGAccuracy): return ( "generate_answer_rag" if len(state["documents"]) != 0 else "generate_answer_normal" ) @staticmethod def routing_after_gen_answer_rag(state: StateRAGAccuracy): return END if state["document_id_selected"] else "generate_answer_normal" def node(self): self.builder.add_node("trim_history", trim_history) self.builder.add_node("transform_query", transform_query) self.builder.add_node("retrieve_document", retrieve_document) self.builder.add_node("generate_answer_rag", generate_answer_rag) self.builder.add_node("generate_answer_normal", gen_answer_normal) def edge(self): self.builder.add_edge(START, "trim_history") self.builder.add_edge("trim_history", "transform_query") self.builder.add_edge("transform_query", "retrieve_document") self.builder.add_conditional_edges( "retrieve_document", self.routing_after_retrieve_document, { "generate_answer_rag": "generate_answer_rag", "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_conditional_edges( "generate_answer_rag", self.routing_after_gen_answer_rag, { END: END, "generate_answer_normal": "generate_answer_normal", }, ) self.builder.add_edge("generate_answer_normal", END) def __call__(self) -> CompiledStateGraph: self.node() self.edge() return self.builder.compile() rag_speed = PrimaryChatBotSpeed()() rag_accuracy = PrimaryChatBotAccuracy()() #