Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| To run: | |
| - activate the virtual environment | |
| - streamlit run path\to\streamlit_app.py | |
| """ | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import warnings | |
| import shutil | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| import openai | |
| import pandas as pd | |
| import streamlit as st | |
| from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode | |
| from streamlit_chat import message | |
| from streamlit_langchain_chat.constants import * | |
| from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat | |
| from streamlit_langchain_chat.dataset import Dataset | |
| # Configure logger | |
| logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True) | |
| logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
| warnings.filterwarnings('ignore') | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'costs' not in st.session_state: | |
| st.session_state['costs'] = [] | |
| if 'contexts' not in st.session_state: | |
| st.session_state['contexts'] = [] | |
| if 'chunks' not in st.session_state: | |
| st.session_state['chunks'] = [] | |
| if 'user_input' not in st.session_state: | |
| st.session_state['user_input'] = "" | |
| if 'dataset' not in st.session_state: | |
| st.session_state['dataset'] = None | |
| def check_api_keys() -> bool: | |
| source_id = app.params['source_id'] | |
| index_id = app.params['index_id'] | |
| open_api_key = os.getenv('OPENAI_API_KEY', '') | |
| openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 | |
| pinecone_api_key = os.getenv('PINECONE_API_KEY', '') | |
| pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True | |
| is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False | |
| return is_ready | |
| def check_combination_point() -> bool: | |
| type_id = app.params['type_id'] | |
| open_api_key = os.getenv('OPENAI_API_KEY', '') | |
| openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 | |
| api_base = app.params['api_base'] | |
| if type_id == 1: | |
| deployment_id = app.params['deployment_id'] | |
| return True if openapi_api_key_ready and api_base and deployment_id else False | |
| elif type_id == 2: | |
| return True if openapi_api_key_ready and api_base else False | |
| else: | |
| return False | |
| def check_index() -> bool: | |
| dataset = st.session_state['dataset'] | |
| index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False | |
| without_source = app.params['source_id'] == 4 | |
| is_ready = True if index_built or without_source else False | |
| return is_ready | |
| def check_index_point() -> bool: | |
| index_id = app.params['index_id'] | |
| pinecone_api_key = os.getenv('PINECONE_API_KEY', '') | |
| pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True | |
| pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True | |
| is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False | |
| return is_ready | |
| def check_params_point() -> bool: | |
| max_sources = app.params['max_sources'] | |
| temperature = app.params['temperature'] | |
| is_ready = True if max_sources and isinstance(temperature, float) else False | |
| return is_ready | |
| def check_source_point() -> bool: | |
| return True | |
| def clear_chat_history(): | |
| if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']: | |
| st.session_state['past'] = [] | |
| st.session_state['generated'] = [] | |
| st.session_state['contexts'] = [] | |
| st.session_state['chunks'] = [] | |
| st.session_state['costs'] = [] | |
| def clear_index(): | |
| if dataset := st.session_state['dataset']: | |
| # delete directory (with files) | |
| index_path = dataset.index_path | |
| if index_path.exists(): | |
| shutil.rmtree(str(index_path)) | |
| # update variable | |
| st.session_state['dataset'] = None | |
| elif (TEMP_DIR / "default").exists(): | |
| shutil.rmtree(str(TEMP_DIR / "default")) | |
| def check_sources() -> bool: | |
| uploaded_files_rows = app.params['uploaded_files_rows'] | |
| urls_df = app.params['urls_df'] | |
| source_id = app.params['source_id'] | |
| some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False | |
| some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) | |
| only_local_files = some_files and not some_urls | |
| only_urls = not some_files and some_urls | |
| is_ready = only_local_files or only_urls or (source_id == 4) | |
| return is_ready | |
| def collect_dataset_and_built_index(): | |
| start = time.time() | |
| uploaded_files_rows = app.params['uploaded_files_rows'] | |
| urls_df = app.params['urls_df'] | |
| type_id = app.params['type_id'] | |
| temperature = app.params['temperature'] | |
| index_id = app.params['index_id'] | |
| api_base = app.params['api_base'] | |
| deployment_id = app.params['deployment_id'] | |
| some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False | |
| some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) | |
| openai.api_type = "azure" if type_id == 1 else "open_ai" | |
| openai.api_base = api_base | |
| openai.api_version = "2023-03-15-preview" if type_id == 1 else None | |
| if deployment_id != "text-davinci-003": | |
| dataset = Dataset( | |
| llm=ChatOpenAI( | |
| temperature=temperature, | |
| max_tokens=512, | |
| deployment_id=deployment_id, | |
| ) | |
| ) | |
| else: | |
| dataset = Dataset( | |
| llm=OpenAI( | |
| temperature=temperature, | |
| max_tokens=512, | |
| deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'), | |
| ) | |
| ) | |
| # get url documents | |
| if some_urls: | |
| urls_df = urls_df.reset_index() | |
| for url_index, url_row in urls_df.iterrows(): | |
| url = url_row.get('urls', '') | |
| citation = url_row.get('citation string', '') | |
| if url: | |
| try: | |
| dataset.add( | |
| url, | |
| citation, | |
| citation, | |
| disable_check=True # True to accept Japanese letters | |
| ) | |
| except Exception as e: | |
| print(e) | |
| pass | |
| # dataset is pandas dataframe | |
| if some_files: | |
| for uploaded_files_row in uploaded_files_rows: | |
| key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None | |
| dataset.add( | |
| uploaded_files_row.get('filepath'), | |
| uploaded_files_row.get('citation string'), | |
| key=key, | |
| disable_check=True # True to accept Japanese letters | |
| ) | |
| openai_embeddings = OpenAIEmbeddings( | |
| document_model_name="text-embedding-ada-002", | |
| query_model_name="text-embedding-ada-002", | |
| ) | |
| if index_id == 1: | |
| dataset._build_faiss_index(openai_embeddings) | |
| else: | |
| dataset._build_pinecone_index(openai_embeddings) | |
| st.session_state['dataset'] = dataset | |
| if OPERATING_MODE == "debug": | |
| print(f"time to collect dataset: {time.time() - start:.2f} [s]") | |
| def configure_streamlit_and_page(): | |
| # Configure Streamlit page and state | |
| st.set_page_config(**ST_CONFIG) | |
| # Force responsive layout for columns also on mobile | |
| st.write( | |
| """<style> | |
| [data-testid="column"] { | |
| width: calc(50% - 1rem); | |
| flex: 1 1 calc(50% - 1rem); | |
| min-width: calc(50% - 1rem); | |
| } | |
| </style>""", | |
| unsafe_allow_html=True, | |
| ) | |
| def get_answer(): | |
| query = st.session_state['user_input'] | |
| dataset = st.session_state['dataset'] | |
| type_id = app.params['type_id'] | |
| index_id = app.params['index_id'] | |
| max_sources = app.params['max_sources'] | |
| if query and dataset and type_id and index_id: | |
| chat_history = [(past, generated) | |
| for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])] | |
| marginal_relevance = False if not index_id == 1 else True | |
| start = time.time() | |
| openai_embeddings = OpenAIEmbeddings( | |
| document_model_name="text-embedding-ada-002", | |
| query_model_name="text-embedding-ada-002", | |
| ) | |
| result = dataset.query( | |
| query, | |
| openai_embeddings, | |
| chat_history, | |
| marginal_relevance=marginal_relevance, # if pinecone is used it must be False | |
| ) | |
| if OPERATING_MODE == "debug": | |
| print(f"time to get answer: {time.time() - start:.2f} [s]") | |
| print("-" * 10) | |
| # response = {'generated_text': result.formatted_answer} | |
| # response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug | |
| return result | |
| else: | |
| return None | |
| def load_main_page(): | |
| """ | |
| Load the body of web. | |
| """ | |
| # Streamlit HTML Markdown | |
| # st.title <h1> # | |
| # st.header <h2> ## | |
| # st.subheader <h3> ### | |
| st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})") | |
| validate_status() | |
| st.markdown(f"#### **Status**: {app.params['status']}") | |
| # hidden div with anchor | |
| st.markdown("<div id='linkto_top'></div>", unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns(3) | |
| col1.button(label="clear index", type="primary", on_click=clear_index) | |
| col2.button(label="clear conversation", type="primary", on_click=clear_chat_history) | |
| col3.markdown("<a href='#linkto_bottom'>Link to bottom</a>", unsafe_allow_html=True) | |
| if st.session_state["generated"]: | |
| for i in range(len(st.session_state["generated"])): | |
| message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
| message(st.session_state['generated'][i], key=str(i)) | |
| with st.expander("See context"): | |
| st.write(st.session_state['contexts'][i]) | |
| with st.expander("See chunks"): | |
| st.write(st.session_state['chunks'][i]) | |
| with st.expander("See costs"): | |
| st.write(st.session_state['costs'][i]) | |
| dataset = st.session_state['dataset'] | |
| index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False | |
| without_source = app.params['source_id'] == 4 | |
| enable_chat_button = index_built or without_source | |
| st.text_input("You:", | |
| key='user_input', | |
| on_change=on_enter, | |
| disabled=not enable_chat_button | |
| ) | |
| st.markdown("<a href='#linkto_top'>Link to top</a>", unsafe_allow_html=True) | |
| # hidden div with anchor | |
| st.markdown("<div id='linkto_bottom'></div>", unsafe_allow_html=True) | |
| def load_sidebar_page(): | |
| st.sidebar.markdown("## Instructions") | |
| # ############ # | |
| # SOURCES TYPE # | |
| # ############ # | |
| st.sidebar.markdown("1. Select a source:") | |
| source_selected = st.sidebar.selectbox( | |
| "Choose the location of your info to give context to chatgpt", | |
| [key for key, value in SOURCES_IDS.items()]) | |
| app.params['source_id'] = SOURCES_IDS.get(source_selected, None) | |
| # ##### # | |
| # MODEL # | |
| # ##### # | |
| st.sidebar.markdown("2. Select a model (LLM):") | |
| combination_selected = st.sidebar.selectbox( | |
| "Choose type: MSF Azure OpenAI and model / OpenAI", | |
| [key for key, value in TYPE_IDS.items()]) | |
| app.params['type_id'] = TYPE_IDS.get(combination_selected, None) | |
| if app.params['type_id'] == 1: # with AzureOpenAI endpoint | |
| # https://docs.streamlit.io/library/api-reference/widgets/st.text_input | |
| os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( | |
| label="Enter Azure OpenAI API Key", | |
| type="password" | |
| ).strip() | |
| app.params['api_base'] = st.sidebar.text_input( | |
| label="Enter Azure API base", | |
| placeholder="https://<api_base_endpoint>.openai.azure.com/", | |
| ).strip() | |
| app.params['deployment_id'] = st.sidebar.text_input( | |
| label="Enter Azure deployment_id", | |
| ).strip() | |
| elif app.params['type_id'] == 2: # with OpenAI endpoint | |
| os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( | |
| label="Enter OpenAI API Key", | |
| placeholder="sk-...", | |
| type="password" | |
| ).strip() | |
| app.params['api_base'] = "https://api.openai.com/v1" | |
| app.params['deployment_id'] = None | |
| # ####### # | |
| # INDEXES # | |
| # ####### # | |
| st.sidebar.markdown("3. Select a index store:") | |
| index_selected = st.sidebar.selectbox( | |
| "Type of Index", | |
| [key for key, value in INDEX_IDS.items()]) | |
| app.params['index_id'] = INDEX_IDS.get(index_selected, None) | |
| if app.params['index_id'] == 2: # with pinecone | |
| os.environ['PINECONE_API_KEY'] = st.sidebar.text_input( | |
| label="Enter pinecone API Key", | |
| type="password" | |
| ).strip() | |
| os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input( | |
| label="Enter pinecone environment", | |
| placeholder="eu-west1-gcp", | |
| ).strip() | |
| # ############## # | |
| # CONFIGURATIONS # | |
| # ############## # | |
| st.sidebar.markdown("4. Choose configuration:") | |
| # https://docs.streamlit.io/library/api-reference/widgets/st.number_input | |
| max_sources = st.sidebar.number_input( | |
| label="Top-k: Number of chunks/sections (1-5)", | |
| step=1, | |
| format="%d", | |
| value=5 | |
| ) | |
| app.params['max_sources'] = max_sources | |
| temperature = st.sidebar.number_input( | |
| label="Temperature (0.0 – 1.0)", | |
| step=0.1, | |
| format="%f", | |
| value=0.0, | |
| min_value=0.0, | |
| max_value=1.0 | |
| ) | |
| app.params['temperature'] = round(temperature, 1) | |
| # ############## # | |
| # UPLOAD SOURCES # | |
| # ############## # | |
| app.params['uploaded_files_rows'] = [] | |
| if app.params['source_id'] == 1: | |
| # https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader | |
| # https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb | |
| st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""") | |
| uploaded_files = st.sidebar.file_uploader( | |
| "Choose files", | |
| accept_multiple_files=True, | |
| type=['pdf', 'PDF', | |
| 'txt', 'TXT', | |
| 'html', | |
| 'docx', 'DOCX', | |
| 'pptx', 'PPTX', | |
| ], | |
| ) | |
| uploaded_files_dataset = request_pathname(uploaded_files) | |
| uploaded_files_df = pd.DataFrame( | |
| uploaded_files_dataset, | |
| columns=['filepath', 'citation string']) | |
| uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df) | |
| uploaded_files_grid_options_builder.configure_selection( | |
| selection_mode='multiple', | |
| pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [], | |
| use_checkbox=True, | |
| ) | |
| uploaded_files_grid_options_builder.configure_column("citation string", editable=True) | |
| uploaded_files_grid_options_builder.configure_auto_height() | |
| uploaded_files_grid_options = uploaded_files_grid_options_builder.build() | |
| with st.sidebar: | |
| uploaded_files_ag_grid = AgGrid( | |
| uploaded_files_df, | |
| gridOptions=uploaded_files_grid_options, | |
| update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, | |
| ) | |
| app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"] | |
| app.params['urls_df'] = pd.DataFrame() | |
| if app.params['source_id'] == 3: | |
| st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""") | |
| # option 1: with streamlit version 1.20.0+ | |
| # app.params['urls_df'] = st.sidebar.experimental_data_editor( | |
| # pd.DataFrame([["", ""]], columns=['urls', 'citation string']), | |
| # use_container_width=True, | |
| # num_rows="dynamic", | |
| # ) | |
| # option 2: with streamlit version 1.19.0 | |
| urls_dataset = [["", ""], | |
| ["", ""], | |
| ["", ""], | |
| ["", ""], | |
| ["", ""]] | |
| urls_df = pd.DataFrame( | |
| urls_dataset, | |
| columns=['urls', 'citation string']) | |
| urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df) | |
| urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True) | |
| urls_grid_options_builder.configure_auto_height() | |
| urls_grid_options = urls_grid_options_builder.build() | |
| with st.sidebar: | |
| urls_ag_grid = AgGrid( | |
| urls_df, | |
| gridOptions=urls_grid_options, | |
| update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, | |
| ) | |
| df = urls_ag_grid.data | |
| df = df[df.urls != ""] | |
| app.params['urls_df'] = df | |
| if app.params['source_id'] in (1, 2, 3): | |
| st.sidebar.markdown("""6. Build an index where you can ask""") | |
| api_keys_ready = check_api_keys() | |
| source_ready = check_sources() | |
| enable_index_button = api_keys_ready and source_ready | |
| if st.sidebar.button("Build index", disabled=not enable_index_button): | |
| collect_dataset_and_built_index() | |
| def main(): | |
| configure_streamlit_and_page() | |
| load_sidebar_page() | |
| load_main_page() | |
| def on_enter(): | |
| output = get_answer() | |
| if output: | |
| st.session_state.past.append(st.session_state['user_input']) | |
| st.session_state.generated.append(output.answer) | |
| st.session_state.contexts.append(output.context) | |
| st.session_state.chunks.append(output.chunks) | |
| st.session_state.costs.append(output.cost_str) | |
| st.session_state['user_input'] = "" | |
| def request_pathname(files): | |
| if not files: | |
| return [["", ""]] | |
| # check if temporal directory exist, if not create it | |
| if not Path.exists(TEMP_DIR): | |
| TEMP_DIR.mkdir( | |
| parents=True, | |
| exist_ok=True, | |
| ) | |
| file_paths = [] | |
| for file in files: | |
| # # absolut path | |
| # file_path = str(TEMP_DIR / file.name) | |
| # relative path | |
| file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR)) | |
| file_paths.append(file_path) | |
| with open(file_path, "wb") as f: | |
| f.write(file.getbuffer()) | |
| return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)] | |
| def validate_status(): | |
| source_point_ready = check_source_point() | |
| combination_point_ready = check_combination_point() | |
| index_point_ready = check_index_point() | |
| params_point_ready = check_params_point() | |
| sources_ready = check_sources() | |
| index_ready = check_index() | |
| if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready: | |
| app.params['status'] = "✨Ready✨" | |
| elif not source_point_ready: | |
| app.params['status'] = "⚠️Review step 1 on the sidebar." | |
| elif not combination_point_ready: | |
| app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..." | |
| elif not index_point_ready: | |
| app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment." | |
| elif not params_point_ready: | |
| app.params['status'] = "⚠️Review step 4 on the sidebar" | |
| elif not sources_ready: | |
| app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..." | |
| elif not index_ready: | |
| app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..." | |
| else: | |
| app.params['status'] = "⚠️Something is not ready..." | |
| class StreamlitLangchainChatApp(): | |
| def __init__(self) -> None: | |
| """Use __init__ to define instance variables. It cannot have any arguments.""" | |
| self.params = dict() | |
| def run(self, **state) -> None: | |
| """Define here all logic required by your application.""" | |
| main() | |
| if __name__ == "__main__": | |
| app = StreamlitLangchainChatApp() | |
| app.run() | |