DocuBot / agents /workflow.py
MaheshLEO4's picture
updated files
fcbb0ab
from typing import TypedDict, List, Dict, Annotated
from langchain.schema import Document
from langgraph.graph import StateGraph, END
import logging
from .research_agent import ResearchAgent
from .verification_agent import VerificationAgent
from .relevance_checker import RelevanceChecker
logger = logging.getLogger(__name__)
class AgentState(TypedDict):
question: str
documents: List[Document]
draft_answer: str
verification_report: str
is_relevant: bool
retry_count: int # Added to prevent infinite loops
class AgentWorkflow:
def __init__(self):
self.researcher = ResearchAgent()
self.verifier = VerificationAgent()
self.relevance_checker = RelevanceChecker()
self.workflow = self.build_workflow()
def build_workflow(self):
builder = StateGraph(AgentState)
builder.add_node("check_relevance", self._check_relevance_step)
builder.add_node("research", self._research_step)
builder.add_node("verify", self._verification_step)
builder.set_entry_point("check_relevance")
builder.add_conditional_edges(
"check_relevance",
lambda x: "research" if x["is_relevant"] else "end",
{"research": "research", "end": END}
)
builder.add_edge("research", "verify")
builder.add_conditional_edges(
"verify",
self._decide_next_step,
{"re_research": "research", "end": END}
)
return builder.compile()
def _check_relevance_step(self, state: AgentState):
# Logic to call relevance_checker.check
res = self.relevance_checker.check(state["question"], state["documents"])
return {"is_relevant": res != "NO_MATCH", "retry_count": 0}
def _research_step(self, state: AgentState):
res = self.researcher.generate(state["question"], state["documents"])
return {"draft_answer": res["draft_answer"], "retry_count": state.get("retry_count", 0) + 1}
def _verification_step(self, state: AgentState):
res = self.verifier.check(state["draft_answer"], state["documents"])
return {"verification_report": res["verification_report"]}
def _decide_next_step(self, state: AgentState):
# Break loop after 2 retries or if supported
if "Supported: YES" in state["verification_report"] or state["retry_count"] >= 3:
return "end"
return "re_research"
def full_pipeline(self, question: str, retriever):
docs = retriever.invoke(question) # Updated from get_relevant_documents
initial_state = {
"question": question,
"documents": docs,
"draft_answer": "",
"verification_report": "",
"is_relevant": False,
"retry_count": 0
}
return self.workflow.invoke(initial_state)