Multi-Rag / src /graphs /builder.py
VashuTheGreat2's picture
Upload folder using huggingface_hub
9c90775 verified
Raw
History Blame Contribute Delete
3.03 kB
from langgraph.graph import StateGraph, START, END
from src.memory import memory
from src.states.Main_State import State
from src.nodes.main_nodes import (
orchastrator_node,
query_generation_node,
retreiver_node,
is_retreived_data_enough,
web_search_node,
document_refiner,
get_chat_node_content,
chat_node
)
graph = StateGraph(State)
# ===================== Adding Nodes =====================
graph.add_node(
"orchastrator",
orchastrator_node
)
graph.add_node(
"query_generation",
query_generation_node
)
graph.add_node(
"retreiver",
retreiver_node
)
graph.add_node(
"relevance_checker",
is_retreived_data_enough
)
graph.add_node(
"web_search",
web_search_node
)
graph.add_node(
"document_refiner",
document_refiner
)
graph.add_node(
"context_builder",
get_chat_node_content
)
graph.add_node(
"chat",
chat_node
)
# ===================== Adding Edges =====================
graph.add_edge(
START,
"orchastrator"
)
graph.add_conditional_edges(
"orchastrator",
lambda state: (
"query_generation"
if state["require_db_search"]
else "chat"
),
{
"query_generation": "query_generation",
"chat": "chat"
}
)
graph.add_edge(
"query_generation",
"retreiver"
)
graph.add_edge(
"retreiver",
"relevance_checker"
)
graph.add_conditional_edges(
"relevance_checker",
lambda state: (
"document_refiner"
if state["relevance"] == "CORRECT" or state["relevance"] == "AMBIGUOUS"
else "web_search"
),
{
"document_refiner": "document_refiner",
"web_search": "web_search"
}
)
graph.add_edge(
"document_refiner",
"context_builder"
)
graph.add_edge(
"web_search",
"context_builder"
)
graph.add_edge(
"context_builder",
"chat"
)
graph.add_edge(
"chat",
END
)
graph = graph.compile(
checkpointer=memory
)
try:
graph.get_graph().draw_mermaid_png(
output_file_path="graph_visualization.png"
)
except Exception as e:
logging.exception(
f"Failed to generate graph visualization: {e}"
)
async def deleteThread(thread_id: str):
try:
cp = memory
state = await cp.aget_tuple(config={'configurable': {'thread_id': thread_id}})
if state is None:
logging.info(f"Thread {thread_id} not found, nothing to delete.")
return False
await cp.adelete_thread(thread_id=thread_id)
logging.info(f"Thread {thread_id} deleted successfully.")
return True
except Exception as e:
logging.error(f"Error deleting thread {thread_id}: {e}")
return False
async def load_conversation(thread_id):
try:
state = graph.get_state(config={'configurable': {'thread_id': thread_id}})
return state.values.get('messages', [])
except Exception as e:
logging.error(f"Error loading conversation: {e}")
return []