Spaces:
Sleeping
Sleeping
| from llm_constants import LLM_MODEL_NAME, MAX_TOKENS, RERANKER_MODEL_NAME, EMBEDDINGS_MODEL_NAME, EMBEDDINGS_TOKENS_COST, INPUT_TOKENS_COST, OUTPUT_TOKENS_COST, COHERE_RERANKER_COST | |
| from prompts import CHAT_PROMPT, TOOLS | |
| import os | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_core.documents import Document | |
| from langchain_community.retrievers import BM25Retriever | |
| from typing import List, Dict, Sequence | |
| from pydantic_models import RequestModel, ResponseModel, ChatHistoryItem, VectorStoreDocumentItem | |
| import tiktoken | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from langchain_community.vectorstores import FAISS | |
| import anthropic | |
| import cohere | |
| class RAGChatBot: | |
| __cohere_api_key = os.getenv("COHERE_API_KEY") | |
| __anthroic_api_key = os.getenv("ANTHROPIC_API_KEY") | |
| __openai_api_key = os.getenv("OPENAI_API_KEY") | |
| __embedding_function = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME) | |
| __base_retriever = None | |
| __bm25_retriever = None | |
| anthropic_client = None | |
| cohere_client = None | |
| top_n: int = 3 | |
| chat_history_length: int = 10 | |
| def __init__(self, vectorstore_path:str, top_n:int = 3): | |
| if self.__cohere_api_key is None: | |
| raise ValueError("COHERE_API_KEY must be set in the environment") | |
| if self.__anthroic_api_key is None: | |
| raise ValueError("ANTHROPIC_API_KEY must be set in the environment") | |
| if self.__openai_api_key is None: | |
| raise ValueError("OPENAI_API_KEY must be set in the environment") | |
| if not isinstance(top_n, int): | |
| raise ValueError("top_n must be an integer") | |
| self.top_n = top_n | |
| self.set_base_retriever(vectorstore_path) | |
| self.set_anthropic_client() | |
| self.set_cohere_client() | |
| def set_base_retriever(self, vectorstore_path:str): | |
| db = FAISS.load_local(vectorstore_path, self.__embedding_function, allow_dangerous_deserialization=True) | |
| retriever = db.as_retriever(search_kwargs={"k": 25}) | |
| self.__base_retriever = retriever | |
| self.__bm25_retriever = BM25Retriever.from_documents(list(db.docstore.__dict__.get('_dict').values()), k=25) | |
| def set_anthropic_client(self): | |
| self.anthropic_client = anthropic.Anthropic(api_key=self.__anthroic_api_key) | |
| def set_cohere_client(self): | |
| self.cohere_client = cohere.Client(self.__cohere_api_key) | |
| def make_llm_api_call(self, messages:list): | |
| return self.anthropic_client.messages.create( | |
| model=LLM_MODEL_NAME, | |
| max_tokens=MAX_TOKENS, | |
| temperature=0, | |
| messages=messages, | |
| tools=TOOLS | |
| ) | |
| def make_rerank_api_call(self, search_phrase:str, documents: Sequence[str]): | |
| return self.cohere_client.rerank(query=search_phrase, documents=documents, model=RERANKER_MODEL_NAME, top_n=self.top_n) | |
| def retrieve_documents(self, search_phrase:str): | |
| similarity_documents = self.__base_retriever.invoke(search_phrase) | |
| bm25_documents = self.__bm25_retriever.invoke(search_phrase) | |
| unique_docs = [] | |
| for doc in bm25_documents: | |
| if doc not in unique_docs: | |
| unique_docs.append(doc) | |
| for doc in similarity_documents: | |
| if doc not in unique_docs: | |
| unique_docs.append(doc) | |
| return unique_docs | |
| def retrieve_and_rerank(self, search_phrase:str): | |
| documents = self.retrieve_documents(search_phrase) | |
| if len(documents) == 0: # to avoid empty api call | |
| return [] | |
| docs = [doc.page_content for doc in documents if isinstance(doc, Document) ] | |
| api_result = self.make_rerank_api_call(search_phrase, docs) | |
| reranked_docs = [] | |
| max_score = max([res.relevance_score for res in api_result.results]) | |
| threshold_score = max_score * 0.8 | |
| for res in api_result.results: | |
| # if res.relevance_score < threshold_score: | |
| # continue | |
| doc = documents[res.index] | |
| documentItem = VectorStoreDocumentItem(page_content=doc.page_content, filename=doc.metadata['filename'], heading=doc.metadata['heading'], relevance_score=res.relevance_score) | |
| reranked_docs.append(documentItem) | |
| return reranked_docs | |
| def get_context_and_docs(self, search_phrase:str): | |
| docs = self.retrieve_and_rerank(search_phrase) | |
| context = "\n\n\n".join([f"Filename:{doc.heading}\n\n{doc.page_content}" for doc in docs]) | |
| return context, docs | |
| def get_tool_use_assistant_message(self, tool_use_block): | |
| return {'role': 'assistant', | |
| 'content':tool_use_block | |
| } | |
| def get_tool_use_user_message(self, tool_use_id, context): | |
| return {'role': 'user', | |
| 'content': [{'type': 'tool_result', | |
| 'tool_use_id': tool_use_id, | |
| 'content': context}]} | |
| def process_tool_call(self, tool_name, tool_input): | |
| if tool_name == "Documents_Retriever": | |
| context, sources_list = self.get_context_and_docs(tool_input["search_phrase"]) | |
| search_phrase = tool_input["search_phrase"] | |
| return sources_list, search_phrase, context | |
| def calculate_cost(self, input_tokens, output_tokens, search_phrase): | |
| MILLION = 1000000 | |
| if search_phrase: | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| query_encode = enc.encode(search_phrase) | |
| embeddings_cost = len(query_encode) * (EMBEDDINGS_TOKENS_COST/MILLION) | |
| total_cost = embeddings_cost + COHERE_RERANKER_COST + (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) | |
| else: | |
| total_cost = (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION)) | |
| return total_cost | |
| def chat_with_claude(self, user_message_history:list): | |
| input_tokens = 0 | |
| output_tokens = 0 | |
| message = self.make_llm_api_call(user_message_history) | |
| input_tokens += message.usage.input_tokens | |
| output_tokens += message.usage.output_tokens | |
| documents_list = [] | |
| search_phrase = "" | |
| while message.stop_reason == "tool_use": | |
| tool_use = next(block for block in message.content if block.type == "tool_use") | |
| tool_name = tool_use.name | |
| tool_input = tool_use.input | |
| tool_use_id = tool_use.id | |
| documents_list, search_phrase, tool_result = self.process_tool_call(tool_name, tool_input) | |
| user_message_history.append( self.get_tool_use_assistant_message(message.content)) | |
| user_message_history.append( self.get_tool_use_user_message(tool_use_id, tool_result)) | |
| message = self.make_llm_api_call(user_message_history) | |
| input_tokens += message.usage.input_tokens | |
| output_tokens += message.usage.output_tokens | |
| answer = next( | |
| (block.text for block in message.content if hasattr(block,"text")), | |
| None, | |
| ) | |
| if "<answer>" in answer: | |
| answer = answer.split("<answer>")[1].split("</answer>")[0].strip() | |
| total_cost = self.calculate_cost(input_tokens, output_tokens, search_phrase) | |
| return (documents_list, search_phrase, answer, total_cost) | |
| def get_chat_history_text(self, chat_history: List[ChatHistoryItem]): | |
| chat_history_text = "" | |
| for chat_message in chat_history: | |
| chat_history_text += f"User: {chat_message.user_message}\nAssistant: {chat_message.assistant_message}\n" | |
| return chat_history_text.strip() | |
| def get_response(self, input:RequestModel) -> ResponseModel: | |
| chat_history = self.get_chat_history_text(input.chat_history) | |
| user_question = input.user_question | |
| user_prompt = CHAT_PROMPT.format(CHAT_HISTORY=chat_history, USER_QUESTION=user_question) | |
| if input.use_tool: | |
| user_prompt = f"{user_prompt}\nUse Documents_Retriever tool in your response." | |
| sources_list, search_phrase, answer, _ = self.chat_with_claude([{"role":"user","content":[{"type":"text","text":user_prompt}]}]) | |
| updated_chat_history = input.chat_history.copy() | |
| updated_chat_history.append(ChatHistoryItem(user_message=user_question, assistant_message=answer)) | |
| return ResponseModel(answer = answer, sources_documents = sources_list, chat_history=updated_chat_history, search_phrase=search_phrase) | |