Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import shutil | |
| import schedule | |
| import time | |
| import pickle | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.retrievers import ParentDocumentRetriever | |
| from langchain.storage import InMemoryStore | |
| from langchain_chroma import Chroma | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from datetime import date | |
| import time | |
| import subprocess | |
| import threading | |
| llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1'] | |
| blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1" | |
| # Environment variables | |
| directory_path = "ohw_proj_chorma_db" | |
| file_path = "ohw_proj_chorma_db.pcl" | |
| # Function to update your retriever | |
| # Function to update your retriever | |
| def load_from_pickle(filename): | |
| with open(filename, "rb") as file: | |
| return pickle.load(file) | |
| def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter): | |
| """Loads the vector store and document store, initializing the retriever.""" | |
| db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time | |
| embedding_function=embeddings, | |
| persist_directory=chroma_path | |
| ) | |
| store_dict = load_from_pickle(docstore_path) | |
| store = InMemoryStore() | |
| store.mset(list(store_dict.items())) | |
| retriever = ParentDocumentRetriever( | |
| vectorstore=db3, | |
| docstore=store, | |
| child_splitter=child_splitter, | |
| parent_splitter=parent_splitter, | |
| search_kwargs={"k": 5} | |
| ) | |
| return retriever | |
| def inspect(state): | |
| if "context_sources" not in st.session_state: | |
| st.session_state.context_sources = [] | |
| context = state['normal_context'] | |
| st.session_state.context_sources =[doc.metadata['source'] for doc in context] | |
| st.session_state.context_content = [doc.page_content for doc in context] | |
| return state | |
| def retrieve_normal_context(retriever, question): | |
| docs = retriever.invoke(question) | |
| return docs | |
| # Your OLMOLLM class implementation here (adapted for the Hugging Face model) | |
| def get_chain(temperature,selected_model): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") | |
| docstore_path = 'ohw_proj_chorma_db.pcl' | |
| chroma_path = 'ohw_proj_chorma_db' | |
| parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, | |
| chunk_overlap=500) | |
| # create the child documents - The small chunks | |
| child_splitter = RecursiveCharacterTextSplitter(chunk_size=300, | |
| chunk_overlap=50) | |
| retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter) | |
| # llm_api = 'glpat-AMzMevbqaVjp4HbLcVum' | |
| llm_api = os.getenv("blablador_api") | |
| llm = ChatOpenAI(model_name=selected_model, | |
| temperature=temperature, | |
| openai_api_key=llm_api, | |
| openai_api_base=blablador_base, | |
| streaming=True) | |
| today = date.today() | |
| # Response prompt | |
| response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant. | |
| Keep track of chat history: {chat_history} | |
| Today's date: {date} | |
| ## Normal Context: | |
| {normal_context} | |
| # Original Question: {question} | |
| # Answer: | |
| """ | |
| response_prompt = ChatPromptTemplate.from_template(response_prompt_template) | |
| context_chain = RunnableLambda(lambda x: { | |
| "question": x["question"], | |
| "normal_context": retrieve_normal_context(retriever,x["question"]), | |
| # "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})), | |
| "chat_history": x["chat_history"], | |
| "date": today}) | |
| chain = ( | |
| context_chain | |
| | RunnableLambda(inspect) | |
| | response_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return chain | |
| def clear_chat_history(): | |
| st.session_state.messages = [] | |
| st.session_state.context_sources = [] | |
| st.session_state.key = 0 | |
| # Sidebar | |
| with st.sidebar: | |
| st.image("logo_no_bg.png", use_column_width=True) # Replace with your actual file path | |
| st.title("OHW Assistant") | |
| selected_model = st.sidebar.selectbox('Choose a LLM model', | |
| llm_list, | |
| key='selected_model', | |
| index = 1) | |
| temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1, | |
| help=("Controls the creativity of responses.\n" | |
| "Lower values make answers more focused.\n" | |
| "Higher values introduce more variety.")) | |
| if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1']: | |
| if selected_model == 'Mistral-7B-Instruct-v0.2': | |
| selected_model = 'alias-fast' | |
| elif selected_model == 'Mixtral-8x7B-Instruct-v0.1': | |
| selected_model = 'alias-large' | |
| chain = get_chain(temperature,selected_model) | |
| st.button('Clear Chat History', on_click=clear_chat_history) | |
| # Main app | |
| # Initialize session state variables | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "context_sources" not in st.session_state: | |
| st.session_state.context_sources = [] | |
| if "context_content" not in st.session_state: | |
| st.session_state.context_content = [] | |
| for q, message in enumerate(st.session_state.messages): | |
| if (message["role"] == 'assistant'): | |
| with st.chat_message(message["role"]): | |
| tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
| with tab1: | |
| st.markdown(message["content"]) | |
| with tab2: | |
| for i, source in enumerate(message["sources"]): | |
| name = f'{source}' | |
| with st.expander(name): | |
| st.markdown(f'{message["context"][i]}') | |
| else: | |
| question = message["content"] | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("How may I assist you today?"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| query=st.session_state.messages[-1]['content'] | |
| tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
| with tab1: | |
| start_time = time.time() | |
| placeholder = st.empty() # Create a placeholder in Streamlit | |
| full_answer = "" | |
| for chunk in chain.stream({"question": query, "chat_history":st.session_state.messages}): | |
| full_answer += chunk | |
| placeholder.markdown(full_answer,unsafe_allow_html=True) | |
| end_time = time.time() | |
| st.caption(f"Response time: {end_time - start_time:.2f} seconds") | |
| with tab2: | |
| if st.session_state.context_sources: | |
| for i, source in enumerate(st.session_state.context_sources): | |
| name = f'{source}' | |
| with st.expander(name): | |
| st.markdown(f'{st.session_state.context_content[i]}') | |
| else: | |
| st.write("No sources available for this query.") | |
| st.session_state.messages.append({"role": "assistant", "content": full_answer}) | |
| st.session_state.messages[-1]['sources'] = st.session_state.context_sources | |
| st.session_state.messages[-1]['context'] = st.session_state.context_content | |