Spaces:
Build error
Build error
| import warnings | |
| from langchain_community.chat_message_histories.in_memory import ChatMessageHistory | |
| from langchain_community.vectorstores import DeepLake | |
| from langchain_core.messages import AIMessage | |
| from langchain_core.prompts import PromptTemplate, load_prompt | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from typing import List | |
| from langchain_core.documents.base import Document | |
| class DrakeLM: | |
| def __init__(self, model_path: str, db: DeepLake, config: dict): | |
| """ | |
| Parameters: | |
| model_path (str): The path to the model in case running Llama | |
| db (DeepLake): The DeepLake DB object | |
| config (dict): The configuration for the llama model | |
| Initialize the DrakeLM model | |
| """ | |
| self.gemini = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True) | |
| self.retriever = db.as_retriever() | |
| self.chat_history = ChatMessageHistory() | |
| self.chat_history.add_user_message("You are assisting a student to understand topics.") | |
| self.notes_prompt = load_prompt("prompt_templates/notes_prompt.yaml") | |
| self.chat_prompt = load_prompt("prompt_templates/chat_prompt.yaml") | |
| def _chat_prompt(self, query: str, context: str) -> (PromptTemplate, str): | |
| """ | |
| Parameters: | |
| query (str): The question asked by the user | |
| context (str): The context retrieved from the DB | |
| Returns: | |
| PromptTemplate: The prompt template for the chat | |
| prompt (str): The prompt string for the chat | |
| Create the chat prompt for the LLM model | |
| """ | |
| prompt = """You are assisting a student to understand topics. \n\n | |
| You have to answer the below question by utilising the below context to answer the question. \n\n | |
| Note to follow the rules given below \n\n | |
| Question: {query} \n\n | |
| Context: {context} \n\n | |
| Rules: {rules} \n\n | |
| Answer: | |
| """ | |
| rules = """ | |
| - If the question says answer for X number of marks, you have to provide X number of points. | |
| - Each point has to be explained in 3-4 sentences. | |
| - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example. | |
| - In case the user requests for a code snippet, provide the code snippet in the language specified in the example. | |
| - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message. | |
| """ | |
| prompt = prompt.format(query=query, context=context, rules=rules) | |
| return PromptTemplate.from_template(prompt), prompt | |
| def _retrieve(self, query: str, metadata_filter, k=3, distance_metric="cos") -> str: | |
| """ | |
| Parameters: | |
| query (str): The question asked by the user | |
| metadata_filter (dict): The metadata filter for the DB | |
| k (int): The number of documents to retrieve | |
| distance_metric (str): The distance metric for retrieval | |
| Returns: | |
| str: The context retrieved from the DB | |
| Retrieve the context from the DB | |
| """ | |
| self.retriever.search_kwargs["distance_metric"] = distance_metric | |
| self.retriever.search_kwargs["k"] = k | |
| if metadata_filter: | |
| self.retriever.search_kwargs["filter"] = { | |
| "metadata": { | |
| "id": metadata_filter["id"] | |
| } | |
| } | |
| retrieved_docs = self.retriever.get_relevant_documents(query) | |
| context = "" | |
| for rd in retrieved_docs: | |
| context += "\n" + rd.page_content | |
| return context | |
| def ask_llm(self, query: str, metadata_filter: dict = None) -> str: | |
| """ | |
| Parameters: | |
| query (str): The question asked by the user | |
| metadata_filter (dict): The metadata filter for the DB | |
| Returns: | |
| str: The response from the LLM model | |
| Ask the LLM model a question | |
| """ | |
| warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!") | |
| context = self._retrieve(query, metadata_filter) | |
| print("Retrieved context") | |
| prompt_template, prompt_string = self._chat_prompt(query, context) | |
| self.chat_history.add_user_message(prompt_string) | |
| print("Generating response...") | |
| rules = """ | |
| - If the question says answer for X number of marks, you have to provide X number of points. | |
| - Each point has to be explained in 3-4 sentences. | |
| - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example. | |
| - In case the user requests for a code snippet, provide the code snippet in the language specified in the example. | |
| - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message. | |
| """ | |
| prompt_template = self.chat_prompt.format(query=query, context=context, rules=rules) | |
| self.chat_history.add_ai_message(AIMessage(content=self.gemini.invoke(prompt_template).content)) | |
| return self.chat_history.messages[-1].content | |
| def create_notes(self, documents: List[Document]) -> str: | |
| """ | |
| Parameters: | |
| documents (List[Document]): The list of documents to create notes from | |
| Returns: | |
| str: The notes generated from the LLM model | |
| Create notes from the LLM model | |
| """ | |
| rules = """ | |
| - Follow the Markdown format for creating notes as shown in the example. | |
| - The heading of the content should be the title of the markdown file. | |
| - Create subheadings for each section. | |
| - Use numbered bullet points for each point. | |
| """ | |
| warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!") | |
| notes_chunk = [] | |
| for doc in documents: | |
| prompt = self.notes_prompt.format(content_chunk=doc.page_content, rules=rules) | |
| response = self.gemini.invoke(prompt) | |
| notes_chunk.append(response.content) | |
| return '\n'.join(notes_chunk) | |