Spaces:
Sleeping
Sleeping
| # Importing the libraries | |
| import os | |
| import math | |
| import requests | |
| import bs4 | |
| from dotenv import load_dotenv | |
| import nltk | |
| import numpy as np | |
| import openai | |
| import streamlit as st | |
| from streamlit_chat import message as show_message | |
| import textract | |
| import tiktoken | |
| import uuid | |
| import validators | |
| # Helper variables | |
| load_dotenv() | |
| openai.api_key = os.environ['openapi'] # Load OpenAI API key from .env file | |
| llm_model = "gpt-3.5-turbo" # https://platform.openai.com/docs/guides/chat/introduction | |
| llm_context_window = ( | |
| 4097 # https://platform.openai.com/docs/guides/chat/managing-tokens | |
| ) | |
| embed_context_window, embed_model = ( | |
| 8191, | |
| "text-embedding-ada-002", | |
| ) # https://platform.openai.com/docs/guides/embeddings/second-generation-models | |
| nltk.download( | |
| "punkt" | |
| ) # Download the nltk punkt tokenizer for splitting text into sentences | |
| tokenizer = tiktoken.get_encoding( | |
| "cl100k_base" | |
| ) # Load the cl100k_base tokenizer which is designed to work with the ada-002 model (engine) | |
| download_chunk_size = 128 # TODO: Find optimal chunk size for downloading files | |
| split_chunk_tokens = 300 # TODO: Find optimal chunk size for splitting text | |
| num_citations = 5 # TODO: Find optimal number of citations to give context to the LLM | |
| # Streamlit settings | |
| user_avatar_style = "fun-emoji" # https://www.dicebear.com/styles | |
| assistant_avatar_style = "bottts-neutral" | |
| # Helper functions | |
| def get_num_tokens(text): # Count the number of tokens in a string | |
| return len( | |
| tokenizer.encode(text, disallowed_special=()) | |
| ) # disallowed_special=() removes the special tokens) | |
| # TODO: | |
| # Currently, any sentence that is longer than the max number of tokens will be its own chunk | |
| # This is not ideal, since this doesn't ensure that the chunks are of a maximum size | |
| # Find a way to split the sentence into chunks of a maximum size | |
| def split_into_many(text): # Split text into chunks of a maximum number of tokens | |
| sentences = nltk.tokenize.sent_tokenize(text) # Split the text into sentences | |
| total_tokens = [ | |
| get_num_tokens(sentence) for sentence in sentences | |
| ] # Get the number of tokens for each sentence | |
| chunks = [] | |
| tokens_so_far = 0 | |
| chunk = [] | |
| for sentence, num_tokens in zip(sentences, total_tokens): | |
| if not tokens_so_far: # If this is the first sentence in the chunk | |
| if ( | |
| num_tokens > split_chunk_tokens | |
| ): # If the sentence is longer than the max number of tokens, add it as its own chunk | |
| chunk.append(sentence) | |
| chunks.append(" ".join(chunk)) | |
| chunk = [] | |
| else: # If this is not the first sentence in the chunk | |
| if ( | |
| tokens_so_far + num_tokens > split_chunk_tokens | |
| ): # If the sentence would make the chunk longer than the max number of tokens, add the chunk to the list of chunks | |
| chunks.append(" ".join(chunk)) | |
| chunk = [] | |
| tokens_so_far = 0 | |
| # Otherwise, add the sentence to the chunk and add the number of tokens to the total | |
| chunk.append(sentence) | |
| tokens_so_far += num_tokens + 1 | |
| # In case the file is smaller than the max number of tokens, add the last chunk | |
| if not chunks: | |
| chunks.append(" ".join(chunk)) | |
| return chunks | |
| def embed(prompt): # Embed the prompt | |
| embeds = [] | |
| if type(prompt) == str: | |
| if ( | |
| get_num_tokens(prompt) > embed_context_window | |
| ): # If token_length of prompt > context_window | |
| prompt = split_into_many(prompt) # Split prompt into multiple chunks | |
| else: # If token_length of prompt <= context_window | |
| embeds = openai.Embedding.create(input=prompt, model=embed_model)[ | |
| "data" | |
| ] # Embed prompt | |
| if not embeds: # If the prompt was split into/is set of chunks | |
| max_num_chunks = ( | |
| embed_context_window // split_chunk_tokens | |
| ) # Number of chunks that can fit in the context window | |
| for i in range( | |
| 0, math.ceil(len(prompt) / max_num_chunks) | |
| ): # For each batch of chunks | |
| embeds.extend( | |
| openai.Embedding.create( | |
| input=prompt[i * max_num_chunks : (i + 1) * max_num_chunks], | |
| model=embed_model, | |
| )["data"] | |
| ) # Embed the batch of chunks | |
| return embeds # Return the list of embeddings | |
| def embed_file(filename): # Create embeddings for a file | |
| source_type = "file" # To help distinguish between local/URL files and URLs | |
| file_source = "" # Source of the file | |
| file_chunks = [] # List of file chunks (from the file) | |
| file_vectors = [] # List of lists of file embeddings (from each chunk) | |
| try: | |
| extracted_text = ( | |
| textract.process(filename) | |
| .decode("utf-8") # Extracted text is in bytes, convert to string | |
| .encode("ascii", "ignore") # Remove non-ascii characters | |
| .decode() # Convert back to string | |
| ) | |
| if not extracted_text: # If the file is empty | |
| raise Exception | |
| os.remove( | |
| filename | |
| ) # Remove the file from the server since it is no longer needed | |
| file_source = filename | |
| file_chunks = split_into_many(extracted_text) # Split the text into chunks | |
| file_vectors = [x["embedding"] for x in embed(file_chunks)] # Embed the chunks | |
| except Exception: # If the file cannot be extracted, return empty values | |
| if os.path.exists(filename): # If the file still exists | |
| os.remove( | |
| filename | |
| ) # Remove the file from the server since it is no longer needed | |
| source_type = "" | |
| file_source = "" | |
| file_chunks = [] | |
| file_vectors = [] | |
| return source_type, file_source, file_chunks, file_vectors | |
| def embed_url(url): # Create embeddings for a url | |
| source_type = "url" # To help distinguish between local/URL files and URLs | |
| url_source = "" # Source of the url | |
| url_chunks = [] # List of url chunks (for the url) | |
| url_vectors = [] # List of list of url embeddings (for each chunk) | |
| filename = "" # Filename of the url if it is a file | |
| try: | |
| if validators.url(url, public=True): # Verify url is a valid and public | |
| response = requests.get(url) # Get the url info | |
| header = response.headers["Content-Type"] # Get the header of the url | |
| is_application = ( | |
| header.split("/")[0] == "application" | |
| ) # Check if the url is a file | |
| if is_application: # If url is a file, call embed_file on the file | |
| filetype = header.split("/")[1] # Get the filetype | |
| url_parts = url.split("/") # Get the parts of the url | |
| filename = str( | |
| "./" | |
| + " ".join( | |
| url_parts[:-1] + [url_parts[-1].split(".")[0]] | |
| ) # Replace / with whitespace in the filename to avoid issues with the file path and remove the file extension since it may not match the actual filetype | |
| + "." | |
| + filetype | |
| ) # Create the filename | |
| with requests.get( | |
| url, stream=True | |
| ) as stream_response: # Download the file | |
| stream_response.raise_for_status() | |
| with open(filename, "wb") as file: | |
| for chunk in stream_response.iter_content( | |
| chunk_size=download_chunk_size | |
| ): | |
| file.write(chunk) | |
| return embed_file(filename) # Embed the file | |
| else: # If url is a webpage, use BeautifulSoup to extract the text | |
| soup = bs4.BeautifulSoup(response.text) # Create a BeautifulSoup object | |
| extracted_text = ( | |
| soup.get_text() # Extract the text from the webpage | |
| .encode("ascii", "ignore") # Remove non-ascii characters | |
| .decode() # Convert back to string | |
| ) | |
| if not extracted_text: # If the webpage is empty | |
| raise Exception | |
| url_source = url | |
| url_chunks = split_into_many( | |
| extracted_text | |
| ) # Split the text into chunks | |
| url_vectors = [ | |
| x["embedding"] for x in embed(url_chunks[-1]) | |
| ] # Embed the chunks | |
| else: # If url is not valid or public | |
| raise Exception | |
| except Exception: # If the url cannot be extracted, return empty values | |
| source_type = "" | |
| url_source = "" | |
| url_chunks = [] | |
| url_vectors = [] | |
| return source_type, url_source, url_chunks, url_vectors | |
| def get_most_relevant( | |
| prompt_embedding, sources_embeddings | |
| ): # Get which sources/chunks are most relevant to the prompt | |
| sources_indices = [] # List of indices of the most relevant sources | |
| sources_cosine_sims = [] # List of cosine similarities of the most relevant sources | |
| for ( | |
| source_embeddings | |
| ) in ( | |
| sources_embeddings | |
| ): # source_embeddings contains all the embeddings of each chunk in a source | |
| cosine_sims = np.array( | |
| (source_embeddings @ prompt_embedding) | |
| / ( | |
| np.linalg.norm(source_embeddings, axis=1) | |
| * np.linalg.norm(prompt_embedding) | |
| ) | |
| ) # Calculate the cosine similarity between the prompt and each chunk's vector | |
| # Get the indices of the most relevant chunks: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array | |
| num_chunks = min( | |
| num_citations, len(cosine_sims) | |
| ) # In case there are less chunks than num_citations | |
| indices = np.argpartition(cosine_sims, -num_chunks)[ | |
| -num_chunks: | |
| ] # Get the indices of the most relevant chunks | |
| indices = indices[np.argsort(cosine_sims[indices])] # Sort the indices | |
| cosine_sims = cosine_sims[ | |
| indices | |
| ] # Get the cosine similarities of the most relevant chunks | |
| sources_indices.append(indices) # Add the indices to sources_indices | |
| sources_cosine_sims.append( | |
| cosine_sims | |
| ) # Add the cosine similarities to sources_cosine_sims | |
| # Use sources_indices and sources_cosine_sims to get the most relevant sources/chunks | |
| indexes = [] | |
| max_cosine_sims = [] | |
| for source_idx in range(len(sources_indices)): # For each source | |
| for chunk_idx in range(len(sources_indices[source_idx])): # For each chunk | |
| sources_chunk_idx = sources_indices[source_idx][ | |
| chunk_idx | |
| ] # Get the index of the chunk | |
| similarity = sources_cosine_sims[source_idx][ | |
| chunk_idx | |
| ] # Get the cosine similarity of the chunk | |
| if len(max_cosine_sims) < num_citations: # If max_values is not full | |
| indexes.append( | |
| [source_idx, sources_chunk_idx] | |
| ) # Add the source/chunk index pair to indexes | |
| max_cosine_sims.append( | |
| similarity | |
| ) # Add the cosine similarity to max_values | |
| elif len(max_cosine_sims) == num_citations and similarity > min( | |
| max_cosine_sims | |
| ): # If max_values is full and the current cosine similarity is greater than the minimum cosine similarity in max_values | |
| indexes.append( | |
| [source_idx, sources_chunk_idx] | |
| ) # Add the source/chunk index pair to indexes | |
| max_cosine_sims.append( | |
| similarity | |
| ) # Add the cosine similarity to max_values | |
| min_idx = max_cosine_sims.index( | |
| min(max_cosine_sims) | |
| ) # Get the index of the minimum cosine similarity in max_values | |
| indexes.pop( | |
| min_idx | |
| ) # Remove the source/chunk index pair at the minimum cosine similarity index in indexes | |
| max_cosine_sims.pop( | |
| min_idx | |
| ) # Remove the minimum cosine similarity in max_values | |
| else: # If max_values is full and the current cosine similarity is less than the minimum cosine similarity in max_values | |
| pass | |
| return indexes | |
| def process_source( | |
| source, source_type | |
| ): # Process the source name to be used in a message, since URL files are processed differently | |
| return ( | |
| source if source_type == "file" else source.replace(" ", "/") | |
| ) # In case this is a URL, reverse what was done in embed_url | |
| # TODO: Find better way to create/store messages instead of everytime a new question is asked | |
| def ask(): # Ask a question | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": str( | |
| "You are a helpful chatbot that answers questions a user may have about a topic. " | |
| + "Sometimes, the user may give you external data from which you can use as needed. " | |
| + "They will give it to you in the following way:\n" | |
| + "Source 1: the source's name\n" | |
| + "Text 1: the relevant text from the source\n" | |
| + "Source 2: the source's name\n" | |
| + "Text 2: the relevant text from the source\n" | |
| + "...\n" | |
| + "You can use this data to answer the user's questions or to ask the user questions. " | |
| + "Take note that if you plan to reference a source, ALWAYS do so using the source's name.\n" | |
| ), | |
| }, | |
| {"role": "user", "content": st.session_state["questions"][0]}, | |
| ] # Add the system's introduction message and the user's first question to messages | |
| show_message( | |
| st.session_state["questions"][0], | |
| is_user=True, | |
| key=str(uuid.uuid4()), | |
| avatar_style=user_avatar_style, | |
| ) # Display user's first question | |
| if ( | |
| len(st.session_state["questions"]) > 1 and st.session_state["answers"] | |
| ): # If this is not the first question | |
| for interaction, message in enumerate( | |
| [ | |
| message | |
| for pair in zip( | |
| st.session_state["answers"], st.session_state["questions"][1:] | |
| ) | |
| for message in pair | |
| ] # Get the messages from the previous conversation in the order of [answer, question, answer, question, ...]: https://stackoverflow.com/questions/7946798/interleave-multiple-lists-of-the-same-length-in-python | |
| ): | |
| if interaction % 2 == 0: # If the message is an answer | |
| messages.append( | |
| {"role": "assistant", "content": message} | |
| ) # Add the answer to messages | |
| show_message( | |
| message, | |
| key=str(uuid.uuid4()), | |
| avatar_style=assistant_avatar_style, | |
| ) # Display the answer | |
| else: # If the message is a question | |
| messages.append( | |
| {"role": "user", "content": message} | |
| ) # Add the question to messages | |
| show_message( | |
| message, | |
| is_user=True, | |
| key=str(uuid.uuid4()), | |
| avatar_style=user_avatar_style, | |
| ) # Display the question | |
| if ( | |
| st.session_state["sources_types"] | |
| and st.session_state["sources"] | |
| and st.session_state["chunks"] | |
| and st.session_state["vectors"] | |
| ): # If there are sources that were uploaded | |
| prompt_embedding = np.array( | |
| embed(st.session_state["questions"][-1])[0]["embedding"] | |
| ) # Embed the last question | |
| indexes = get_most_relevant( | |
| prompt_embedding, st.session_state["vectors"] | |
| ) # Get the most relevant chunks | |
| if indexes: # If there are relevant chunks | |
| messages[-1]["content"] += str( | |
| "Here are some sources that may be helpful:\n" | |
| ) # Add the sources to the last message | |
| for idx, ind in enumerate(indexes): | |
| source_idx, chunk_idx = ind[0], ind[1] # Get the source and chunk index | |
| messages[-1]["content"] += str( | |
| "Source " | |
| + str(idx + 1) | |
| + ": " | |
| + process_source( | |
| st.session_state["sources"][source_idx], | |
| st.session_state["sources_types"][source_idx], | |
| ) | |
| + "\n" | |
| + "Text " | |
| + str(idx + 1) | |
| + ": " | |
| + st.session_state["chunks"][source_idx][chunk_idx] # Get the chunk | |
| + "\n" | |
| ) | |
| while ( | |
| get_num_tokens("\n".join([message["content"] for message in messages])) | |
| > llm_context_window | |
| ): # If the context window is too large | |
| if ( | |
| len(messages) == 2 | |
| ): # If there is only the introduction message and the user's most recent question | |
| max_tokens_left = llm_context_window - get_num_tokens( | |
| messages[0]["content"] | |
| ) # Get the maximum number of tokens that can be present in the question | |
| messages[1]["content"] = messages[1]["content"][ | |
| :max_tokens_left | |
| ] # Truncate the question, from https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them 4 chars ~= 1 token, but it isn't certain that this is the case, so we will just truncate the question to max_tokens_left characters to be safe | |
| else: # If there are more than 2 messages | |
| messages.pop(1) # Remove the oldest question | |
| messages.pop(2) # Remove the oldest answer | |
| answer = openai.ChatCompletion.create(model=llm_model, messages=messages)[ | |
| "choices" | |
| ][0]["message"][ | |
| "content" | |
| ] # Get the answer from the chatbot | |
| st.session_state["answers"].append(answer) # Add the answer to answers | |
| show_message( | |
| st.session_state["answers"][-1], | |
| key=str(uuid.uuid4()), | |
| avatar_style=assistant_avatar_style, | |
| ) # Display the answer | |
| # Main function, defines layout of the app | |
| def main(): | |
| # Initialize session state variables | |
| if "questions" not in st.session_state: | |
| st.session_state["questions"] = [] | |
| if "answers" not in st.session_state: | |
| st.session_state["answers"] = [] | |
| if "sources_types" not in st.session_state: | |
| st.session_state["sources_types"] = [] | |
| if "sources" not in st.session_state: | |
| st.session_state["sources"] = [] | |
| if "chunks" not in st.session_state: | |
| st.session_state["chunks"] = [] | |
| if "vectors" not in st.session_state: | |
| st.session_state["vectors"] = [] | |
| st.title("CacheChat :money_with_wings:") # Title | |
| st.markdown( | |
| "Check out the repo [here](https://github.com/andrewhinh/CacheChat) and notes on using the app [here](https://github.com/andrewhinh/CacheChat#notes)." | |
| ) # Link to repo | |
| uploaded_files = st.file_uploader( | |
| "Choose file(s):", accept_multiple_files=True, key="files" | |
| ) # File upload section | |
| if uploaded_files: # If (a) file(s) is/are uploaded, create embeddings | |
| with st.spinner("Processing..."): # Show loading spinner | |
| for uploaded_file in uploaded_files: | |
| if not ( | |
| uploaded_file.name in st.session_state["sources"] | |
| ): # If the file has not been uploaded, process it | |
| with open(uploaded_file.name, "wb") as file: # Save file to disk | |
| file.write(uploaded_file.getbuffer()) | |
| source_type, file_source, file_chunks, file_vectors = embed_file( | |
| uploaded_file.name | |
| ) # Embed file | |
| if ( | |
| not source_type | |
| and not file_source | |
| and not file_chunks | |
| and not file_vectors | |
| ): # If the file is invalid | |
| st.error("Invalid file(s). Please try again.") | |
| else: # If the file is valid | |
| st.session_state["sources_types"].append(source_type) | |
| st.session_state["sources"].append(file_source) | |
| st.session_state["chunks"].append(file_chunks) | |
| st.session_state["vectors"].append(file_vectors) | |
| with st.form(key="url", clear_on_submit=True): # form for question input | |
| uploaded_url = st.text_input( | |
| "Enter a URL:", | |
| placeholder="https://www.africau.edu/images/default/sample.pdf", | |
| ) # URL input text box | |
| upload_url_button = st.form_submit_button(label="Add URL") # Add URL button | |
| if upload_url_button and uploaded_url: # If a URL is entered, create embeddings | |
| with st.spinner("Processing..."): # Show loading spinner | |
| if not ( | |
| uploaded_url in st.session_state["sources"] # Non-file URL in sources | |
| or "./" + uploaded_url.replace("/", " ") # File URL in sources | |
| in st.session_state["sources"] | |
| ): # If the URL has not been uploaded, process it | |
| source_type, url_source, url_chunks, url_vectors = embed_url( | |
| uploaded_url | |
| ) # Embed URL | |
| if ( | |
| not source_type | |
| and not url_source | |
| and not url_chunks | |
| and not url_vectors | |
| ): # If the URL is invalid | |
| st.error("Invalid URL. Please try again.") | |
| else: # If the URL is valid | |
| st.session_state["sources_types"].append(source_type) | |
| st.session_state["sources"].append(url_source) | |
| st.session_state["chunks"].append(url_chunks) | |
| st.session_state["vectors"].append(url_vectors) | |
| st.divider() # Create a divider between the uploads and the chat | |
| input_container = ( | |
| st.container() | |
| ) # container for inputs/uploads, https://docs.streamlit.io/library/api-reference/layout/st.container | |
| response_container = ( | |
| st.container() | |
| ) # container for chat history, https://docs.streamlit.io/library/api-reference/layout/st.container | |
| with input_container: | |
| with st.form(key="question", clear_on_submit=True): # form for question input | |
| uploaded_question = st.text_input( | |
| "Enter your input:", | |
| placeholder="e.g: Summarize the research paper in 3 sentences.", | |
| key="input", | |
| ) # question text box | |
| uploaded_question_button = st.form_submit_button( | |
| label="Send" | |
| ) # send button | |
| with response_container: | |
| if ( | |
| uploaded_question_button and uploaded_question | |
| ): # if send button is pressed and text box is not empty | |
| with st.spinner("Thinking..."): # show loading spinner | |
| st.session_state["questions"].append( | |
| uploaded_question | |
| ) # add question to questions | |
| ask() # ask question to chatbot | |
| if __name__ == "__main__": | |
| main() |