Spaces:
Build error
Build error
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.state import CompiledStateGraph | |
| from .func import ( | |
| StateRAGAccuracy, | |
| StateRAGSpeed, | |
| trim_history, | |
| route, | |
| transform_query, | |
| retrieve_document, | |
| grade_document, | |
| generate_answer_rag, | |
| grade_hallucinations, | |
| gen_answer_normal, | |
| ) | |
| class PrimaryChatBotAccuracy: | |
| def __init__(self): | |
| self.builder = StateGraph(StateRAGAccuracy) | |
| def routing_after_route(state: StateRAGAccuracy): | |
| if state["route_response"] == "vectorstore": | |
| return "transform_query" | |
| else: | |
| return "generate_answer_normal" | |
| def routing_after_retrieve_document(state: StateRAGAccuracy): | |
| return ( | |
| "grade_document" | |
| if len(state["documents"]) != 0 | |
| else "generate_answer_normal" | |
| ) | |
| def route_after_grade_document(state: StateRAGAccuracy): | |
| return ( | |
| "generate_answer_rag" | |
| if len(state["documents"]) != 0 | |
| else "generate_answer_normal" | |
| ) | |
| def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy): | |
| return END if state["grade_response"] == "yes" else "generate_answer_normal" | |
| def node(self): | |
| self.builder.add_node("trim_history", trim_history) | |
| self.builder.add_node("route", route) | |
| self.builder.add_node("transform_query", transform_query) | |
| self.builder.add_node("retrieve_document", retrieve_document) | |
| self.builder.add_node("grade_document", grade_document) | |
| self.builder.add_node("generate_answer_rag", generate_answer_rag) | |
| self.builder.add_node("grade_hallucinations", grade_hallucinations) | |
| self.builder.add_node("generate_answer_normal", gen_answer_normal) | |
| def edge(self): | |
| self.builder.add_edge(START, "trim_history") | |
| self.builder.add_edge("trim_history", "route") | |
| self.builder.add_conditional_edges( | |
| "route", | |
| self.routing_after_route, | |
| { | |
| "transform_query": "transform_query", | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_edge("transform_query", "retrieve_document") | |
| self.builder.add_conditional_edges( | |
| "retrieve_document", | |
| self.routing_after_retrieve_document, | |
| { | |
| "grade_document": "grade_document", | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_conditional_edges( | |
| "grade_document", | |
| self.route_after_grade_document, | |
| { | |
| "generate_answer_rag": "generate_answer_rag", | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_edge("generate_answer_rag", "grade_hallucinations") | |
| self.builder.add_conditional_edges( | |
| "grade_hallucinations", | |
| self.routing_check_pass_grade_hallucinations, | |
| { | |
| END: END, | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_edge("generate_answer_normal", END) | |
| def __call__(self) -> CompiledStateGraph: | |
| self.node() | |
| self.edge() | |
| return self.builder.compile() | |
| class PrimaryChatBotSpeed: | |
| def __init__(self): | |
| self.builder = StateGraph(StateRAGSpeed) | |
| def routing_after_retrieve_document(state: StateRAGAccuracy): | |
| return ( | |
| "generate_answer_rag" | |
| if len(state["documents"]) != 0 | |
| else "generate_answer_normal" | |
| ) | |
| def routing_after_gen_answer_rag(state: StateRAGAccuracy): | |
| return END if state["document_id_selected"] else "generate_answer_normal" | |
| def node(self): | |
| self.builder.add_node("trim_history", trim_history) | |
| self.builder.add_node("transform_query", transform_query) | |
| self.builder.add_node("retrieve_document", retrieve_document) | |
| self.builder.add_node("generate_answer_rag", generate_answer_rag) | |
| self.builder.add_node("generate_answer_normal", gen_answer_normal) | |
| def edge(self): | |
| self.builder.add_edge(START, "trim_history") | |
| self.builder.add_edge("trim_history", "transform_query") | |
| self.builder.add_edge("transform_query", "retrieve_document") | |
| self.builder.add_conditional_edges( | |
| "retrieve_document", | |
| self.routing_after_retrieve_document, | |
| { | |
| "generate_answer_rag": "generate_answer_rag", | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_conditional_edges( | |
| "generate_answer_rag", | |
| self.routing_after_gen_answer_rag, | |
| { | |
| END: END, | |
| "generate_answer_normal": "generate_answer_normal", | |
| }, | |
| ) | |
| self.builder.add_edge("generate_answer_normal", END) | |
| def __call__(self) -> CompiledStateGraph: | |
| self.node() | |
| self.edge() | |
| return self.builder.compile() | |
| rag_speed = PrimaryChatBotSpeed()() | |
| rag_accuracy = PrimaryChatBotAccuracy()() | |
| # | |