ABAO77's picture
Upload 55 files
83a18c0 verified
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from .func import (
StateRAGAccuracy,
StateRAGSpeed,
trim_history,
route,
transform_query,
retrieve_document,
grade_document,
generate_answer_rag,
grade_hallucinations,
generate_answer_rag_primary_chat,
gen_answer_normal,
)
# class PrimaryChatBotAccuracy:
# def __init__(self):
# self.builder = StateGraph(StateRAGAccuracy)
# @staticmethod
# def routing_after_route(state: StateRAGAccuracy):
# if state["route_response"] == "vectorstore":
# return "transform_query"
# else:
# return "generate_answer_normal"
# @staticmethod
# def routing_after_retrieve_document(state: StateRAGAccuracy):
# return (
# "grade_document"
# if len(state["documents"]) != 0
# else "generate_answer_normal"
# )
# @staticmethod
# def route_after_grade_document(state: StateRAGAccuracy):
# return (
# "generate_answer_rag"
# if len(state["documents"]) != 0
# else "generate_answer_normal"
# )
# @staticmethod
# def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy):
# return END if state["grade_response"] == "yes" else "generate_answer_normal"
# def node(self):
# self.builder.add_node("trim_history", trim_history)
# self.builder.add_node("route", route)
# self.builder.add_node("transform_query", transform_query)
# self.builder.add_node("retrieve_document", retrieve_document)
# self.builder.add_node("grade_document", grade_document)
# self.builder.add_node("generate_answer_rag", generate_answer_rag)
# self.builder.add_node("grade_hallucinations", grade_hallucinations)
# self.builder.add_node("generate_answer_normal", gen_answer_normal)
# def edge(self):
# self.builder.add_edge(START, "trim_history")
# # self.builder.add_edge("trim_history", "route")
# # self.builder.add_conditional_edges(
# # "route",
# # self.routing_after_route,
# # {
# # "transform_query": "transform_query",
# # "generate_answer_normal": "generate_answer_normal",
# # },
# # )
# self.builder.add_edge("trim_history", "transform_query")
# self.builder.add_edge("transform_query", "retrieve_document")
# self.builder.add_conditional_edges(
# "retrieve_document",
# self.routing_after_retrieve_document,
# {
# "grade_document": "grade_document",
# "generate_answer_normal": "generate_answer_normal",
# },
# )
# # self.builder.add_conditional_edges(
# # "grade_document",
# # self.route_after_grade_document,
# # {
# # "generate_answer_rag": "generate_answer_rag",
# # "generate_answer_normal": "generate_answer_normal",
# # },
# # )
# self.builder.add_edge("generate_answer_rag", "grade_hallucinations")
# self.builder.add_conditional_edges(
# "grade_hallucinations",
# self.routing_check_pass_grade_hallucinations,
# {
# END: END,
# "generate_answer_normal": "generate_answer_normal",
# },
# )
# self.builder.add_edge("generate_answer_normal", END)
# def __call__(self) -> CompiledStateGraph:
# self.node()
# self.edge()
# return self.builder.compile()
class PrimaryChatBotSpeed:
def __init__(self):
self.builder = StateGraph(StateRAGSpeed)
@staticmethod
def routing_after_retrieve_document(state: StateRAGSpeed):
return (
"generate_answer_rag"
if len(state["documents"]) != 0
else "generate_answer_normal"
)
@staticmethod
def routing_after_gen_answer_rag(state: StateRAGSpeed):
# return END if state["document_id_selected"] else "generate_answer_normal"
return END
def node(self):
self.builder.add_node("trim_history", trim_history)
self.builder.add_node("transform_query", transform_query)
self.builder.add_node("retrieve_document", retrieve_document)
self.builder.add_node(
"generate_answer_rag_primary_chat", generate_answer_rag_primary_chat
)
# self.builder.add_node("generate_answer_normal", gen_answer_normal)
def edge(self):
self.builder.add_edge(START, "trim_history")
self.builder.add_edge("trim_history", "transform_query")
self.builder.add_edge("transform_query", "retrieve_document")
self.builder.add_edge("retrieve_document", "generate_answer_rag_primary_chat")
# self.builder.add_conditional_edges(
# "generate_answer_rag_primary_chat",
# self.routing_after_gen_answer_rag,
# {
# END: END,
# "generate_answer_normal": "generate_answer_normal",
# },
# )
# self.builder.add_edge("generate_answer_normal", END)
self.builder.add_edge("generate_answer_rag_primary_chat", END)
def __call__(self) -> CompiledStateGraph:
self.node()
self.edge()
return self.builder.compile()
class PrimaryChatBotAccuracy:
def __init__(self):
self.builder = StateGraph(StateRAGAccuracy)
@staticmethod
def routing_after_retrieve_document(state: StateRAGAccuracy):
return (
"generate_answer_rag"
if len(state["documents"]) != 0
else "generate_answer_normal"
)
@staticmethod
def routing_after_gen_answer_rag(state: StateRAGAccuracy):
# return END if state["document_id_selected"] else "generate_answer_normal"
return END
def node(self):
self.builder.add_node("trim_history", trim_history)
self.builder.add_node("transform_query", transform_query)
self.builder.add_node("retrieve_document", retrieve_document)
self.builder.add_node("generate_answer_rag", generate_answer_rag)
# self.builder.add_node("generate_answer_normal", gen_answer_normal)
def edge(self):
self.builder.add_edge(START, "trim_history")
self.builder.add_edge("trim_history", "transform_query")
self.builder.add_edge("transform_query", "retrieve_document")
self.builder.add_edge("retrieve_document", "generate_answer_rag")
# self.builder.add_conditional_edges(
# "generate_answer_rag_primary_chat",
# self.routing_after_gen_answer_rag,
# {
# END: END,
# "generate_answer_normal": "generate_answer_normal",
# },
# )
# self.builder.add_edge("generate_answer_normal", END)
self.builder.add_edge("generate_answer_rag", END)
def __call__(self) -> CompiledStateGraph:
self.node()
self.edge()
return self.builder.compile()
primary_chat_agent = PrimaryChatBotSpeed()()
tutor_chat_agent = PrimaryChatBotAccuracy()()
#