| import logging |
| from typing import List |
|
|
| from langchain_core.callbacks import ( |
| AsyncCallbackManagerForRetrieverRun, |
| CallbackManagerForRetrieverRun, |
| ) |
| from langchain_core.documents import Document |
| from langchain_core.language_models import BaseLLM |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.prompts import BasePromptTemplate |
| from langchain_core.prompts.prompt import PromptTemplate |
| from langchain_core.retrievers import BaseRetriever |
| from langchain_core.runnables import Runnable |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \ |
| query from a user and converting it into a query for a vectorstore. \ |
| In this process, you strip out information that is not relevant for \ |
| the retrieval task. Here is the user query: {question}""" |
|
|
| |
| DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE) |
|
|
|
|
| class RePhraseQueryRetriever(BaseRetriever): |
| """Given a query, use an LLM to re-phrase it. |
| Then, retrieve docs for the re-phrased query.""" |
|
|
| retriever: BaseRetriever |
| llm_chain: Runnable |
|
|
| @classmethod |
| def from_llm( |
| cls, |
| retriever: BaseRetriever, |
| llm: BaseLLM, |
| prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT, |
| ) -> "RePhraseQueryRetriever": |
| """Initialize from llm using default template. |
| |
| The prompt used here expects a single input: `question` |
| |
| Args: |
| retriever: retriever to query documents from |
| llm: llm for query generation using DEFAULT_QUERY_PROMPT |
| prompt: prompt template for query generation |
| |
| Returns: |
| RePhraseQueryRetriever |
| """ |
| llm_chain = prompt | llm | StrOutputParser() |
| return cls( |
| retriever=retriever, |
| llm_chain=llm_chain, |
| ) |
|
|
| def _get_relevant_documents( |
| self, |
| query: str, |
| *, |
| run_manager: CallbackManagerForRetrieverRun, |
| ) -> List[Document]: |
| """Get relevant documents given a user question. |
| |
| Args: |
| query: user question |
| |
| Returns: |
| Relevant documents for re-phrased question |
| """ |
| re_phrased_question = self.llm_chain.invoke( |
| query, {"callbacks": run_manager.get_child()} |
| ) |
| logger.info(f"Re-phrased question: {re_phrased_question}") |
| docs = self.retriever.invoke( |
| re_phrased_question, config={"callbacks": run_manager.get_child()} |
| ) |
| return docs |
|
|
| async def _aget_relevant_documents( |
| self, |
| query: str, |
| *, |
| run_manager: AsyncCallbackManagerForRetrieverRun, |
| ) -> List[Document]: |
| raise NotImplementedError |
|
|