Spaces:
Runtime error
Runtime error
| from models import EmbeddingModel, LLM | |
| from utils import MistralPrompts | |
| from vector_store import FaissVectorStore | |
| import argparse | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Create a ChatBot class to manage interactions | |
| class ChatBot: | |
| def __init__(self, llm, embedding_model, vector_store): | |
| self.llm = llm | |
| self.embedding_model = embedding_model | |
| self.chat_history = [] | |
| self.vector_store = vector_store | |
| def format_context(self, retrieved_documents): | |
| context, sources = '', '' | |
| # Format retrieved documents into context and sources | |
| # This is simplest way to combine. there are other techniques as well to try out. | |
| for doc in retrieved_documents: | |
| context += doc.text + '\n\n' | |
| sources += str(doc.metadata) + '\n' | |
| return context, sources | |
| def chat(self, question): | |
| if len(self.chat_history): | |
| # Create a prompt based on chat history | |
| chat_history_prompt = MistralPrompts.create_history_prompt(self.chat_history) | |
| standalone_question_prompt = MistralPrompts.create_standalone_question_prompt(question, chat_history_prompt) | |
| standalone_question = self.llm.generate_response(standalone_question_prompt) | |
| else: | |
| chat_history_prompt = '' | |
| standalone_question = question | |
| # Encode the question using the embedding model | |
| query_embedding = self.embedding_model.encode(standalone_question) | |
| # Retrieve documents related to the question | |
| retrieved_documents = self.vector_store.query(query_embedding, 3) | |
| context, sources = self.format_context(retrieved_documents) | |
| # Print information about retrieved documents | |
| print("Retrieved documents info: \n", sources) | |
| # Create a prompt and generate a response | |
| prompt = MistralPrompts.create_question_prompt(question, context, chat_history_prompt) | |
| response = self.llm.generate_response(prompt) | |
| # Extract the response and update chat history | |
| response = MistralPrompts.extract_response(response) | |
| self.chat_history.append((question, response)) | |
| return response | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--vector_database_path", default='vector_db',help="Vector database which store embeddings vector") | |
| args = parser.parse_args() | |
| VECTOR_DATABASE_PATH = parser.vector_database_path | |
| # Initialize models and vector store | |
| embedding_model = EmbeddingModel(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
| llm = LLM("mistralai/Mistral-7B-Instruct-v0.1") | |
| vector_store = FaissVectorStore.as_retriever(database_path=VECTOR_DATABASE_PATH) | |
| # Create a ChatBot instance | |
| chat_bot = ChatBot(llm, embedding_model, vector_store) | |
| # Start the conversation | |
| print("Assistant Bot: Hello, I'm the Assistant Bot! How may I assist you today?") | |
| while True: | |
| question = input("User:") | |
| response = chat_bot.chat(question) | |
| print("Assistant Bot:", response, '\n') | |