Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_cohere import ChatCohere | |
| from langchain_community.chat_message_histories.streamlit import ( | |
| StreamlitChatMessageHistory, | |
| ) | |
| from langchain_openai import ChatOpenAI | |
| from calback_handler import PrintRetrievalHandler, StreamHandler | |
| from chat_profile import ChatProfileRoleEnum | |
| from document_retriever import configure_retriever | |
| from llm_provider import LLMProviderEnum | |
| # Constants | |
| GPT_LLM_MODEL = "gpt-3.5-turbo" | |
| COMMAND_R_LLM_MODEL = "command-r" | |
| # Properties | |
| uploaded_files = [] | |
| api_key = "" | |
| result_retriever = None | |
| chain = None | |
| llm = None | |
| model_name = "" | |
| # Set up sidebar | |
| if "sidebar_state" not in st.session_state: | |
| st.session_state.sidebar_state = "expanded" | |
| # Streamlit app configuration | |
| st.set_page_config( | |
| page_title="InkChatGPT: Chat with Documents", | |
| page_icon="π", | |
| initial_sidebar_state=st.session_state.sidebar_state, | |
| menu_items={ | |
| "Get Help": "https://x.com/vinhnx", | |
| "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues", | |
| "About": """InkChatGPT is a simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents. | |
| GitHub: https://github.com/vinhnx/InkChatGPT""", | |
| }, | |
| ) | |
| with st.sidebar: | |
| with st.container(): | |
| col1, col2 = st.columns([0.2, 0.8]) | |
| with col1: | |
| st.image( | |
| "./assets/app_icon.png", | |
| use_column_width="always", | |
| output_format="PNG", | |
| ) | |
| with col2: | |
| st.header(":books: InkChatGPT") | |
| # Model | |
| selected_model = st.selectbox( | |
| "Select a model", | |
| options=[ | |
| LLMProviderEnum.OPEN_AI.value, | |
| LLMProviderEnum.COHERE.value, | |
| ], | |
| index=None, | |
| placeholder="Select a model...", | |
| ) | |
| if selected_model: | |
| api_key = st.text_input(f"{selected_model} API Key", type="password") | |
| if selected_model == LLMProviderEnum.OPEN_AI: | |
| model_name = GPT_LLM_MODEL | |
| elif selected_model == LLMProviderEnum.COHERE: | |
| model_name = COMMAND_R_LLM_MODEL | |
| msgs = StreamlitChatMessageHistory() | |
| if len(msgs.messages) == 0: | |
| msgs.clear() | |
| msgs.add_ai_message(""" | |
| Hi, your uploaded document(s) had been analyzed. | |
| Feel free to ask me any questions. For example: you can start by asking me something like: | |
| `What is this context about?` | |
| `Help me summarize this!` | |
| """) | |
| if api_key: | |
| # Documents | |
| uploaded_files = st.file_uploader( | |
| label="Select files", | |
| type=["pdf", "txt", "docx"], | |
| accept_multiple_files=True, | |
| disabled=(not selected_model), | |
| ) | |
| if api_key and not uploaded_files: | |
| st.info("π You can upload some documents to get started") | |
| # Check if a model is selected | |
| if not selected_model: | |
| st.info( | |
| "πΊ Please select a model first, open the `Settings` tab from side bar menu to get started" | |
| ) | |
| # Check if API key is provided | |
| if selected_model and len(api_key.strip()) == 0: | |
| st.warning( | |
| f"π API key for {selected_model} is missing or invalid. Please provide a valid API key." | |
| ) | |
| # Process uploaded files | |
| if uploaded_files: | |
| result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key) | |
| if result_retriever is not None: | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| chat_memory=msgs, | |
| return_messages=True, | |
| ) | |
| if selected_model == LLMProviderEnum.OPEN_AI: | |
| llm = ChatOpenAI( | |
| model=model_name, | |
| api_key=api_key, | |
| temperature=0, | |
| streaming=True, | |
| ) | |
| elif selected_model == LLMProviderEnum.COHERE: | |
| llm = ChatCohere( | |
| model=model_name, | |
| temperature=0.3, | |
| streaming=True, | |
| cohere_api_key=api_key, | |
| ) | |
| if llm is None: | |
| st.error( | |
| "Failed to initialize the language model. Please check your configuration." | |
| ) | |
| # Create the ConversationalRetrievalChain instance using the llm instance | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=result_retriever, | |
| memory=memory, | |
| verbose=True, | |
| max_tokens_limit=4000, | |
| ) | |
| avatars = { | |
| ChatProfileRoleEnum.HUMAN.value: "user", | |
| ChatProfileRoleEnum.AI.value: "assistant", | |
| } | |
| for msg in msgs.messages: | |
| st.chat_message(avatars[msg.type]).write(msg.content) | |
| # Get user input and generate response | |
| if user_query := st.chat_input( | |
| placeholder="Ask me anything!", | |
| disabled=(not uploaded_files), | |
| ): | |
| st.chat_message("user").write(user_query) | |
| with st.chat_message("assistant"): | |
| retrieval_handler = PrintRetrievalHandler(st.empty()) | |
| stream_handler = StreamHandler(st.empty()) | |
| response = chain.run( | |
| user_query, | |
| callbacks=[retrieval_handler, stream_handler], | |
| ) | |
| if selected_model and model_name: | |
| st.sidebar.caption(f"πͺ Using `{model_name}` model") | |