Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| import tempfile | |
| import zipfile | |
| import os | |
| st.title('Testing and QA') | |
| # Dynamically load the selected models from the session state | |
| EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small") | |
| LLM_MODEL_NAME = st.session_state.get('selected_llm_model', "mistralai/Mistral-7B-Instruct-v0.2") | |
| # Initialization block for embedding_model, with a debug message | |
| if 'embedding_model' not in st.session_state: | |
| EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small") | |
| st.session_state['embedding_model'] = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL_NAME, | |
| multi_process=True, | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| st.info("embedding_model has been initialized.") # Debug message for initialization | |
| else: | |
| st.info("embedding_model was already initialized.") # Debug message if already initialized | |
| # Now that we've ensured embedding_model is initialized, we can safely access it | |
| embedding_model = st.session_state['embedding_model'] | |
| st.write("Accessing embedding_model...") # Debug message for accessing | |
| # Form for LLM settings, allowing dynamic model selection | |
| with st.form("llm_settings_form"): | |
| st.subheader("LLM Settings") | |
| repo_id = st.text_input("Repo ID", value=LLM_MODEL_NAME, key="repo_id") | |
| max_new_tokens = st.number_input("Max New Tokens", value=250, key="max_new_tokens") | |
| top_k = st.number_input("Top K", value=3, key="top_k") | |
| top_p = st.number_input("Top P", value=0.95, key="top_p") | |
| typical_p = st.number_input("Typical P", value=0.95, key="typical_p") | |
| temperature = st.number_input("Temperature", value=0.01, key="temperature") | |
| repetition_penalty = st.number_input("Repetition Penalty", value=1.035, key="repetition_penalty") | |
| submitted = st.form_submit_button("Update LLM Settings") | |
| if submitted: | |
| st.session_state['llm'] = HuggingFaceEndpoint( | |
| repo_id=repo_id, | |
| max_new_tokens=max_new_tokens, | |
| top_k=top_k, | |
| top_p=top_p, | |
| typical_p=typical_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| st.success("LLM settings updated.") | |
| # Vector store upload and setup | |
| if 'collection_vectorstore' not in st.session_state: | |
| uploaded_file = st.file_uploader("Upload Vector Store ZIP", type=["zip"]) | |
| if uploaded_file is not None: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| with zipfile.ZipFile(uploaded_file, 'r') as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| docs_vectors_path = os.path.join(temp_dir, "docs_vectors") | |
| st.session_state['collection_vectorstore'] = FAISS.load_local(docs_vectors_path, embeddings=embedding_model, allow_dangerous_deserialization=True) | |
| st.success("Vector store uploaded and loaded successfully.") | |
| # Create the retriever as soon as the vector store is created | |
| st.session_state['retriever'] = st.session_state['collection_vectorstore'].as_retriever() | |
| st.info("Retriever has been created.") # Debug message to confirm the retriever's creation | |
| # Check if LLM and vector store are ready | |
| if 'llm' in st.session_state and 'collection_vectorstore' in st.session_state: | |
| # Use a button to indicate when to update the prompt template | |
| if st.button("Update Prompt Template"): | |
| # Assuming you have a text area where users input the new template | |
| new_template = st.text_area("Enter new prompt template", key="new_prompt_template") | |
| # Update the session state only when the button is pressed | |
| st.session_state['prompt_template'] = new_template | |
| st.success("Prompt template updated.") | |
| # Ensure there's a default prompt template | |
| if 'prompt_template' not in st.session_state: | |
| st.session_state['prompt_template'] = "You are a knowledgeable assistant answering the following question based on the provided documents: {context} Question: {question}" | |
| # Display the current template for editing | |
| current_template = st.text_area("Edit Prompt Template", value=st.session_state['prompt_template'], key="current_prompt_template") | |
| # Question input and processing | |
| question = st.text_input("Enter your question", key="question_input") | |
| if question: | |
| llm = st.session_state['llm'] | |
| prompt = ChatPromptTemplate.from_template(current_template) | |
| retriever = st.session_state['retriever'] | |
| chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| if st.button("Ask"): | |
| result = chain.invoke(question) | |
| st.subheader("Answer:") | |
| st.write(result) | |
| else: | |
| st.warning("Please configure and submit the LLM settings and ensure the vector store is loaded to ask questions.") |