ABAO77's picture
Upload 38 files
24dcddf verified
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from .func import (
StateRAGAccuracy,
StateRAGSpeed,
trim_history,
route,
transform_query,
retrieve_document,
grade_document,
generate_answer_rag,
grade_hallucinations,
gen_answer_normal,
)
class PrimaryChatBotAccuracy:
def __init__(self):
self.builder = StateGraph(StateRAGAccuracy)
@staticmethod
def routing_after_route(state: StateRAGAccuracy):
if state["route_response"] == "vectorstore":
return "transform_query"
else:
return "generate_answer_normal"
@staticmethod
def routing_after_retrieve_document(state: StateRAGAccuracy):
return (
"grade_document"
if len(state["documents"]) != 0
else "generate_answer_normal"
)
@staticmethod
def route_after_grade_document(state: StateRAGAccuracy):
return (
"generate_answer_rag"
if len(state["documents"]) != 0
else "generate_answer_normal"
)
@staticmethod
def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy):
return END if state["grade_response"] == "yes" else "generate_answer_normal"
def node(self):
self.builder.add_node("trim_history", trim_history)
self.builder.add_node("route", route)
self.builder.add_node("transform_query", transform_query)
self.builder.add_node("retrieve_document", retrieve_document)
self.builder.add_node("grade_document", grade_document)
self.builder.add_node("generate_answer_rag", generate_answer_rag)
self.builder.add_node("grade_hallucinations", grade_hallucinations)
self.builder.add_node("generate_answer_normal", gen_answer_normal)
def edge(self):
self.builder.add_edge(START, "trim_history")
self.builder.add_edge("trim_history", "route")
self.builder.add_conditional_edges(
"route",
self.routing_after_route,
{
"transform_query": "transform_query",
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_edge("transform_query", "retrieve_document")
self.builder.add_conditional_edges(
"retrieve_document",
self.routing_after_retrieve_document,
{
"grade_document": "grade_document",
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_conditional_edges(
"grade_document",
self.route_after_grade_document,
{
"generate_answer_rag": "generate_answer_rag",
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_edge("generate_answer_rag", "grade_hallucinations")
self.builder.add_conditional_edges(
"grade_hallucinations",
self.routing_check_pass_grade_hallucinations,
{
END: END,
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_edge("generate_answer_normal", END)
def __call__(self) -> CompiledStateGraph:
self.node()
self.edge()
return self.builder.compile()
class PrimaryChatBotSpeed:
def __init__(self):
self.builder = StateGraph(StateRAGSpeed)
@staticmethod
def routing_after_retrieve_document(state: StateRAGAccuracy):
return (
"generate_answer_rag"
if len(state["documents"]) != 0
else "generate_answer_normal"
)
@staticmethod
def routing_after_gen_answer_rag(state: StateRAGAccuracy):
return END if state["document_id_selected"] else "generate_answer_normal"
def node(self):
self.builder.add_node("trim_history", trim_history)
self.builder.add_node("transform_query", transform_query)
self.builder.add_node("retrieve_document", retrieve_document)
self.builder.add_node("generate_answer_rag", generate_answer_rag)
self.builder.add_node("generate_answer_normal", gen_answer_normal)
def edge(self):
self.builder.add_edge(START, "trim_history")
self.builder.add_edge("trim_history", "transform_query")
self.builder.add_edge("transform_query", "retrieve_document")
self.builder.add_conditional_edges(
"retrieve_document",
self.routing_after_retrieve_document,
{
"generate_answer_rag": "generate_answer_rag",
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_conditional_edges(
"generate_answer_rag",
self.routing_after_gen_answer_rag,
{
END: END,
"generate_answer_normal": "generate_answer_normal",
},
)
self.builder.add_edge("generate_answer_normal", END)
def __call__(self) -> CompiledStateGraph:
self.node()
self.edge()
return self.builder.compile()
rag_speed = PrimaryChatBotSpeed()()
rag_accuracy = PrimaryChatBotAccuracy()()
#