import os from typing import List, Literal from typing_extensions import TypedDict from pydantic import BaseModel, Field from langchain.schema import Document from langchain_core.output_parsers import StrOutputParser from langgraph.graph import END, StateGraph, START from project.pipeline.rag import RAGPipeline from project.utils.model_loader import ModelLoader from project.prompts.prompt_template import ROUTER_PROMPT, WEB_SEARCH_PROMPT from project.logger.logging import get_logger logger = get_logger(__name__) class GradeDocuments(BaseModel): binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'") class GraphState(TypedDict): question: str generation: str web_search: str documents: List[str] class AgentWorkflow: def __init__(self, config_path: str = None): self.config_path = config_path self.model_loader = ModelLoader(config_path) self.llm = self.model_loader.load_llm() self.rag_pipeline = RAGPipeline(config_path) self.web_search_tool = None self._setup_web_search() self.workflow = None self.app = None self._setup_graders() logger.info("AgentWorkflow initialized") def _setup_web_search(self): tavily_key = os.getenv("TAVILY_API_KEY") if tavily_key: try: from langchain_community.tools.tavily_search import TavilySearchResults self.web_search_tool = TavilySearchResults(k=3) logger.info("Web search tool initialized") except Exception as e: logger.warning(f"Could not initialize web search: {str(e)}") self.web_search_tool = None else: logger.warning("TAVILY_API_KEY not found, web search disabled") def _setup_graders(self): grade_prompt = """You are a grader assessing relevance of a retrieved document to a user question. If the document contains keywords or semantic meaning related to the question, grade it as relevant. Give ONLY a binary score 'yes' or 'no' to indicate whether the document is relevant to the question. Retrieved document: {document} User question: {question} Answer (yes or no):""" self.grade_prompt_text = grade_prompt self.retrieval_grader = self.llm | StrOutputParser() rewrite_prompt = """You are a question re-writer that converts an input question to a better version optimized for web search. Look at the input and try to reason about the underlying semantic intent/meaning. Provide only the improved question without any explanation. Initial question: {question} Improved question:""" self.rewrite_prompt_text = rewrite_prompt self.question_rewriter = self.llm | StrOutputParser() def setup(self, pdf_path: str = None, use_attention_paper: bool = True): self.rag_pipeline.setup(pdf_path=pdf_path, use_attention_paper=use_attention_paper) self._build_graph() logger.info("Agent workflow setup complete") def retrieve(self, state: GraphState): logger.info("---RETRIEVE---") question = state["question"] documents = self.rag_pipeline.retriever.invoke(question) return {"documents": documents, "question": question} def grade_documents(self, state: GraphState): logger.info("---CHECK DOCUMENT RELEVANCE TO QUESTION---") question = state["question"] documents = state["documents"] filtered_docs = [] web_search = "No" for d in documents: prompt_filled = self.grade_prompt_text.format( document=d.page_content[:500], question=question ) score = self.retrieval_grader.invoke(prompt_filled) grade = score.strip().lower() if "yes" in grade: logger.info("---GRADE: DOCUMENT RELEVANT---") filtered_docs.append(d) else: logger.info("---GRADE: DOCUMENT NOT RELEVANT---") web_search = "Yes" return {"documents": filtered_docs, "question": question, "web_search": web_search} def generate(self, state: GraphState): logger.info("---GENERATE---") question = state["question"] documents = state["documents"] generation = self.rag_pipeline.chain.invoke({"question": question}) return {"documents": documents, "question": question, "generation": generation} def transform_query(self, state: GraphState): logger.info("---TRANSFORM QUERY---") question = state["question"] documents = state["documents"] prompt_filled = self.rewrite_prompt_text.format(question=question) better_question = self.question_rewriter.invoke(prompt_filled) return {"documents": documents, "question": better_question} def web_search(self, state: GraphState): logger.info("---WEB SEARCH---") question = state["question"] documents = state["documents"] if self.web_search_tool is None: logger.warning("Web search tool not available, skipping") return {"documents": documents, "question": question} try: response = self.web_search_tool.invoke({"query": question}) if not response: logger.warning("No results from web search") return {"documents": documents, "question": question} web_results = "\n".join([d["content"] for d in response if "content" in d]) web_doc = Document(page_content=web_results) documents.append(web_doc) except Exception as e: logger.error(f"Web search failed: {str(e)}") return {"documents": documents, "question": question} def decide_to_generate(self, state: GraphState) -> Literal["transform_query", "generate"]: logger.info("---ASSESS GRADED DOCUMENTS---") documents = state.get("documents", []) if len(documents) == 0: logger.info("---DECISION: NO RELEVANT DOCUMENTS, TRANSFORM QUERY---") return "transform_query" else: logger.info("---DECISION: RELEVANT DOCUMENTS FOUND, GENERATE---") return "generate" def _build_graph(self): workflow = StateGraph(GraphState) workflow.add_node("retrieve", self.retrieve) workflow.add_node("grade_documents", self.grade_documents) workflow.add_node("generate", self.generate) workflow.add_node("transform_query", self.transform_query) workflow.add_node("web_search", self.web_search) workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "grade_documents") workflow.add_conditional_edges( "grade_documents", self.decide_to_generate, { "transform_query": "transform_query", "generate": "generate", }, ) workflow.add_edge("transform_query", "web_search") workflow.add_edge("web_search", "generate") workflow.add_edge("generate", END) self.app = workflow.compile() logger.info("LangGraph workflow compiled") def save_graph(self, output_path: str = "workflow.png"): try: from IPython.display import Image graph_image = self.app.get_graph().draw_mermaid_png() with open(output_path, "wb") as f: f.write(graph_image) logger.info(f"Workflow graph saved to {output_path}") except Exception as e: logger.error(f"Failed to save graph: {str(e)}") def run(self, question: str) -> str: if self.app is None: raise ValueError("Workflow not setup. Call setup() first.") inputs = {"question": question} for output in self.app.stream(inputs): for key, value in output.items(): logger.info(f"Node '{key}' completed") final_generation = value.get("generation", "No answer generated") return final_generation