Spaces:
Runtime error
Runtime error
| from io import StringIO | |
| import streamlit as st | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter, Language | |
| import time | |
| import vector_db as vdb | |
| from llm_model import LLMModel | |
| def default_state(): | |
| if "startup" not in st.session_state: | |
| st.session_state.startup = True | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "uploaded_docs" not in st.session_state: | |
| st.session_state.uploaded_docs = [] | |
| if "llm_option" not in st.session_state: | |
| st.session_state.llm_option = "Local" | |
| if "answer_loading" not in st.session_state: | |
| st.session_state.answer_loading = False | |
| def load_doc(file_name: str, file_content: str): | |
| if file_name is not None: | |
| # Create document with metadata | |
| doc = Document(page_content=file_content, metadata={"source": file_name}) | |
| # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. | |
| # It splits text into chunks of 1000 characters each with a 150-character overlap. | |
| language = get_language(file_name) | |
| text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150, | |
| language=language) | |
| # Split the text into chunks using the text splitter. | |
| docs = text_splitter.split_documents([doc]) | |
| return docs | |
| else: | |
| return None | |
| def get_language(file_name: str): | |
| if file_name.endswith(".md") or file_name.endswith(".mdx"): | |
| return Language.MARKDOWN | |
| elif file_name.endswith(".rst"): | |
| return Language.RST | |
| else: | |
| return Language.MARKDOWN | |
| def get_vector_db(): | |
| return vdb.VectorDB() | |
| def get_llm_model(_db: vdb.VectorDB): | |
| retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) | |
| return LLMModel(retriever=retriever).create_qa_chain() | |
| # Initialize an instance of the RetrievalQA class with the specified parameters | |
| def init_sidebar(): | |
| with st.sidebar: | |
| st.toggle( | |
| "Loading from LLM", | |
| on_change=enable_sidebar(), | |
| disabled=not st.session_state.answer_loading | |
| ) | |
| llm_option = st.selectbox( | |
| 'Select to use local model or inference API', | |
| options=['Local', 'Inference API'] | |
| ) | |
| st.session_state.llm_option = llm_option | |
| uploaded_files = st.file_uploader( | |
| 'Upload file(s)', | |
| type=['md', 'mdx', 'rst', 'txt'], | |
| accept_multiple_files=True | |
| ) | |
| for uploaded_file in uploaded_files: | |
| if uploaded_file.name not in st.session_state.uploaded_docs: | |
| # Read the file as a string | |
| stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) | |
| string_data = stringio.read() | |
| # Get chunks of text | |
| doc_chunks = load_doc(uploaded_file.name, string_data) | |
| st.write(f"Number of chunks={len(doc_chunks)}") | |
| vector_db.load_docs_into_vector_db(doc_chunks) | |
| st.session_state.uploaded_docs.append(uploaded_file.name) | |
| def init_chat(): | |
| # Display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| def disable_sidebar(): | |
| st.session_state.answer_loading = True | |
| st.rerun() | |
| def enable_sidebar(): | |
| st.session_state.answer_loading = False | |
| st.set_page_config(page_title="Document Answering Tool", page_icon=":book:") | |
| vector_db = get_vector_db() | |
| default_state() | |
| init_sidebar() | |
| st.header("Document answering tool") | |
| st.subheader("Upload your documents on the side and ask questions") | |
| init_chat() | |
| llm_model = get_llm_model(vector_db) | |
| st.session_state.startup = False | |
| # React to user input | |
| if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()): | |
| # if st.session_state.answer_loading: | |
| # st.warning("Cannot ask multiple questions at the same time") | |
| # st.session_state.answer_loading = False | |
| # else: | |
| start_time = time.time() | |
| # Display user message in chat message container | |
| with st.chat_message("user"): | |
| st.markdown(user_prompt) | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": user_prompt}) | |
| if llm_model is not None: | |
| assistant_chat = st.chat_message("assistant") | |
| if not st.session_state.uploaded_docs: | |
| assistant_chat.warning("WARN: Will try answer question without documents") | |
| with st.spinner('Resolving question...'): | |
| res = llm_model({"query": user_prompt}) | |
| sources = [] | |
| for source_docs in res['source_documents']: | |
| if 'source' in source_docs.metadata: | |
| sources.append(source_docs.metadata['source']) | |
| # Display assistant response in chat message container | |
| end_time = time.time() | |
| time_taken = "{:.2f}".format(end_time - start_time) | |
| format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" | |
| assistant_chat.markdown(format_answer) | |
| source_expander = assistant_chat.expander("See full sources") | |
| for source_docs in res['source_documents']: | |
| if 'source' in source_docs.metadata: | |
| format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" | |
| source_expander.markdown(format_source) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": format_answer}) | |
| enable_sidebar() | |
| st.rerun() | |