github-actions[bot]
Deploy from GitHub Actions
dfa6a46
import os
from typing import List, Literal
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import END, StateGraph, START
from project.pipeline.rag import RAGPipeline
from project.utils.model_loader import ModelLoader
from project.prompts.prompt_template import ROUTER_PROMPT, WEB_SEARCH_PROMPT
from project.logger.logging import get_logger
logger = get_logger(__name__)
class GradeDocuments(BaseModel):
binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")
class GraphState(TypedDict):
question: str
generation: str
web_search: str
documents: List[str]
class AgentWorkflow:
def __init__(self, config_path: str = None):
self.config_path = config_path
self.model_loader = ModelLoader(config_path)
self.llm = self.model_loader.load_llm()
self.rag_pipeline = RAGPipeline(config_path)
self.web_search_tool = None
self._setup_web_search()
self.workflow = None
self.app = None
self._setup_graders()
logger.info("AgentWorkflow initialized")
def _setup_web_search(self):
tavily_key = os.getenv("TAVILY_API_KEY")
if tavily_key:
try:
from langchain_community.tools.tavily_search import TavilySearchResults
self.web_search_tool = TavilySearchResults(k=3)
logger.info("Web search tool initialized")
except Exception as e:
logger.warning(f"Could not initialize web search: {str(e)}")
self.web_search_tool = None
else:
logger.warning("TAVILY_API_KEY not found, web search disabled")
def _setup_graders(self):
grade_prompt = """You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keywords or semantic meaning related to the question, grade it as relevant.
Give ONLY a binary score 'yes' or 'no' to indicate whether the document is relevant to the question.
Retrieved document: {document}
User question: {question}
Answer (yes or no):"""
self.grade_prompt_text = grade_prompt
self.retrieval_grader = self.llm | StrOutputParser()
rewrite_prompt = """You are a question re-writer that converts an input question to a better version optimized for web search.
Look at the input and try to reason about the underlying semantic intent/meaning.
Provide only the improved question without any explanation.
Initial question: {question}
Improved question:"""
self.rewrite_prompt_text = rewrite_prompt
self.question_rewriter = self.llm | StrOutputParser()
def setup(self, pdf_path: str = None, use_attention_paper: bool = True):
self.rag_pipeline.setup(pdf_path=pdf_path, use_attention_paper=use_attention_paper)
self._build_graph()
logger.info("Agent workflow setup complete")
def retrieve(self, state: GraphState):
logger.info("---RETRIEVE---")
question = state["question"]
documents = self.rag_pipeline.retriever.invoke(question)
return {"documents": documents, "question": question}
def grade_documents(self, state: GraphState):
logger.info("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
filtered_docs = []
web_search = "No"
for d in documents:
prompt_filled = self.grade_prompt_text.format(
document=d.page_content[:500],
question=question
)
score = self.retrieval_grader.invoke(prompt_filled)
grade = score.strip().lower()
if "yes" in grade:
logger.info("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
logger.info("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def generate(self, state: GraphState):
logger.info("---GENERATE---")
question = state["question"]
documents = state["documents"]
generation = self.rag_pipeline.chain.invoke({"question": question})
return {"documents": documents, "question": question, "generation": generation}
def transform_query(self, state: GraphState):
logger.info("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
prompt_filled = self.rewrite_prompt_text.format(question=question)
better_question = self.question_rewriter.invoke(prompt_filled)
return {"documents": documents, "question": better_question}
def web_search(self, state: GraphState):
logger.info("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]
if self.web_search_tool is None:
logger.warning("Web search tool not available, skipping")
return {"documents": documents, "question": question}
try:
response = self.web_search_tool.invoke({"query": question})
if not response:
logger.warning("No results from web search")
return {"documents": documents, "question": question}
web_results = "\n".join([d["content"] for d in response if "content" in d])
web_doc = Document(page_content=web_results)
documents.append(web_doc)
except Exception as e:
logger.error(f"Web search failed: {str(e)}")
return {"documents": documents, "question": question}
def decide_to_generate(self, state: GraphState) -> Literal["transform_query", "generate"]:
logger.info("---ASSESS GRADED DOCUMENTS---")
documents = state.get("documents", [])
if len(documents) == 0:
logger.info("---DECISION: NO RELEVANT DOCUMENTS, TRANSFORM QUERY---")
return "transform_query"
else:
logger.info("---DECISION: RELEVANT DOCUMENTS FOUND, GENERATE---")
return "generate"
def _build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("grade_documents", self.grade_documents)
workflow.add_node("generate", self.generate)
workflow.add_node("transform_query", self.transform_query)
workflow.add_node("web_search", self.web_search)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
self.decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
self.app = workflow.compile()
logger.info("LangGraph workflow compiled")
def save_graph(self, output_path: str = "workflow.png"):
try:
from IPython.display import Image
graph_image = self.app.get_graph().draw_mermaid_png()
with open(output_path, "wb") as f:
f.write(graph_image)
logger.info(f"Workflow graph saved to {output_path}")
except Exception as e:
logger.error(f"Failed to save graph: {str(e)}")
def run(self, question: str) -> str:
if self.app is None:
raise ValueError("Workflow not setup. Call setup() first.")
inputs = {"question": question}
for output in self.app.stream(inputs):
for key, value in output.items():
logger.info(f"Node '{key}' completed")
final_generation = value.get("generation", "No answer generated")
return final_generation