Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| from typing import Any, Dict, List, Optional | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.pydantic_v1 import Field | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.base import Chain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | |
| class QAGenerationChain(Chain): | |
| """Base class for question-answer generation chains.""" | |
| llm_chain: LLMChain | |
| """LLM Chain that generates responses from user input and context.""" | |
| text_splitter: TextSplitter = Field( | |
| default=RecursiveCharacterTextSplitter(chunk_overlap=500) | |
| ) | |
| """Text splitter that splits the input into chunks.""" | |
| input_key: str = "text" | |
| """Key of the input to the chain.""" | |
| output_key: str = "questions" | |
| """Key of the output of the chain.""" | |
| k: Optional[int] = None | |
| """Number of questions to generate.""" | |
| def from_llm( | |
| cls, | |
| llm: BaseLanguageModel, | |
| prompt: Optional[BasePromptTemplate] = None, | |
| **kwargs: Any, | |
| ) -> QAGenerationChain: | |
| """ | |
| Create a QAGenerationChain from a language model. | |
| Args: | |
| llm: a language model | |
| prompt: a prompt template | |
| **kwargs: additional arguments | |
| Returns: | |
| a QAGenerationChain class | |
| """ | |
| _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) | |
| chain = LLMChain(llm=llm, prompt=_prompt) | |
| return cls(llm_chain=chain, **kwargs) | |
| def _chain_type(self) -> str: | |
| raise NotImplementedError | |
| def input_keys(self) -> List[str]: | |
| return [self.input_key] | |
| def output_keys(self) -> List[str]: | |
| return [self.output_key] | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, List]: | |
| docs = self.text_splitter.create_documents([inputs[self.input_key]]) | |
| results = self.llm_chain.generate( | |
| [{"text": d.page_content} for d in docs], run_manager=run_manager | |
| ) | |
| qa = [json.loads(res[0].text) for res in results.generations] | |
| return {self.output_key: qa} | |