Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_option_menu import option_menu | |
| import os | |
| # from langchain.llms import HuggingFaceHub # old, for calling HuggingFace Inference API (free for our use case) | |
| #from langchain_community.llms import HuggingFaceEndpoint # for calling HuggingFace Inference API (free for our use case) | |
| #from langchain.embeddings import HuggingFaceEmbeddings # to let program know what embeddings the vector store was embedded in earlier | |
| #from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint | |
| from langchain_huggingface.chat_models import ChatHuggingFace | |
| # to set up the agent and tools which will be used to answer questions later | |
| from langchain.agents import initialize_agent | |
| from langchain.agents import tool # decorator so each function will be recognized as a tool | |
| #from langchain.chains.retrieval_qa.base import RetrievalQA # to answer questions from vector store retriever. | |
| from langchain.chains import ConversationalRetrievalChain | |
| # from langchain.chains.question_answering import load_qa_chain # to further customize qa chain if needed | |
| from langchain.vectorstores import Chroma # vector store for retriever | |
| import ast # to parse user string input to list for one of the tools (agent tools do not support 2 inputs) | |
| #from langchain.memory import ConversationBufferMemory # not used as of now | |
| import pickle # for loading the bm25 retriever | |
| from langchain.retrievers import EnsembleRetriever # to use chroma and | |
| # for defining a generic LLMChain as a generic chat tool (if needed) | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| # for printing intermediate steps of agent (actions, tool calling etc.) | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # for web scraping and user to override | |
| from web_scrape_and_pdf_loader import ( | |
| duckduckgo_scrape, | |
| process_links_load_documents, | |
| setup_chromadb_vectorstore, | |
| setup_bm25_retriever, | |
| pdf_loader_local | |
| ) | |
| # look for new retrievers that user created (to override existing ones if user chooses) | |
| import glob | |
| # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API | |
| # alternatively set your env variable above | |
| ################################ Callback ################################ | |
| # callback is needed to print intermediate steps of agent reasoning in the chatbot | |
| # i.e. when action is taken, when tool is called, when tool call is complete etc. | |
| class MyCallbackHandler(BaseCallbackHandler): | |
| def __init__(self): | |
| self.tokens = [] | |
| # def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream unfortunately! | |
| # self.tokens.append(token) | |
| # print(token) | |
| def on_agent_action(self, action, **kwargs): | |
| """Run on agent action.""" | |
| print("\n\nnew action", action) | |
| thought = action.log.replace('\n', ' \n') # so streamlit will recognize as newline | |
| tool_called = action.tool | |
| # tool_input = action.tool_input | |
| calling_tool = f"I am calling the '{tool_called}' tool and waiting for it to give me a result..." | |
| st.session_state.messages.extend( | |
| [{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}] | |
| ) | |
| # Add the response to the chat window | |
| with st.chat_message("assistant"): | |
| st.markdown(thought) | |
| st.markdown(calling_tool) | |
| # def on_agent_finish(self, finish, **kwargs): | |
| # """Run on agent end.""" | |
| # #print("\n\nEnd", finish) | |
| # finish_string = finish.log.replace('\n', ' \n') # so streamlit will recognize as newline | |
| # st.session_state.messages.append( | |
| # {"role": "assistant", "content": finish_string} | |
| # ) | |
| # with st.chat_message("assistant"): | |
| # st.markdown(finish_string) | |
| # def on_llm_start(self, serialized, prompts, **kwargs): | |
| # """Run when LLM starts running.""" | |
| # print("LLM Start: ", prompts) | |
| # def on_llm_end(self, response, **kwargs): | |
| # """Run when LLM ends running.""" | |
| # print(response) | |
| def on_tool_end(self, output, **kwargs): | |
| """Run when tool ends running.""" | |
| #print("\n\nTool End: ", output) | |
| tool_output = f":blue[[Tool Output]] {output} \n \nI am processing the output from the tool..." | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": tool_output} | |
| ) | |
| with st.chat_message("assistant"): | |
| st.markdown(tool_output) | |
| my_callback_handler = MyCallbackHandler() | |
| ################################ Configs ################################ | |
| # Set the webpage title | |
| st.set_page_config( | |
| page_title="ESG Countries Chatbot", | |
| # layout="wide" | |
| ) | |
| # Document Config | |
| if 'countries_override' not in st.session_state: | |
| # countries to override with own documents from uploaded pdf or updated scraped search results | |
| # must first scrape or upload own documents to use this | |
| st.session_state['countries_override'] = [] | |
| if 'chunk_size' not in st.session_state: | |
| st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000] | |
| if 'chunk_overlap' not in st.session_state: | |
| st.session_state['chunk_overlap'] = 100 # choose one of [50, 100, 150, 200] | |
| # Retriever Config | |
| if 'chroma_n_similar_documents' not in st.session_state: | |
| st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic) | |
| if 'bm25_n_similar_documents' not in st.session_state: | |
| st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword) | |
| if 'retriever_config' not in st.session_state: | |
| st.session_state['retriever_config'] = 'Ensemble (Both Re-Ranked)' # choose one of ['semantic', 'keyword', 'ensemble'] | |
| if 'keyword_retriever_weight' not in st.session_state: | |
| st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble | |
| if 'source_documents' not in st.session_state: | |
| st.session_state['source_documents'] = [] # this is to store all source documents for a particular search | |
| # LLM config | |
| # LLM from HuggingFace Inference API | |
| if 'model' not in st.session_state: | |
| #st.session_state['model'] = "mistralai/Mixtral-8x7B-Instruct-v0.1" # or "mistralai/Mistral-7B-Instruct-v0.2" | |
| #st.session_state['model'] = "meta-llama/Llama-3.1-8B-Instruct" | |
| st.session_state['model'] = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| #st.session_state['model'] = "meta-llama/Llama-3.3-70B-Instruct" | |
| #st.session_state['model'] = "meta-llama/Llama-4-Scout-17B-16E-Instruct" | |
| if 'temperature' not in st.session_state: | |
| st.session_state['temperature'] = 0.25 | |
| if 'max_new_tokens' not in st.session_state: | |
| st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM | |
| # This is the list of countries present in the pre-built vector store, since the vector store is previously prepared as they take very long to prepare | |
| # This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store, | |
| # it would still filter the vector store with this country and give an empty result, instead of giving an error. | |
| # We have to manually return the error to let the agent using the tool know. | |
| # The countries were reduced to just 6 as the time taken to get the embeddings to build up the chunks is too long. | |
| # However, having more countries **will not affect** the quality of the answers in comparing between 2 countries in the RAG application | |
| # as the RAG only picks out document chunks for the 2 countries of interest. | |
| countries = [ | |
| "Australia", | |
| "China", | |
| "Japan", | |
| "Malaysia", | |
| "Singapore", | |
| "Germany", | |
| ] | |
| ################################ Get LLM and Embeddings ################################ | |
| def get_llm(): | |
| # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface | |
| # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM | |
| # llm = HuggingFaceHub(repo_id=st.session_state['model'], | |
| # model_kwargs={ | |
| # 'temperature': st.session_state['temperature'], | |
| # "max_new_tokens": st.session_state['max_new_tokens'] | |
| # }, | |
| # ) | |
| llm = HuggingFaceEndpoint( | |
| #endpoint_url=st.session_state['model'], | |
| repo_id=st.session_state['model'], | |
| huggingfacehub_api_token=os.environ['HUGGINGFACEHUB_API_TOKEN'], | |
| #task="text-generation", | |
| task="conversational", # important for models mapped to chat | |
| temperature = st.session_state['temperature'], | |
| max_new_tokens = st.session_state['max_new_tokens'] | |
| ) | |
| return llm | |
| # for chromadb vectore store | |
| def get_embeddings(): | |
| # We use HuggingFaceEmbeddings() as it is open source and free to use. | |
| # Initialize the default hf model for embedding the tokenized texts into vectors with semantic meanings | |
| hf_embeddings = HuggingFaceEmbeddings() | |
| return hf_embeddings | |
| # call above functions | |
| #llm = get_llm() | |
| llm = ChatHuggingFace(llm=get_llm()) | |
| hf_embeddings = get_embeddings() | |
| # when LLM config is changed we will call this function | |
| def update_llm(): | |
| global llm | |
| #llm = get_llm() | |
| llm = ChatHuggingFace(llm=get_llm()) | |
| ################################ Download and Initialize Pre-Built Retrievers ################################ | |
| # Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up, | |
| # to save time when experimenting as the embeddings take a long time to generate. | |
| # The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores, | |
| # just need to change the name of the persist directory when selecting the different chunk sizes and overlaps. | |
| # Later in the main app if the user choose to scrape new data, or override with their own PDF, a new chromadb would be created. | |
| # This step will take some time | |
| if not os.path.exists("bm25.zip"): | |
| with st.spinner(f'Downloading bm25 retriever for all chunk sizes and overlaps, will take some time'): | |
| os.system("gdown https://drive.google.com/uc?id=1q-hNnyyBA8tKyF3vR69nkwCk9kJj7WHi") | |
| if not os.path.exists("chromadb.zip"): | |
| with st.spinner(f'Downloading chromadb retrievers for all chunk sizes and overlaps, will take some time'): | |
| os.system("gdown https://drive.google.com/uc?id=1zad6tgYm2o5M9E2dTLQqmm6GoI8kxNC3") | |
| if not os.path.exists("bm25/"): | |
| with st.spinner(f'Unzipping bm25 retriever for all chunk sizes and overlaps, will take some time'): | |
| os.system("unzip bm25.zip") | |
| if not os.path.exists("chromadb/"): | |
| with st.spinner(f'Unzipping chromadb retrievers for all chunk sizes and overlaps, will take some time'): | |
| os.system("unzip chromadb.zip") | |
| # One retriever below is semantic based (chromadb) and the other is keyword based (bm25) | |
| # Both retrievers will be used | |
| # Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below | |
| def get_retrievers(): | |
| persist_directory = f"chromadb/chromadb_esg_countries_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}" | |
| with st.spinner(f'Setting up pre-built chroma vector store'): | |
| chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings) | |
| # Initialize BM25 Retriever | |
| # Unlike Chroma (semantic) BM25 is a keyword-based algorithm that performs well on queries containing keywords without capturing the semantic meaning of the query terms, | |
| # hence there is no need to embed the text with HuggingFaceEmbeddings and it is relatively faster to set up. | |
| # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget. | |
| # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country. | |
| # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.) | |
| # A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap. | |
| bm25_retrievers = {} # to store retrievers for different countries | |
| with st.spinner(f'Setting up pre-built bm25 retrievers'): | |
| for country in countries: | |
| bm25_filename = f"bm25/bm25_esg_countries_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}.pickle" | |
| with open(bm25_filename, 'rb') as handle: | |
| bm25_retriever = pickle.load(handle) | |
| bm25_retrievers[country] = bm25_retriever | |
| return chroma_db, bm25_retrievers | |
| chroma_db, bm25_retrievers = get_retrievers() | |
| # when retriever config is changed we will call this function | |
| def update_retrievers(): | |
| global chroma_db | |
| global bm25_retrievers | |
| chroma_db, bm25_retrievers = get_retrievers() | |
| chroma_db_new = None | |
| bm25_new_retrievers = {} # to store retrievers for different countries | |
| # get retrievers for country which we override | |
| if len(st.session_state['countries_override']) > 0: | |
| for country in st.session_state['countries_override']: | |
| chroma_db_new = Chroma(persist_directory=f"chromadb/new_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}_",embedding_function=hf_embeddings) | |
| bm25_filename = f"bm25/new_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}_.pickle" | |
| with open(bm25_filename, 'rb') as handle: | |
| bm25_retriever = pickle.load(handle) | |
| bm25_new_retrievers[country] = bm25_retriever | |
| # check if there are any new retrievers where user uploaded PDF or scraped new links and return list of countries for them | |
| def check_for_new_retrievers(): | |
| # see if retrievers/vector stores created by user's own uploaded PDF or newly scraped data is found | |
| new_documents_chroma = glob.glob("chromadb/new*") | |
| new_documents_bm25 = glob.glob("bm25/new*") | |
| new_documents_chroma = [os.path.split(doc)[-1] for doc in new_documents_chroma] | |
| new_documents_bm25 = [os.path.split(doc)[-1] for doc in new_documents_bm25] | |
| new_countries = [] | |
| # loop through new docs in chroma retrievers created by user scraping/pdf (if any) | |
| try: | |
| for doc in new_documents_chroma: | |
| #print(doc) | |
| if ((doc + ".pickle") in new_documents_bm25): # check that the doc also exists for bm25 retriever | |
| new_doc_country = doc.split('_')[1] | |
| new_doc_chunk_size = doc.split('_')[3] | |
| new_doc_chunk_overlap = doc.split('_')[5] | |
| # check that the retrievers are created for the current selected chunk sizes | |
| if ((new_doc_chunk_overlap == str(st.session_state['chunk_overlap'])) & (new_doc_chunk_size == str(st.session_state['chunk_size']))): | |
| new_countries.append(new_doc_country) | |
| except Exception as e: | |
| print(e) | |
| if len(new_countries) == 0: | |
| info = ' (Own documents are :red[NOT FOUND]. Must first scrape or upload own PDF (in menu above) before you can select any countries to override.)' | |
| else: | |
| info = ' (⚠️Own documents for the following countries :green[FOUND], select them in the list below to override.)' | |
| return new_countries, info | |
| ################################ Tools for Agent to Use ################################ | |
| # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies, | |
| # e.g. carbon emissions policy of Singapore. | |
| # By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them. | |
| # This is far better than just retrieving relevant chunks for the user's query and throwing everything to a single RetrievalQA chain to process | |
| # Multi input tools are not available, hence we have to prompt the agent to give an input list as a string | |
| # then use ast.literal_eval to convert it back into a list | |
| def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change diff chain type diff version answers, change | |
| """Gives answer to a query about a single country's public ESG policy. | |
| The input list should be of the following format: | |
| [query, country] | |
| The first element of the list is the user query, surrounded by double quotes. | |
| The second element is the full name of the country involved, surrounded by double quotes, for example "Singapore". | |
| The 2 inputs are separated by a comma. Do not write a list comprehension. | |
| The 2 inputs, together, are surrounded by square brackets as it is a list. | |
| Do not put multiple countries into the input at once. Instead use this tool multiple times, one time for each country. | |
| If you have multiple queries to ask about a country, break the query into separate parts and use this tool multiple times, one for each query. | |
| """ | |
| try: | |
| query_and_country_list = ast.literal_eval(query_and_country) | |
| query = query_and_country_list[0] | |
| country = query_and_country_list[1].capitalize() # in case LLM did not capitalize first letter as filtering for metadata is case sensitive | |
| if not country in (countries + st.session_state['countries_override']): | |
| return """The country that you input into the tool cannot be found. | |
| If you did not make a mistake and the country that you input is indeed what the user asked, | |
| then there is no record for the country and no answer can be obtained.""" | |
| # if there are countries we want to override | |
| if country in st.session_state['countries_override']: | |
| # keyword | |
| bm = bm25_new_retrievers [country] | |
| #bm.k = st.session_state['bm25_n_similar_documents'] | |
| try: | |
| bm.k = int(st.session_state['bm25_n_similar_documents']) | |
| except Exception: | |
| fs = getattr(bm, "__pydantic_fields_set__", None) | |
| if isinstance(fs, dict): bm.__pydantic_fields_set__ = set(fs) | |
| fs1 = getattr(bm, "__fields_set__", None) | |
| if isinstance(fs1, dict): bm.__fields_set__ = set(fs1) | |
| bm.k = int(st.session_state['bm25_n_similar_documents']) | |
| # semantic | |
| chroma = chroma_db_new.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) | |
| else: | |
| # keyword | |
| bm = bm25_retrievers[country] | |
| #bm.k = st.session_state['bm25_n_similar_documents'] | |
| try: | |
| bm.k = int(st.session_state['bm25_n_similar_documents']) | |
| except Exception: | |
| fs = getattr(bm, "__pydantic_fields_set__", None) | |
| if isinstance(fs, dict): bm.__pydantic_fields_set__ = set(fs) | |
| fs1 = getattr(bm, "__fields_set__", None) | |
| if isinstance(fs1, dict): bm.__fields_set__ = set(fs1) | |
| bm.k = int(st.session_state['bm25_n_similar_documents']) | |
| # semantic | |
| chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) | |
| # ensemble (below) reranks results from both retrievers above | |
| ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']]) | |
| # for user to make selection | |
| retrievers = {'Ensemble (Both Re-Ranked)': ensemble, 'Semantic (Chroma DB)': chroma, 'Keyword (BM 2.5)': bm} | |
| # new | |
| chat_history = [] # list of (user, ai) tuples, memory not really needed here but I decided to try it anyway, RetrievalQA should be enough | |
| qa = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retrievers[st.session_state['retriever_config']], # selected retriever based on user config | |
| return_source_documents=True, | |
| #verbose=False, | |
| ) | |
| #out = qa({"question": query, "chat_history": chat_history}) | |
| out = qa.invoke({"question": query, "chat_history": chat_history}) | |
| answer = out.get("answer", "") | |
| sources = out.get("source_documents", []) | |
| chat_history.append((query, answer)) | |
| # old | |
| # qa = RetrievalQA.from_chain_type( | |
| # llm=llm, | |
| # chain_type='stuff', | |
| # retriever=retrievers[st.session_state['retriever_config']], # selected retriever based on user config | |
| # return_source_documents=True # returned in result['source_documents'] | |
| # ) | |
| #result = qa(query) | |
| # add to source documents session state so it can be loaded later in the other menu | |
| # all source documents linked to answer any query (or part of it) are visible | |
| st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.") | |
| #st.session_state['source_documents'].append(result['source_documents']) | |
| st.session_state['source_documents'].append(sources) | |
| return f"'{query.capitalize()}' for '{country}': " + answer | |
| #return f"'{query.capitalize()}' for '{country}': " + result['result'] | |
| except Exception as e: | |
| return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again. | |
| Remember the 2 inputs, query and country, must both be surrounded by double quotes. | |
| The 2 inputs, together, are surrounded by square brackets as it is a list.""" | |
| # if a user tries to casually chat with the agent chatbot, the LLM will be able to use this tool to reply instead | |
| # this is optional, better to let user's know the chatbot is not for casual chatting | |
| def generic_chat_llm(query: str) -> str: | |
| """Use this tool for general queries and casual chat. Forward the user input directly into this tool, do not come up with your own input. | |
| This tool IS NOT FOR MAKING COMPARISONS of anything. | |
| This tool IS NOT FOR FINDING ESG POLICY of any country! | |
| It is only for casual chat! Do not use this tool unnecessarily! | |
| """ | |
| try: | |
| # Second Generic Tool | |
| prompt = PromptTemplate( | |
| input_variables=["query"], | |
| template="{query}" | |
| ) | |
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |
| return llm_chain.run(query) | |
| except Exception as e: | |
| return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again. | |
| If you have already tried 2 times, do not try anymore, there is no response for your input. | |
| Move on to the next step of your plan.""" | |
| # sometimes the agent will suddenly ask for a 'compare' tool even though it was not given this tool | |
| # hence I have decided to give it this tool that gives a prompt to remind it to look at past information | |
| # and decide whether it is time to darw a conclusion | |
| # tools cannot have no input, hence I let the agent input a 'query' parameter even though it is not used | |
| # having the query as input let the LLM 'recall' what is being asked | |
| # instead of it being lost all the way at the start of the ReAct process | |
| def compare(query:str) -> str: | |
| """Use this tool to give you hints and instructions on how you can compare between policies of countries. | |
| Use this tool as a final step, only after you have used other tools to obtain all the information you need. | |
| When putting the query into this tool, look at the entire query that the user has asked at the start, | |
| do not leave any details in the query out. | |
| """ | |
| return f"""Once again, check through all your previous observations to answer the user query. | |
| Make sure every part of the query is addressed by the context, or that you have at least tried to do so. | |
| Make sure you have not forgotten to address anything in the query. | |
| If you still need more details, you can use another tool to find out more if you have not tried using the same tool with the necessary input earlier. | |
| If you have enough information, use your reasoning to answer them to the best of your ability. | |
| Give as much elaboration in your answer as possible but they MUST be from the earlier context. | |
| Do not give details that cannot be found in the earlier context.""" | |
| # equip tools with callbacks | |
| retrieve_answer_for_country.callbacks = [my_callback_handler] | |
| compare.callbacks = [my_callback_handler] | |
| generic_chat_llm.callbacks = [my_callback_handler] | |
| # Initialize | |
| agent = initialize_agent( | |
| [retrieve_answer_for_country, compare], # tools | |
| # uncomment below if want to enable general chat option also, if user engages bot with casual talk | |
| # however user should be advised not to do this | |
| # [generic_chat_llm, retrieve_answer_for_country, compare], | |
| llm=llm, | |
| agent="zero-shot-react-description", # this is good | |
| verbose=False, | |
| handle_parsing_errors=True, | |
| return_intermediate_steps=True, | |
| callbacks=[my_callback_handler] | |
| # no memories, limited RAM in HuggingFaceSpaces | |
| # in production mode conversation can be stored for separate users/chat sessions in postgresql database | |
| # memory=ConversationBufferMemory( | |
| # memory_key="chat_history", return_messages=True | |
| # ), | |
| # max_iterations=10 | |
| ) | |
| ################################ Sidebar with Menu ################################ | |
| with st.sidebar: | |
| st.title("ESG Countries Chatbot") | |
| page = option_menu("Menu", | |
| [ | |
| "Main Chatbot", | |
| "View Source Docs for Last Query", | |
| "Scrape or Upload Own Docs", | |
| ], | |
| icons=['robot', 'list-task', 'cloud-upload-fill'], | |
| default_index=0) | |
| with st.expander("Warning", expanded = True): | |
| st.write("⚠️ DO NOT navigate between pages or change config when chat is ongoing. Wait for query to complete first.") | |
| st.write("") | |
| new_countries, info = check_for_new_retrievers() | |
| # if new retrievers that pass the above criteria are found, let the user know their countries | |
| # the user can select from these countries to override existing retrievers | |
| # otherwise prompt user to scrape or upload own PDF to create the new retrievers | |
| with st.expander("Document Config", expanded = True): | |
| st.multiselect( | |
| 'Countries to Override with Own Docs:' + info, | |
| new_countries, | |
| key="countries_override" | |
| ) | |
| st.selectbox( | |
| "Chunk Size", | |
| options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000], | |
| on_change=update_retrievers, | |
| key="chunk_size" | |
| ) | |
| st.selectbox( | |
| "Chunk Overlap", | |
| options=[50, 100, 150, 200], | |
| on_change=update_retrievers, | |
| key="chunk_overlap" | |
| ) | |
| st.write("") | |
| with st.expander("LLM Config", expanded = True): | |
| st.selectbox( | |
| "HuggingFace Inference Model", | |
| options=[ | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| "meta-llama/Llama-3.3-70B-Instruct", | |
| "meta-llama/Llama-4-Scout-17B-16E-Instruct" | |
| ], | |
| on_change=update_llm, | |
| key="model" | |
| ) | |
| st.slider( | |
| "Temperature", | |
| 0.0, 1.0, 0.05, | |
| #value = st.session_state['temperature'], | |
| on_change=update_llm, | |
| key="temperature" | |
| ) | |
| st.slider( | |
| "Max Tokens Generated", | |
| 200, 1000, | |
| on_change=update_llm, | |
| key="max_new_tokens" | |
| ) | |
| st.write("") | |
| with st.expander("Retriever Config", expanded = True): | |
| st.selectbox( | |
| "Retriever to Use", | |
| options=['Ensemble (Both Re-Ranked)', 'Semantic (Chroma DB)', 'Keyword (BM 2.5)'], | |
| key="retriever_config" | |
| ) | |
| st.slider( | |
| "Keyword Retriever Weight (If using ensemble retriever, this is the weight of the keyword retriever, semantic retriever would be 1 minus this value)", | |
| 0.0, 0.05, 1.0, | |
| key="keyword_retriever_weight" | |
| ) | |
| st.number_input( | |
| "Number of Relevant Documents Returned by Keyword Retriever (BM25)", | |
| 0, 20, | |
| key="bm25_n_similar_documents" | |
| ) | |
| st.number_input( | |
| "Number of Relevant Documents Returned by Semantic Retriever (ChromaDB)", | |
| 0, 20, | |
| key="chroma_n_similar_documents" | |
| ) | |
| ################################ Main Chatbot Page ################################ | |
| if page == "Main Chatbot": | |
| st.subheader("Chatbot") | |
| # Store the conversation in the session state. | |
| # Used to render the chat conversation. | |
| # Initialize it with the first message for users to be greeted with | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [ | |
| {"role": "assistant", | |
| "content": f""" | |
| Hello, I am a chatbot which specializes in ESG policies of countries. | |
| Currently I have data for {(', ').join(countries)}. | |
| You can update the data or add data for more countries in the left menu under ""Scrape or Upload Own Docs". | |
| You can ask me to compare specific policies between multiple countries too. An example of a question you can ask me is: | |
| "What are the differences between carbon emissions policy in Singapore, Malaysia and China?" How may I help you today? | |
| """} | |
| ] | |
| # Loop through each message in the session state and render it as a chat message | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # We take questions/instructions from the chat input to pass to the LLM | |
| if user_query := st.chat_input("Your message here", key="user_input"): | |
| # reset source documents list during a new query | |
| st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list | |
| # Add our input to the session state | |
| formatted_user_query = f":blue[{user_query}]" | |
| st.session_state.messages.append( | |
| {"role": "user", "content": formatted_user_query} | |
| ) | |
| # Add our input to the chat window | |
| with st.chat_message("user"): | |
| st.markdown(formatted_user_query) | |
| # Let user know agent is planning the actions | |
| action_plan_message = "Please wait while I plan out a best set of actions to obtain the necessary information to answer your query." | |
| # Add the response to the session state | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": action_plan_message} | |
| ) | |
| # Add the response to the chat window | |
| with st.chat_message("assistant"): | |
| st.markdown(action_plan_message) | |
| results = agent(user_query) | |
| response = f":blue[The answer to your query is:] {results['output']}" | |
| # Add the response to the session state | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": response} | |
| ) | |
| # Add the response to the chat window | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| ################################ Source Documents Page ################################ | |
| if page == "View Source Docs for Last Query": | |
| st.subheader("Source Documents for Last Query") | |
| try: | |
| st.subheader(st.session_state['source_documents'][0]) | |
| for doc in st.session_state['source_documents'][1:]: | |
| #st.write("Source: " + doc['page_content']) | |
| st.write(doc) | |
| except: | |
| st.write("No source documents retrieved yet. Please run a full user query before coming back to this page.") | |
| ################################ Scrap or Upload Documents Page ################################ | |
| # to scrape new documents from DuckDuckGo | |
| # to upload own PDF | |
| # to override existing data on new scraped data or new pdf uploaded | |
| if page == "Scrape or Upload Own Docs": | |
| st.header("Scrape or Upload Own PDF") | |
| st.write("Here you can choose to upload your own PDF or scrape more recent data via DuckDuckGo search for a selected country below.") | |
| st.write(":blue[NOTE: Certain countries were not present in the original default vector stores, you can scrape data for these countries too so you can ask about them in the chat.]") | |
| st.write("You will create new BM2.5 (keyword) and Chroma (semantic) retrievers for it. Note that this can take a very long time.") | |
| country_scrape_upload = st.selectbox( | |
| "Select Country", | |
| options=[ | |
| "Australia", "Bangladesh", "Brunei", "Cambodia", "China", "India", "Indonesia", "Japan", "Laos", "Macau", "Malaysia", "Myanmar", | |
| "Nepal", "Philippines", "Singapore", "South Korea", "Sri Lanka", "Thailand", "Vietnam", "France", "Germany", "Israel", "Poland", | |
| "Sweden", "Turkey", "United Kingdom", "United States" | |
| ], | |
| ) | |
| # display documents chunk sizes and overlaps | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.container(border = True): | |
| st.write("New Documents Chunk Size: (Can change in sidebar)" ) | |
| st.text(f"{st.session_state['chunk_size']}" ) | |
| with col2: | |
| with st.container(border = True): | |
| st.write("New Documents Chunk Overlap: (Can change in sidebar)" ) | |
| st.text(f"{st.session_state['chunk_overlap']}") | |
| # how user wishes to populate documents | |
| options = [ | |
| "Upload Own PDF", | |
| "Automatically Scrape Web Data using DuckDuckGo (may take more than 5 mins)" | |
| ] | |
| option = st.radio( | |
| "How Do You Wish To Create New Documents", | |
| options=options | |
| ) | |
| submit_upload_pdf = False | |
| submit_scrape_web = False | |
| submit_scrape_vector_store = False | |
| # save new retrievers in local directory | |
| def save_new_retrievers(all_documents, chunk_size, chunk_overlap, country_scrape_upload): | |
| with st.spinner('Setting up new bm25 retrievers with documents, may take more than 5 mins...'): | |
| # vectorstore for this country will be stored in "bm25/new_{country}_chunk_{chunk_size}_overlap_{chunk_overlap}_" | |
| # can be used to override existing vectorstore for this country in sidebar document configuration | |
| setup_bm25_retriever(all_documents, chunk_size, chunk_overlap, country_scrape_upload) | |
| with st.spinner('Setting up new chromadb vector stores with documents, may take more than 5 mins...'): | |
| # vectorstore for this country will be stored in "chroma_db/new_{country}_chunk_{chunk_size}_overlap_{chunk_overlap}_" | |
| # can be used to override existing vectorstore for this country in sidebar document configuration | |
| setup_chromadb_vectorstore(hf_embeddings, all_documents, chunk_size, chunk_overlap, country_scrape_upload) | |
| st.toast(":blue[SUCCESS!] New retrievers set up with your new data. To override data for this country, you can :blue[Select the Countries to Override in the 'Document Config'] section of the left sidebar.") | |
| st.rerun() | |
| # form for user to configure pdf loading options | |
| if option == options[0]: | |
| with st.form(key='upload_pdf_form'): | |
| st.subheader(f"Selected Option: {option}") | |
| uploaded_pdf = st.file_uploader("Upload a PDF") | |
| if uploaded_pdf: | |
| temp_file = uploaded_pdf.name | |
| with open(temp_file, "wb") as file: | |
| file.write(uploaded_pdf.getvalue()) | |
| submit_upload_pdf = st.form_submit_button(label='Upload and Create Vector Store (Scroll down after clicking)') | |
| st.markdown(":blue[NOTE:] After you are done creating the vector store, the country will appear under :blue[Countries to Override in the 'Document Config'] section of the left sidebar. Select the country to override it.") | |
| if submit_upload_pdf: | |
| try: | |
| with st.spinner('Generating documents from PDF...may take more than 5 mins...'): | |
| all_documents = pdf_loader_local(temp_file, country_scrape_upload) | |
| #st.write(all_documents) | |
| save_new_retrievers(all_documents, st.session_state['chunk_size'], st.session_state['chunk_overlap'], country_scrape_upload) | |
| except Exception as e: | |
| st.write(f"Error! Did you remember to upload the PDF file? Error Message: {e}") | |
| # form for user to configure web scraping for duckduckgo | |
| if option == options[1]: | |
| with st.form(key='scrape_web_form'): | |
| st.subheader(f"Selected Option: {option}") | |
| n_search_results = st.number_input( | |
| "How many DuckDuckGo search results would you like to scrape? In the default vector stores, the number is 10 but it will take a very long time!", | |
| 0, 20, | |
| value = 5 | |
| ) | |
| search_term = st.text_input( | |
| "Search Term", | |
| value = f"{country_scrape_upload} sustainability esg newest updated public policy document government", | |
| ) | |
| submit_scrape_web = st.form_submit_button(label='Scrape Web for Results and Create Vector Store (Scroll down after clicking)') | |
| st.markdown(":blue[NOTE:] After you are done creating the vector store, the country will appear under :blue[Countries to Override in the 'Document Config'] section of the left sidebar. Select the country to override it.") | |
| if submit_scrape_web: | |
| with st.spinner('Scraping web using Duck Duck Go search...'): | |
| all_links, df_links = duckduckgo_scrape(country_scrape_upload, search_term, n_search_results) | |
| st.write(f"Results from Web Scrape") | |
| try: | |
| st.write(df_links) | |
| except: | |
| st.write("Waiting for web scraping results.") | |
| with st.spinner('Generating documents from web search results...may take more than 5 mins...'): | |
| all_documents = process_links_load_documents(all_links) | |
| save_new_retrievers(all_documents, st.session_state['chunk_size'], st.session_state['chunk_overlap'], country_scrape_upload) |