from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph from .func import ( StateRAGAccuracy, StateRAGSpeed, trim_history, route, transform_query, retrieve_document, grade_document, generate_answer_rag, grade_hallucinations, generate_answer_rag_primary_chat, gen_answer_normal, ) # class PrimaryChatBotAccuracy: # def __init__(self): # self.builder = StateGraph(StateRAGAccuracy) # @staticmethod # def routing_after_route(state: StateRAGAccuracy): # if state["route_response"] == "vectorstore": # return "transform_query" # else: # return "generate_answer_normal" # @staticmethod # def routing_after_retrieve_document(state: StateRAGAccuracy): # return ( # "grade_document" # if len(state["documents"]) != 0 # else "generate_answer_normal" # ) # @staticmethod # def route_after_grade_document(state: StateRAGAccuracy): # return ( # "generate_answer_rag" # if len(state["documents"]) != 0 # else "generate_answer_normal" # ) # @staticmethod # def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy): # return END if state["grade_response"] == "yes" else "generate_answer_normal" # def node(self): # self.builder.add_node("trim_history", trim_history) # self.builder.add_node("route", route) # self.builder.add_node("transform_query", transform_query) # self.builder.add_node("retrieve_document", retrieve_document) # self.builder.add_node("grade_document", grade_document) # self.builder.add_node("generate_answer_rag", generate_answer_rag) # self.builder.add_node("grade_hallucinations", grade_hallucinations) # self.builder.add_node("generate_answer_normal", gen_answer_normal) # def edge(self): # self.builder.add_edge(START, "trim_history") # # self.builder.add_edge("trim_history", "route") # # self.builder.add_conditional_edges( # # "route", # # self.routing_after_route, # # { # # "transform_query": "transform_query", # # "generate_answer_normal": "generate_answer_normal", # # }, # # ) # self.builder.add_edge("trim_history", "transform_query") # self.builder.add_edge("transform_query", "retrieve_document") # self.builder.add_conditional_edges( # "retrieve_document", # self.routing_after_retrieve_document, # { # "grade_document": "grade_document", # "generate_answer_normal": "generate_answer_normal", # }, # ) # # self.builder.add_conditional_edges( # # "grade_document", # # self.route_after_grade_document, # # { # # "generate_answer_rag": "generate_answer_rag", # # "generate_answer_normal": "generate_answer_normal", # # }, # # ) # self.builder.add_edge("generate_answer_rag", "grade_hallucinations") # self.builder.add_conditional_edges( # "grade_hallucinations", # self.routing_check_pass_grade_hallucinations, # { # END: END, # "generate_answer_normal": "generate_answer_normal", # }, # ) # self.builder.add_edge("generate_answer_normal", END) # def __call__(self) -> CompiledStateGraph: # self.node() # self.edge() # return self.builder.compile() class PrimaryChatBotSpeed: def __init__(self): self.builder = StateGraph(StateRAGSpeed) @staticmethod def routing_after_retrieve_document(state: StateRAGSpeed): return ( "generate_answer_rag" if len(state["documents"]) != 0 else "generate_answer_normal" ) @staticmethod def routing_after_gen_answer_rag(state: StateRAGSpeed): # return END if state["document_id_selected"] else "generate_answer_normal" return END def node(self): self.builder.add_node("trim_history", trim_history) self.builder.add_node("transform_query", transform_query) self.builder.add_node("retrieve_document", retrieve_document) self.builder.add_node( "generate_answer_rag_primary_chat", generate_answer_rag_primary_chat ) # self.builder.add_node("generate_answer_normal", gen_answer_normal) def edge(self): self.builder.add_edge(START, "trim_history") self.builder.add_edge("trim_history", "transform_query") self.builder.add_edge("transform_query", "retrieve_document") self.builder.add_edge("retrieve_document", "generate_answer_rag_primary_chat") # self.builder.add_conditional_edges( # "generate_answer_rag_primary_chat", # self.routing_after_gen_answer_rag, # { # END: END, # "generate_answer_normal": "generate_answer_normal", # }, # ) # self.builder.add_edge("generate_answer_normal", END) self.builder.add_edge("generate_answer_rag_primary_chat", END) def __call__(self) -> CompiledStateGraph: self.node() self.edge() return self.builder.compile() class PrimaryChatBotAccuracy: def __init__(self): self.builder = StateGraph(StateRAGAccuracy) @staticmethod def routing_after_retrieve_document(state: StateRAGAccuracy): return ( "generate_answer_rag" if len(state["documents"]) != 0 else "generate_answer_normal" ) @staticmethod def routing_after_gen_answer_rag(state: StateRAGAccuracy): # return END if state["document_id_selected"] else "generate_answer_normal" return END def node(self): self.builder.add_node("trim_history", trim_history) self.builder.add_node("transform_query", transform_query) self.builder.add_node("retrieve_document", retrieve_document) self.builder.add_node("generate_answer_rag", generate_answer_rag) # self.builder.add_node("generate_answer_normal", gen_answer_normal) def edge(self): self.builder.add_edge(START, "trim_history") self.builder.add_edge("trim_history", "transform_query") self.builder.add_edge("transform_query", "retrieve_document") self.builder.add_edge("retrieve_document", "generate_answer_rag") # self.builder.add_conditional_edges( # "generate_answer_rag_primary_chat", # self.routing_after_gen_answer_rag, # { # END: END, # "generate_answer_normal": "generate_answer_normal", # }, # ) # self.builder.add_edge("generate_answer_normal", END) self.builder.add_edge("generate_answer_rag", END) def __call__(self) -> CompiledStateGraph: self.node() self.edge() return self.builder.compile() primary_chat_agent = PrimaryChatBotSpeed()() tutor_chat_agent = PrimaryChatBotAccuracy()() #