Kishor Ramanan
Base
0a25329
from typing import List, Literal, Optional, TypedDict
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from qdrant_client.http.models import (
FieldCondition,
Filter,
MatchValue,
)
from clients import LLM, VECTOR_STORE
class RetrievalState(TypedDict):
"""State for the agentic retrieval graph."""
original_query: str
current_query: str
category: Optional[str]
topic: Optional[str]
documents: List[Document]
relevant_documents: List[Document]
generation: str
retry_count: int
max_retries: int
class GradeDocuments(BaseModel):
"""Grade whether a document is relevant to the query."""
is_relevant: Literal["yes", "no"] = Field(
description="Is the document relevant to the query? 'yes' or 'no'"
)
reason: str = Field(description="Brief reason for the relevance decision")
def retrieve_documents(state: RetrievalState) -> RetrievalState:
"""Retrieve documents from vector store."""
query = state["current_query"]
category = state.get("category")
topic = state.get("topic")
# Build Qdrant filter
conditions = []
if category:
conditions.append(
FieldCondition(key="metadata.category", match=MatchValue(value=category))
)
if topic:
conditions.append(
FieldCondition(key="metadata.topic", match=MatchValue(value=topic))
)
qdrant_filter = Filter(must=conditions) if conditions else None
documents = VECTOR_STORE.similarity_search(
query,
k=5,
filter=qdrant_filter,
)
return {**state, "documents": documents}
def grade_documents(state: RetrievalState) -> RetrievalState:
"""Grade documents for relevance using LLM."""
query = state["original_query"]
documents = state["documents"]
if not documents:
return {**state, "relevant_documents": []}
# Create grader with structured output
grader_llm = LLM.with_structured_output(GradeDocuments)
grading_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a grader assessing relevance of a retrieved document to a user query.
If the document contains keywords or semantic meaning related to the query, grade it as relevant.
Be lenient - even partial relevance should be marked as 'yes'.
Only mark 'no' if the document is completely unrelated.""",
),
(
"human",
"""Query: {query}
Document content: {document}
Is this document relevant to the query?""",
),
]
)
relevant_docs = []
for doc in documents:
try:
result = grader_llm.invoke(
grading_prompt.format_messages(
query=query,
document=doc.page_content[:1000], # Limit content length
)
)
if result.is_relevant == "yes":
relevant_docs.append(doc)
except Exception:
# If grading fails, include the document (fail-safe)
relevant_docs.append(doc)
return {**state, "relevant_documents": relevant_docs}
def rewrite_query(state: RetrievalState) -> RetrievalState:
"""Rewrite the query for better retrieval."""
original_query = state["original_query"]
retry_count = state["retry_count"]
rewrite_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert at reformulating search queries.
Given the original query, generate a better search query that might retrieve more relevant documents.
Focus on:
- Extracting key concepts and entities
- Using synonyms or related terms
- Being more specific or more general as appropriate
Return ONLY the rewritten query, nothing else.""",
),
("human", "Original query: {query}\n\nRewritten query:"),
]
)
response = LLM.invoke(rewrite_prompt.format_messages(query=original_query))
new_query = response.content.strip()
return {
**state,
"current_query": new_query,
"retry_count": retry_count + 1,
}
def generate_response(state: RetrievalState) -> RetrievalState:
"""Generate final response from relevant documents."""
relevant_docs = state["relevant_documents"]
if not relevant_docs:
return {**state, "generation": "No relevant memories found."}
# Format documents
formatted = []
for i, doc in enumerate(relevant_docs, 1):
meta = doc.metadata
formatted.append(
f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}"
)
return {**state, "generation": "\n".join(formatted)}
def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]:
"""Decide whether to retry with a rewritten query."""
relevant_docs = state["relevant_documents"]
retry_count = state["retry_count"]
max_retries = state["max_retries"]
# If we have relevant docs, generate response
if relevant_docs:
return "generate"
# If no relevant docs and we can still retry, rewrite query
if retry_count < max_retries:
return "rewrite"
# Max retries reached, generate (empty) response
return "generate"
def build_retrieval_graph():
workflow = StateGraph(RetrievalState)
# Add nodes
workflow.add_node("retrieve", retrieve_documents)
workflow.add_node("grade", grade_documents)
workflow.add_node("rewrite", rewrite_query)
workflow.add_node("generate", generate_response)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade")
workflow.add_conditional_edges(
"grade",
should_retry,
{
"rewrite": "rewrite",
"generate": "generate",
},
)
workflow.add_edge("rewrite", "retrieve")
workflow.add_edge("generate", END)
return workflow.compile()