from typing import List, Literal, Optional, TypedDict from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import END, START, StateGraph from pydantic import BaseModel, Field from qdrant_client.http.models import ( FieldCondition, Filter, MatchValue, ) from clients import LLM, VECTOR_STORE class RetrievalState(TypedDict): """State for the agentic retrieval graph.""" original_query: str current_query: str category: Optional[str] topic: Optional[str] documents: List[Document] relevant_documents: List[Document] generation: str retry_count: int max_retries: int class GradeDocuments(BaseModel): """Grade whether a document is relevant to the query.""" is_relevant: Literal["yes", "no"] = Field( description="Is the document relevant to the query? 'yes' or 'no'" ) reason: str = Field(description="Brief reason for the relevance decision") def retrieve_documents(state: RetrievalState) -> RetrievalState: """Retrieve documents from vector store.""" query = state["current_query"] category = state.get("category") topic = state.get("topic") # Build Qdrant filter conditions = [] if category: conditions.append( FieldCondition(key="metadata.category", match=MatchValue(value=category)) ) if topic: conditions.append( FieldCondition(key="metadata.topic", match=MatchValue(value=topic)) ) qdrant_filter = Filter(must=conditions) if conditions else None documents = VECTOR_STORE.similarity_search( query, k=5, filter=qdrant_filter, ) return {**state, "documents": documents} def grade_documents(state: RetrievalState) -> RetrievalState: """Grade documents for relevance using LLM.""" query = state["original_query"] documents = state["documents"] if not documents: return {**state, "relevant_documents": []} # Create grader with structured output grader_llm = LLM.with_structured_output(GradeDocuments) grading_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a grader assessing relevance of a retrieved document to a user query. If the document contains keywords or semantic meaning related to the query, grade it as relevant. Be lenient - even partial relevance should be marked as 'yes'. Only mark 'no' if the document is completely unrelated.""", ), ( "human", """Query: {query} Document content: {document} Is this document relevant to the query?""", ), ] ) relevant_docs = [] for doc in documents: try: result = grader_llm.invoke( grading_prompt.format_messages( query=query, document=doc.page_content[:1000], # Limit content length ) ) if result.is_relevant == "yes": relevant_docs.append(doc) except Exception: # If grading fails, include the document (fail-safe) relevant_docs.append(doc) return {**state, "relevant_documents": relevant_docs} def rewrite_query(state: RetrievalState) -> RetrievalState: """Rewrite the query for better retrieval.""" original_query = state["original_query"] retry_count = state["retry_count"] rewrite_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are an expert at reformulating search queries. Given the original query, generate a better search query that might retrieve more relevant documents. Focus on: - Extracting key concepts and entities - Using synonyms or related terms - Being more specific or more general as appropriate Return ONLY the rewritten query, nothing else.""", ), ("human", "Original query: {query}\n\nRewritten query:"), ] ) response = LLM.invoke(rewrite_prompt.format_messages(query=original_query)) new_query = response.content.strip() return { **state, "current_query": new_query, "retry_count": retry_count + 1, } def generate_response(state: RetrievalState) -> RetrievalState: """Generate final response from relevant documents.""" relevant_docs = state["relevant_documents"] if not relevant_docs: return {**state, "generation": "No relevant memories found."} # Format documents formatted = [] for i, doc in enumerate(relevant_docs, 1): meta = doc.metadata formatted.append( f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}" ) return {**state, "generation": "\n".join(formatted)} def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]: """Decide whether to retry with a rewritten query.""" relevant_docs = state["relevant_documents"] retry_count = state["retry_count"] max_retries = state["max_retries"] # If we have relevant docs, generate response if relevant_docs: return "generate" # If no relevant docs and we can still retry, rewrite query if retry_count < max_retries: return "rewrite" # Max retries reached, generate (empty) response return "generate" def build_retrieval_graph(): workflow = StateGraph(RetrievalState) # Add nodes workflow.add_node("retrieve", retrieve_documents) workflow.add_node("grade", grade_documents) workflow.add_node("rewrite", rewrite_query) workflow.add_node("generate", generate_response) # Add edges workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "grade") workflow.add_conditional_edges( "grade", should_retry, { "rewrite": "rewrite", "generate": "generate", }, ) workflow.add_edge("rewrite", "retrieve") workflow.add_edge("generate", END) return workflow.compile()