Spaces:
Runtime error
Runtime error
| from dotenv import load_dotenv | |
| import pandas as pd | |
| import streamlit as st | |
| import streamlit_authenticator as stauth | |
| from streamlit_modal import Modal | |
| from utils import new_file, clear_memory, append_documentation_to_sidebar, load_authenticator_config, init_qa, \ | |
| append_header | |
| from haystack.document_stores.in_memory import InMemoryDocumentStore | |
| from haystack import Document | |
| load_dotenv() | |
| OPENAI_MODELS = ['gpt-3.5-turbo', | |
| "gpt-4", | |
| "gpt-4-1106-preview"] | |
| OPEN_MODELS = [ | |
| 'mistralai/Mistral-7B-Instruct-v0.1', | |
| 'HuggingFaceH4/zephyr-7b-beta' | |
| ] | |
| def reset_chat_memory(): | |
| st.button( | |
| 'Reset chat memory', | |
| key="reset-memory-button", | |
| on_click=clear_memory, | |
| help="Clear the conversational memory. Currently implemented to retain the 4 most recent messages.", | |
| disabled=False) | |
| def manage_files(modal, document_store): | |
| open_modal = st.sidebar.button("Manage Files", use_container_width=True) | |
| if open_modal: | |
| modal.open() | |
| if modal.is_open(): | |
| with modal.container(): | |
| uploaded_file = st.file_uploader( | |
| "Upload a document in PDF format", | |
| type=("pdf",), | |
| on_change=new_file(), | |
| disabled=st.session_state['document_qa_model'] is None, | |
| label_visibility="collapsed", | |
| help="The document is used to answer your questions. The system will process the document and store it in a RAG to answer your questions.", | |
| ) | |
| edited_df = st.data_editor(use_container_width=True, data=st.session_state['files'], | |
| num_rows='dynamic', | |
| column_order=['name', 'size', 'is_active'], | |
| column_config={'name': {'editable': False}, 'size': {'editable': False}, | |
| 'is_active': {'editable': True, 'type': 'checkbox', | |
| 'width': 100}} | |
| ) | |
| st.session_state['files'] = pd.DataFrame(columns=['name', 'content', 'size', 'is_active']) | |
| if uploaded_file: | |
| st.session_state['file_uploaded'] = True | |
| st.session_state['files'] = pd.concat([st.session_state['files'], edited_df]) | |
| with st.spinner('Processing the document content...'): | |
| store_file_in_table(document_store, uploaded_file) | |
| ingest_document(uploaded_file) | |
| def ingest_document(uploaded_file): | |
| if not st.session_state['document_qa_model']: | |
| st.warning('Please select a model to start asking questions') | |
| else: | |
| try: | |
| st.session_state['document_qa_model'].ingest_pdf(uploaded_file) | |
| st.success('Document processed successfully') | |
| except Exception as e: | |
| st.error(f"Error processing the document: {e}") | |
| st.session_state['file_uploaded'] = False | |
| def store_file_in_table(document_store, uploaded_file): | |
| pdf_content = uploaded_file.getvalue() | |
| st.session_state['pdf_content'] = pdf_content | |
| st.session_state.messages = [] | |
| document = Document(content=pdf_content, meta={"name": uploaded_file.name}) | |
| df = pd.DataFrame(st.session_state['files']) | |
| df['is_active'] = False | |
| st.session_state['files'] = pd.concat([df, pd.DataFrame( | |
| [{"name": uploaded_file.name, "content": pdf_content, "size": len(pdf_content), | |
| "is_active": True}])]) | |
| document_store.write_documents([document]) | |
| def init_session_state(): | |
| st.session_state.setdefault('files', pd.DataFrame(columns=['name', 'content', 'size', 'is_active'])) | |
| st.session_state.setdefault('models', []) | |
| st.session_state.setdefault('api_keys', {}) | |
| st.session_state.setdefault('current_selected_model', 'gpt-3.5-turbo') | |
| st.session_state.setdefault('current_api_key', '') | |
| st.session_state.setdefault('messages', []) | |
| st.session_state.setdefault('pdf_content', None) | |
| st.session_state.setdefault('memory', None) | |
| st.session_state.setdefault('pdf', None) | |
| st.session_state.setdefault('document_qa_model', None) | |
| st.session_state.setdefault('file_uploaded', False) | |
| def set_page_config(): | |
| st.set_page_config( | |
| page_title="AI Audit Assistant", | |
| page_icon=":shark:", | |
| initial_sidebar_state="expanded", | |
| layout="wide", | |
| menu_items={ | |
| 'Get Help': 'https://www.extremelycoolapp.com/help', | |
| 'Report a bug': "https://www.extremelycoolapp.com/bug", | |
| 'About': "# This is a header. This is an *extremely* cool app!" | |
| } | |
| ) | |
| def update_running_model(api_key, model): | |
| st.session_state['api_keys'][model] = api_key | |
| st.session_state['document_qa_model'] = init_qa(model, api_key) | |
| def init_api_key_dict(): | |
| st.session_state['models'] = OPENAI_MODELS + list(OPEN_MODELS) + ['local LLM'] | |
| for model_name in OPENAI_MODELS: | |
| st.session_state['api_keys'][model_name] = None | |
| def display_chat_messages(chat_box, chat_input): | |
| with chat_box: | |
| if chat_input: | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"], unsafe_allow_html=True) | |
| st.chat_message("user").markdown(chat_input) | |
| with st.chat_message("assistant"): | |
| # process user input and generate response | |
| response = st.session_state['document_qa_model'].inference(chat_input, st.session_state.messages) | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "user", "content": chat_input}) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| def setup_model_selection(): | |
| model = st.selectbox( | |
| "Model:", | |
| options=st.session_state['models'], | |
| index=0, # default to the first model in the list gpt-3.5-turbo | |
| placeholder="Select model", | |
| help="Select an LLM:" | |
| ) | |
| if model: | |
| if model != st.session_state['current_selected_model']: | |
| st.session_state['current_selected_model'] = model | |
| if model == 'local LLM': | |
| st.session_state['document_qa_model'] = init_qa(model) | |
| api_key = st.sidebar.text_input("Enter LLM-authorization Key:", type="password", | |
| disabled=st.session_state['current_selected_model'] == 'local LLM') | |
| if api_key and api_key != st.session_state['current_api_key']: | |
| update_running_model(api_key, model) | |
| st.session_state['current_api_key'] = api_key | |
| return model | |
| def setup_task_selection(model): | |
| # enable extractive and generative tasks if we're using a local LLM or an OpenAI model with an API key | |
| if model == 'local LLM' or st.session_state['api_keys'].get(model): | |
| task_options = ['Extractive', 'Generative'] | |
| else: | |
| task_options = ['Extractive'] | |
| task_selection = st.sidebar.radio('Select the task:', task_options) | |
| # TODO: Add the task selection logic here (initializing the model based on the task) | |
| def setup_page_body(): | |
| chat_box = st.container(height=350, border=False) | |
| chat_input = st.chat_input( | |
| placeholder="Upload a document to start asking questions...", | |
| disabled=not st.session_state['file_uploaded'], | |
| ) | |
| if st.session_state['file_uploaded']: | |
| display_chat_messages(chat_box, chat_input) | |
| class StreamlitApp: | |
| def __init__(self): | |
| self.authenticator_config = load_authenticator_config() | |
| self.document_store = InMemoryDocumentStore() | |
| set_page_config() | |
| self.authenticator = self.init_authenticator() | |
| init_session_state() | |
| init_api_key_dict() | |
| def init_authenticator(self): | |
| return stauth.Authenticate( | |
| self.authenticator_config['credentials'], | |
| self.authenticator_config['cookie']['name'], | |
| self.authenticator_config['cookie']['key'], | |
| self.authenticator_config['cookie']['expiry_days'] | |
| ) | |
| def setup_sidebar(self): | |
| with st.sidebar: | |
| st.sidebar.image("resources/ml_logo.png", use_column_width=True) | |
| # Sidebar for Task Selection | |
| st.sidebar.header('Options:') | |
| model = setup_model_selection() | |
| setup_task_selection(model) | |
| st.divider() | |
| self.authenticator.logout() | |
| reset_chat_memory() | |
| modal = Modal("Manage Files", key="demo-modal") | |
| manage_files(modal, self.document_store) | |
| st.divider() | |
| append_documentation_to_sidebar() | |
| def run(self): | |
| name, authentication_status, username = self.authenticator.login() | |
| if authentication_status: | |
| self.run_authenticated_app() | |
| elif st.session_state["authentication_status"] is False: | |
| st.error('Username/password is incorrect') | |
| elif st.session_state["authentication_status"] is None: | |
| st.warning('Please enter your username and password') | |
| def run_authenticated_app(self): | |
| self.setup_sidebar() | |
| append_header() | |
| setup_page_body() | |
| app = StreamlitApp() | |
| app.run() | |