| from typing import Any, Optional |
|
|
| import weave |
| from langchain_core.documents import Document |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_groq import ChatGroq |
|
|
| from rag_pipelines.prompts import RETRIEVAL_CRITIC_PROMPT, RetrievalCriticResult |
|
|
|
|
| class RetrievalCritic: |
| """Evaluates the relevance of retrieved documents in response to a user question. |
| |
| Uses a language model chain to assess document relevance and filters documents based |
| on specified support levels. |
| |
| Attributes: |
| llm (ChatGroq): Language model instance used for relevance assessment. |
| prompt (ChatPromptTemplate): Template for structuring the critic evaluation prompt. |
| retrieval_critic_chain (RunnableSequence): Configured LangChain processing pipeline for evaluation. |
| support_levels (list[str]): list of support levels considered relevant for filtering. |
| """ |
|
|
| def __init__(self, llm: ChatGroq, support_levels: Optional[list[str]] = None) -> None: |
| """Initialize the retrieval critic with language model and configuration. |
| |
| Args: |
| llm (ChatGroq): Pre-configured ChatGroq instance for evaluation processing. |
| support_levels (Optional[list[str]]): list of acceptable support levels. |
| Defaults to ["fully-supported", "partially-supported", "no-support"]. |
| Documents will be filtered to only include these levels. |
| """ |
| self.llm = llm |
| self.prompt = ChatPromptTemplate.from_messages([("system", RETRIEVAL_CRITIC_PROMPT)]) |
| self.retrieval_critic_chain = self.prompt | self.llm.with_structured_output(RetrievalCriticResult) |
| self.support_levels = support_levels or ["fully-supported", "partially-supported", "no-support"] |
|
|
| @weave.op() |
| def score_context(self, question: str, context: str) -> str: |
| """Evaluate the relevance of a single document context to a question. |
| |
| Args: |
| question (str): User question to evaluate against. |
| context (str): Document text content to assess for relevance. |
| |
| Returns: |
| str: Support level decision from the model. Possible values are: |
| "fully-supported", "partially-supported", or "no-support". |
| """ |
| result = self.retrieval_critic_chain.invoke({"question": question, "context": context}) |
| return result.decision |
|
|
| @weave.op() |
| def __call__(self, state: dict[str, Any]) -> dict[str, Any]: |
| """Filter document contexts based on their relevance to the user question. |
| |
| Processes a state dictionary containing question, documents, and contexts, |
| returning a new state with filtered contexts based on support levels. |
| |
| Args: |
| state (dict[str, Any]): Input processing state containing: |
| - "question" (str): Original user question. |
| - "documents" (list[Any]): Retrieved document objects (passed through). |
| - "context" (list[str]): Extracted document texts to filter. |
| |
| Returns: |
| dict[str, Any]: Output state with filtered contexts. Contains: |
| - "question" (str): Original question. |
| - "documents" (list[Document]): Document objects from input. |
| - "context" (list[str]): Filtered list of document texts that match the configured support levels. |
| """ |
| question: str = state["question"] |
| documents: list[Document] = state["documents"] |
| relevant_context: list[str] = state["context"] |
|
|
| filtered_context: list[str] = [ |
| context for context in relevant_context if self.score_context(question, context) in self.support_levels |
| ] |
|
|
| return { |
| "question": question, |
| "context": filtered_context, |
| "documents": documents, |
| } |
|
|