Spaces:
Build error
Build error
| import boto3 | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_core.runnables.base import RunnableLambda | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.tools import StructuredTool | |
| from langchain_core.utils.function_calling import convert_to_openai_tool | |
| from langchain_core.messages import AIMessage | |
| from typing import List, Optional | |
| from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer, get_vectorDB_module | |
| from langchain_core.documents.base import Document | |
| from langchain_core.runnables import ConfigurableField | |
| import qdrant_client.models as rest | |
| class EurLexChat: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| self.max_history_messages = self.config["max_history_messages"] | |
| self.vectorDB_class = self.config['vectorDB']['class'] | |
| self.use_functions = ( | |
| 'use_context_function' in config["llm"] and | |
| config["llm"]["use_context_function"] and | |
| config["llm"]["class"] == "ChatOpenAI") | |
| self.embedder, self.llm, self.chatDB_class, self.retriever, retriever_chain = get_init_modules( | |
| config) | |
| self.max_context_size = config["llm"]["max_context_size"] | |
| self.prompt = ChatPromptTemplate.from_messages([ | |
| ("system", SYSTEM_PROMPT), | |
| MessagesPlaceholder(variable_name="history"), | |
| ("human", "{question}"), | |
| ]) | |
| self.prompt_loop = ChatPromptTemplate.from_messages([ | |
| ("system", SYSTEM_PROMPT_LOOP), | |
| ("human", "History: {history}. Message:"), | |
| ]) | |
| self.chain_loop_answer = ( self.prompt_loop | self.llm ) | |
| if self.use_functions: | |
| GET_CONTEXT_TOOL = StructuredTool.from_function( | |
| func=self.get_context, | |
| name="get_context", | |
| description="To be used whenever the provided context is empty or the user changes the topic of the conversation and you need the context for the topic. " + | |
| "To be used whenever you need context to answer the question. " + | |
| "Do not make up an answer, rely solely on the information provided by the context. ", | |
| args_schema=ContextInput | |
| ) | |
| self.llm_with_functions = self.llm.bind( | |
| tools=[convert_to_openai_tool(GET_CONTEXT_TOOL)] | |
| ) | |
| chain = ( | |
| self.prompt | | |
| RunnableLambda(self._resize_history) | | |
| self.llm_with_functions | |
| ) | |
| else: | |
| chain = ( | |
| self.prompt | | |
| RunnableLambda(self._resize_history) | | |
| self.llm | |
| ) | |
| self.chain_with_history = RunnableWithMessageHistory( | |
| chain, | |
| self.get_chat_history, | |
| input_messages_key="question", | |
| history_messages_key="history", | |
| ) | |
| self.relevant_documents_pipeline = (retriever_chain | self._parse_documents) | |
| def _resize_history(self, input_dict): | |
| """ | |
| Resize the message history. | |
| Args: | |
| input_dict: The llm input containing the message history. | |
| Returns: | |
| dict: The resized version of the input_dict. | |
| """ | |
| messages = input_dict.messages | |
| if (len(messages) - 2) > self.max_history_messages: | |
| messages = [messages[0]] + messages[-(self.max_history_messages + 1):] | |
| input_dict.messages = messages | |
| return input_dict | |
| def get_chat_history(self, session_id: str): | |
| """ | |
| Retrieve chat history instance for a specific session ID. | |
| Args: | |
| session_id (str): The unique identifier for the session. | |
| Returns: | |
| Chat history object: An instance of the appropriate chat history class. | |
| """ | |
| kwargs = self.config["chatDB"]["kwargs"] | |
| if self.config["chatDB"]["class"] == 'FileChatMessageHistory': | |
| file_path = f"{kwargs['output_path']}/{session_id}.json" | |
| return self.chatDB_class(file_path=file_path) | |
| elif self.config["chatDB"]["class"] == 'DynamoDBChatMessageHistory': | |
| table_name = kwargs["table_name"] | |
| session = boto3.Session(aws_access_key_id=kwargs["aws_access_key_id"], | |
| aws_secret_access_key=kwargs["aws_secret_access_key"], | |
| region_name='eu-west-1') | |
| return self.chatDB_class(session_id=session_id, | |
| table_name=table_name, | |
| boto3_session=session) | |
| else: | |
| return self.chatDB_class(session_id=session_id, **kwargs) | |
| def _parse_documents(self, docs: List[Document]) -> List[dict]: | |
| """ | |
| Parse a list of documents into a standardized format. | |
| Args: | |
| docs (List[Document]): A list of documents to parse. | |
| Returns: | |
| List[dict]: A list of dictionaries, each containing parsed information from the input documents. | |
| """ | |
| parsed_documents = [] | |
| for doc in docs: | |
| parsed_documents.append({ | |
| 'text': doc.page_content, | |
| 'source': doc.metadata["source"], | |
| 'celex': doc.metadata["celex"], | |
| '_id': doc.metadata["_id"] | |
| }) | |
| return parsed_documents | |
| def _format_context_docs(self, context_docs: List[dict]) -> str: | |
| """ | |
| Format a list of documents into a single string. | |
| Args: | |
| context_docs (List[dict]): A list of dictionaries containing text from context documents. | |
| Returns: | |
| str: A string containing the concatenated text from all context documents. | |
| """ | |
| context_str = '' | |
| for doc in context_docs: | |
| context_str += doc['text'] + "\n\n" | |
| return context_str | |
| def get_ids_from_celexes(self, celex_list: List[str]): | |
| """ | |
| Retrieve the IDs of the documents given their CELEX numbers. | |
| Args: | |
| celex_list (List[str]): A list of CELEX numbers. | |
| Returns: | |
| List[str]: A list of document IDs corresponding to the provided CELEX numbers | |
| """ | |
| if self.vectorDB_class == 'Qdrant': | |
| scroll_filter = rest.Filter( | |
| must=[ | |
| rest.FieldCondition( | |
| key="celex", | |
| match=rest.MatchAny(any=celex_list), | |
| ) | |
| ]) | |
| offset = -1 | |
| ids = [] | |
| while not (offset is None and offset != -1): | |
| if offset == -1: | |
| offset = None | |
| points, offset = self.retriever.vectorstore.client.scroll( | |
| collection_name=self.retriever.vectorstore.collection_name, | |
| limit=100, | |
| offset=offset, | |
| scroll_filter=scroll_filter, | |
| with_payload=False | |
| ) | |
| ids.extend([p.id for p in points]) | |
| else: | |
| NotImplementedError(f"Not supported {self.vectorDB_class} vectorDB class") | |
| return ids | |
| def _get_qdrant_ids_filter(self, ids): | |
| """ | |
| Returns a Qdrant filter to filter documents based on their IDs. | |
| This function acts as a workaround due to a hidden bug in Qdrant | |
| that prevents correct filtering using CELEX numbers. | |
| Args: | |
| ids (List[str]): A list of document IDs. | |
| Returns: | |
| Qdrant filter: A Qdrant filter to filter documents based on their IDs. | |
| """ | |
| filter = rest.Filter( | |
| must=[ | |
| rest.HasIdCondition(has_id=ids), | |
| ], | |
| ) | |
| return filter | |
| def get_relevant_docs(self, question: str, ids_list: Optional[List[str]] = None) -> List[dict]: | |
| """ | |
| Retrieve relevant documents based on a given question. | |
| If ids_list is provided, the search is filtered by the given IDs. | |
| Args: | |
| question (str): The question for which relevant documents are retrieved. | |
| ids_list (Optional[List[str]]): A list of document IDs to filter the search results. | |
| Returns: | |
| List[dict]: A list of relevant documents. | |
| """ | |
| if ids_list: | |
| search_kwargs = {k:v for k,v in self.retriever.search_kwargs.items()} | |
| if self.vectorDB_class == 'Qdrant': | |
| filter = self._get_qdrant_ids_filter(ids_list) | |
| else: | |
| raise ValueError(f'Celex filter not supported for {self.vectorDB_class}') | |
| search_kwargs.update({'filter': filter}) | |
| docs = self.relevant_documents_pipeline.invoke( | |
| {'question': question}, | |
| config={"configurable": {"search_kwargs": search_kwargs}}) | |
| else: | |
| docs = self.relevant_documents_pipeline.invoke({'question': question}) | |
| return docs | |
| def get_context(self, text: str, ids_list:Optional[List[str]]=None) -> str: | |
| """ | |
| Retrieve context for a given text. | |
| If ids_list is provided, the search is filtered by the given IDs. | |
| Args: | |
| text (str): The text for which context is retrieved. | |
| ids_list (Optional[List[str]]): A list of document IDs to filter the search results. | |
| Returns: | |
| str: A formatted string containing the relevant documents texts. | |
| """ | |
| docs = self.get_relevant_docs(text, ids_list=ids_list) | |
| return self._format_context_docs(docs) | |
| def _remove_last_messages(self, session_id:str, n:int) -> None: | |
| """ | |
| Remove last n messages from the chat history of a specific session. | |
| Args: | |
| session_id (str): The session ID for which messages are removed. | |
| n (int): The number of last messages to remove. | |
| """ | |
| chat_history = self.get_chat_history(session_id=session_id) | |
| message_history = chat_history.messages | |
| chat_history.clear() | |
| message_history = message_history[:-n] | |
| for message in message_history: | |
| chat_history.add_message(message) | |
| def _format_history(self, session_id:str) -> str: | |
| """ | |
| Format chat history for a specific session into a string. | |
| Args: | |
| session_id (str): The session ID for which the chat history is formatted. | |
| Returns: | |
| str: A formatted string containing the chat history for the specified session. | |
| """ | |
| chat_history = self.get_chat_history(session_id).messages | |
| formatted_history = "" | |
| for message in chat_history: | |
| formatted_history += f"{message.type}: {message.content}\n\n" | |
| return formatted_history | |
| def _resize_context(self, context_docs: List[dict]) -> List[dict]: | |
| """ | |
| Resize the dimension of the context in terms of number of tokens. | |
| If the concatenation of document text exceeds max_context_size, | |
| the document text is cut off to meet the limit. | |
| Args: | |
| context_docs (List[dict]): List of formatted documents. | |
| Returns: | |
| List[dict]: Returns the list of resized documents. | |
| """ | |
| lengths = [self.llm.get_num_tokens(doc['text']) for doc in context_docs] | |
| resized_contexts = [] | |
| total_len = 0 | |
| for i, l in enumerate(lengths): | |
| if l + total_len <= self.max_context_size: | |
| resized_contexts.append(context_docs[i]) | |
| total_len += l | |
| return resized_contexts | |
| def get_answer(self, | |
| session_id: str, | |
| question: str, | |
| context_docs: List[dict], | |
| from_tool: bool = False, | |
| ids_list: List[str] = None | |
| ) -> Answer: | |
| """ | |
| Get an answer to a question of a specific session, considering context documents and history messages. | |
| If ids_list is provided, any search for new context documents is filtered by the given IDs. | |
| Args: | |
| session_id (str): The session ID for which the answer is retrieved. | |
| question (str): The new user message. | |
| context_docs (List[dict]): A list of documents used as context to answer the user message. | |
| from_tool (bool, optional): Whether the question originates from a tool. Defaults to False. | |
| ids_list (Optional[List[str]]): A list of document IDs to filter the search results for new context documents. | |
| Returns: | |
| Answer: An object containing the answer along with a new list of context documents | |
| if those provided are insufficient to answer the question. | |
| """ | |
| resized_docs = self._resize_context(context_docs) | |
| context = self._format_context_docs(resized_docs) | |
| result = self.chain_with_history.invoke( | |
| {"context": context, "question": question}, | |
| config={"configurable": {"session_id": session_id}} | |
| ) | |
| if self.use_functions and len(result.additional_kwargs) > 0: | |
| if from_tool: | |
| self._remove_last_messages(session_id=session_id, n=1) | |
| history = self._format_history(session_id) | |
| result = self.chain_loop_answer.invoke({'history': history}) | |
| self.get_chat_history(session_id=session_id).add_message(AIMessage(result.content)) | |
| return Answer(answer=result.content, status=-1) | |
| text = eval(result.additional_kwargs['tool_calls'][0]['function']['arguments'])['text'] | |
| new_docs = self.get_relevant_docs(text, ids_list=ids_list) | |
| self._remove_last_messages(session_id=session_id, n=2) | |
| result = self.get_answer( | |
| session_id=session_id, | |
| question=question, | |
| context_docs=new_docs, | |
| from_tool=True, | |
| ids_list=ids_list | |
| ) | |
| if result.status == 1: | |
| return Answer(answer=result.answer, new_documents=new_docs) | |
| else: | |
| return Answer(answer=result.answer) | |
| return Answer(answer=result.content) |