awinml's picture
Upload 107 files
336f4a9 verified
from typing import Any, Optional
import weave
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from rag_pipelines.prompts import RETRIEVAL_CRITIC_PROMPT, RetrievalCriticResult
class RetrievalCritic:
"""Evaluates the relevance of retrieved documents in response to a user question.
Uses a language model chain to assess document relevance and filters documents based
on specified support levels.
Attributes:
llm (ChatGroq): Language model instance used for relevance assessment.
prompt (ChatPromptTemplate): Template for structuring the critic evaluation prompt.
retrieval_critic_chain (RunnableSequence): Configured LangChain processing pipeline for evaluation.
support_levels (list[str]): list of support levels considered relevant for filtering.
"""
def __init__(self, llm: ChatGroq, support_levels: Optional[list[str]] = None) -> None:
"""Initialize the retrieval critic with language model and configuration.
Args:
llm (ChatGroq): Pre-configured ChatGroq instance for evaluation processing.
support_levels (Optional[list[str]]): list of acceptable support levels.
Defaults to ["fully-supported", "partially-supported", "no-support"].
Documents will be filtered to only include these levels.
"""
self.llm = llm
self.prompt = ChatPromptTemplate.from_messages([("system", RETRIEVAL_CRITIC_PROMPT)])
self.retrieval_critic_chain = self.prompt | self.llm.with_structured_output(RetrievalCriticResult)
self.support_levels = support_levels or ["fully-supported", "partially-supported", "no-support"]
@weave.op()
def score_context(self, question: str, context: str) -> str:
"""Evaluate the relevance of a single document context to a question.
Args:
question (str): User question to evaluate against.
context (str): Document text content to assess for relevance.
Returns:
str: Support level decision from the model. Possible values are:
"fully-supported", "partially-supported", or "no-support".
"""
result = self.retrieval_critic_chain.invoke({"question": question, "context": context})
return result.decision
@weave.op()
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
"""Filter document contexts based on their relevance to the user question.
Processes a state dictionary containing question, documents, and contexts,
returning a new state with filtered contexts based on support levels.
Args:
state (dict[str, Any]): Input processing state containing:
- "question" (str): Original user question.
- "documents" (list[Any]): Retrieved document objects (passed through).
- "context" (list[str]): Extracted document texts to filter.
Returns:
dict[str, Any]: Output state with filtered contexts. Contains:
- "question" (str): Original question.
- "documents" (list[Document]): Document objects from input.
- "context" (list[str]): Filtered list of document texts that match the configured support levels.
"""
question: str = state["question"]
documents: list[Document] = state["documents"]
relevant_context: list[str] = state["context"]
filtered_context: list[str] = [
context for context in relevant_context if self.score_context(question, context) in self.support_levels
]
return {
"question": question,
"context": filtered_context,
"documents": documents,
}