Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import pathlib | |
| from typing import List | |
| # local imports | |
| from models.llms import load_llm, integrated_llms | |
| from models.embeddings import openai_embed_model | |
| from models.llamaCustom import LlamaCustom | |
| # from models.llamaCustomV2 import LlamaCustomV2 | |
| from models.vector_database import get_pinecone_index | |
| from utils.chatbox import show_previous_messages, show_chat_input | |
| from utils.util import validate_openai_api_key | |
| # llama_index | |
| from llama_index.core import ( | |
| SimpleDirectoryReader, | |
| Document, | |
| VectorStoreIndex, | |
| StorageContext, | |
| Settings, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.vector_stores.pinecone import PineconeVectorStore | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from llama_index.core.base.llms.types import ChatMessage | |
| # huggingface | |
| from huggingface_hub import HfApi | |
| SAVE_DIR = "uploaded_files" | |
| VECTOR_STORE_DIR = "vectorStores" | |
| HF_REPO_ID = "zhtet/RegBotBeta" | |
| # global | |
| # Settings.embed_model = hf_embed_model | |
| Settings.embed_model = openai_embed_model | |
| # huggingface api | |
| hf_api = HfApi() | |
| def init_session_state(): | |
| if "llama_messages" not in st.session_state: | |
| st.session_state.llama_messages = [ | |
| {"role": "assistant", "content": "How can I help you today?"} | |
| ] | |
| # TODO: create a chat history for each different document | |
| if "llama_chat_history" not in st.session_state: | |
| st.session_state.llama_chat_history = [ | |
| ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
| ] | |
| if "llama_custom" not in st.session_state: | |
| st.session_state.llama_custom = None | |
| if "openai_api_key" not in st.session_state: | |
| st.session_state.openai_api_key = "" | |
| if "replicate_api_token" not in st.session_state: | |
| st.session_state.replicate_api_token = "" | |
| if "hf_token" not in st.session_state: | |
| st.session_state.hf_token = "" | |
| # @st.cache_resource | |
| def get_index( | |
| filename: str, | |
| ) -> VectorStoreIndex: | |
| """This function loads the index from storage if it exists, otherwise it creates a new index from the document.""" | |
| try: | |
| index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") | |
| if pathlib.Path.exists(index_path): | |
| print("Loading index from storage ...") | |
| storage_context = StorageContext.from_defaults(persist_dir=index_path) | |
| index = load_index_from_storage(storage_context=storage_context) | |
| else: | |
| reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) | |
| docs = reader.load_data(show_progress=True) | |
| index = VectorStoreIndex.from_documents( | |
| documents=docs, | |
| show_progress=True, | |
| ) | |
| index.storage_context.persist( | |
| persist_dir=f"vectorStores/{filename.replace('.', '_')}" | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise e | |
| return index | |
| def check_api_key(model_name: str, source: str): | |
| if source.startswith("openai"): | |
| if not st.session_state.openai_api_key: | |
| with st.expander("OpenAI API Key", expanded=True): | |
| openai_api_key = st.text_input( | |
| label="Enter your OpenAI API Key:", | |
| type="password", | |
| help="Get your key from https://platform.openai.com/account/api-keys", | |
| value=st.session_state.openai_api_key, | |
| ) | |
| if openai_api_key and st.spinner("Validating OpenAI API Key ..."): | |
| result = validate_openai_api_key(openai_api_key) | |
| if result["status"] == "success": | |
| st.session_state.openai_api_key = openai_api_key | |
| st.success(result["message"]) | |
| else: | |
| st.error(result["message"]) | |
| st.info("You can still select a different model to proceed.") | |
| st.stop() | |
| elif source.startswith("replicate"): | |
| if not st.session_state.replicate_api_token: | |
| with st.expander("Replicate API Token", expanded=True): | |
| replicate_api_token = st.text_input( | |
| label="Enter your Replicate API Token:", | |
| type="password", | |
| help="Get your key from https://replicate.ai/account", | |
| value=st.session_state.replicate_api_token, | |
| ) | |
| # TODO: need to validate the token | |
| if replicate_api_token: | |
| st.session_state.replicate_api_token = replicate_api_token | |
| # set the environment variable | |
| os.environ["REPLICATE_API_TOKEN"] = replicate_api_token | |
| elif source.startswith("huggingface"): | |
| if not st.session_state.hf_token: | |
| with st.expander("Hugging Face Token", expanded=True): | |
| hf_token = st.text_input( | |
| label="Enter your Hugging Face Token:", | |
| type="password", | |
| help="Get your key from https://huggingface.co/settings/token", | |
| value=st.session_state.hf_token, | |
| ) | |
| if hf_token: | |
| st.session_state.hf_token = hf_token | |
| # set the environment variable | |
| os.environ["HF_TOKEN"] = hf_token | |
| init_session_state() | |
| st.set_page_config(page_title="Llama", page_icon="🦙") | |
| st.header("California Drinking Water Regulation Chatbot - LlamaIndex Demo") | |
| tab1, tab2 = st.tabs(["Config", "Chat"]) | |
| with tab1: | |
| selected_llm_name = st.selectbox( | |
| label="Select a model:", | |
| options=[f"{key} | {value}" for key, value in integrated_llms.items()], | |
| ) | |
| model_name, source = selected_llm_name.split("|") | |
| check_api_key(model_name=model_name.strip(), source=source.strip()) | |
| selected_file = st.selectbox( | |
| label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) | |
| ) | |
| if st.button("Clear all api keys"): | |
| st.session_state.openai_api_key = "" | |
| st.session_state.replicate_api_token = "" | |
| st.session_state.hf_token = "" | |
| st.success("All API keys cleared!") | |
| st.rerun() | |
| if st.button("Submit", key="submit", help="Submit the form"): | |
| with st.status("Loading ...", expanded=True) as status: | |
| try: | |
| st.write("Loading Model ...") | |
| llama_llm = load_llm( | |
| model_name=model_name.strip(), source=source.strip() | |
| ) | |
| if llama_llm is None: | |
| raise ValueError("Model not found!") | |
| Settings.llm = llama_llm | |
| st.write("Processing Data ...") | |
| # index = get_index(selected_file) | |
| index = get_pinecone_index(selected_file) | |
| st.write("Finishing Up ...") | |
| llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) | |
| # llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index) | |
| st.session_state.llama_custom = llama_custom | |
| status.update(label="Ready to query!", state="complete", expanded=False) | |
| except Exception as e: | |
| status.update(label="Error!", state="error", expanded=False) | |
| st.error(f"Error: {e}") | |
| st.stop() | |
| with tab2: | |
| messages_container = st.container(height=300) | |
| show_previous_messages(framework="llama", messages_container=messages_container) | |
| show_chat_input( | |
| disabled=False, | |
| framework="llama", | |
| model=st.session_state.llama_custom, | |
| messages_container=messages_container, | |
| ) | |
| def clear_history(): | |
| messages_container.empty() | |
| st.session_state.llama_messages = [ | |
| {"role": "assistant", "content": "How can I help you today?"} | |
| ] | |
| st.session_state.llama_chat_history = [ | |
| ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
| ] | |
| if st.button("Clear Chat History"): | |
| clear_history() | |
| st.rerun() | |