Spaces:
Runtime error
Runtime error
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| from langchain.schema import BaseMessage, BaseRetriever, Document | |
| from langchain.chains.conversational_retrieval.base import _get_chat_history | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate) | |
| from config import DEPLOYMENT_ID | |
| from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE | |
| from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE | |
| from chains.azure_openai import CustomAzureOpenAI | |
| class MultiQueriesChain(LLMChain): | |
| llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, | |
| openai_api_type=OPENAI_API_TYPE, | |
| openai_api_base=OPENAI_API_BASE, | |
| openai_api_version=OPENAI_API_VERSION, | |
| openai_api_key=OPENAI_API_KEY, | |
| temperature=0.0) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE), | |
| HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE) | |
| ]) | |
| llm_chain = MultiQueriesChain() | |
| class CustomConversationalRetrievalChain(ConversationalRetrievalChain): | |
| retriever: BaseRetriever | |
| """Index to connect to.""" | |
| max_tokens_limit: Optional[int] = None | |
| def _get_docs( | |
| self, | |
| question: str, | |
| inputs: Dict[str, Any] | |
| ) -> List[Document]: | |
| """Get docs.""" | |
| docs = self.retriever.get_relevant_documents( | |
| question | |
| ) | |
| # Add attribute to docs call docs.citation | |
| for (idx, d) in enumerate(docs): | |
| item = [d.page_content.strip("�"), d.metadata["source"]] | |
| d.page_content = f'[{idx+1}] {item[0]}' | |
| d.metadata["source"] = f'{item[1]}' | |
| return self._reduce_tokens_below_limit(docs) | |
| # def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: | |
| # results = llm_chain.predict(question=question) + "\n" | |
| # print(results) | |
| # queries = list(map(lambda x: x.strip(), results.split(', '))) | |
| # docs = [] | |
| # print(queries) | |
| # for query in queries[:3]: | |
| # self.retriever.search_kwargs = {"k": 3} | |
| # doc = self.retriever.get_relevant_documents(query) | |
| # docs.extend(doc) | |
| # unique_documents_dict = { | |
| # (doc.page_content, tuple(sorted(doc.metadata.items()))): doc | |
| # for doc in docs | |
| # } | |
| # unique_documents = list(unique_documents_dict.values()) | |
| # return self._reduce_tokens_below_limit(unique_documents) | |