| from langchain.chains import RetrievalQA |
|
|
| from langflow.base.chains.model import LCChainComponent |
| from langflow.field_typing import Message |
| from langflow.inputs import BoolInput, DropdownInput, HandleInput, MultilineInput |
|
|
|
|
| class RetrievalQAComponent(LCChainComponent): |
| display_name = "Retrieval QA" |
| description = "Chain for question-answering querying sources from a retriever." |
| name = "RetrievalQA" |
| legacy: bool = True |
| icon = "LangChain" |
| inputs = [ |
| MultilineInput( |
| name="input_value", |
| display_name="Input", |
| info="The input value to pass to the chain.", |
| required=True, |
| ), |
| DropdownInput( |
| name="chain_type", |
| display_name="Chain Type", |
| info="Chain type to use.", |
| options=["Stuff", "Map Reduce", "Refine", "Map Rerank"], |
| value="Stuff", |
| advanced=True, |
| ), |
| HandleInput( |
| name="llm", |
| display_name="Language Model", |
| input_types=["LanguageModel"], |
| required=True, |
| ), |
| HandleInput( |
| name="retriever", |
| display_name="Retriever", |
| input_types=["Retriever"], |
| required=True, |
| ), |
| HandleInput( |
| name="memory", |
| display_name="Memory", |
| input_types=["BaseChatMemory"], |
| ), |
| BoolInput( |
| name="return_source_documents", |
| display_name="Return Source Documents", |
| value=False, |
| ), |
| ] |
|
|
| def invoke_chain(self) -> Message: |
| chain_type = self.chain_type.lower().replace(" ", "_") |
| if self.memory: |
| self.memory.input_key = "query" |
| self.memory.output_key = "result" |
|
|
| runnable = RetrievalQA.from_chain_type( |
| llm=self.llm, |
| chain_type=chain_type, |
| retriever=self.retriever, |
| memory=self.memory, |
| |
| |
| return_source_documents=True, |
| ) |
|
|
| result = runnable.invoke( |
| {"query": self.input_value}, |
| config={"callbacks": self.get_langchain_callbacks()}, |
| ) |
|
|
| source_docs = self.to_data(result.get("source_documents", keys=[])) |
| result_str = str(result.get("result", "")) |
| if self.return_source_documents and len(source_docs): |
| references_str = self.create_references_from_data(source_docs) |
| result_str = f"{result_str}\n{references_str}" |
| |
| self.status = {**result, "source_documents": source_docs, "output": result_str} |
| return result_str |
|
|