Spaces:
Sleeping
Sleeping
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.callbacks import StreamlitCallbackHandler | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from manage_vectordb import VectorDB | |
| import tempfile | |
| import streamlit as st | |
| import os | |
| model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001") | |
| model_service = f"{model_service}/v1" | |
| model_service_bearer = os.getenv("MODEL_ENDPOINT_BEARER") | |
| model_name = os.getenv("MODEL_NAME", "") | |
| chunk_size = os.getenv("CHUNK_SIZE", 150) | |
| embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") | |
| vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") | |
| vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") | |
| vdb_port = os.getenv("VECTORDB_PORT", "8000") | |
| vdb_name = os.getenv("VECTORDB_NAME", "test_collection") | |
| vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model) | |
| vectorDB_client = vdb.connect() | |
| def split_docs(raw_documents): | |
| text_splitter = CharacterTextSplitter(separator = ".", | |
| chunk_size=int(chunk_size), | |
| chunk_overlap=0) | |
| docs = text_splitter.split_documents(raw_documents) | |
| return docs | |
| def read_file(file): | |
| file_type = file.type | |
| if file_type == "application/pdf": | |
| temp = tempfile.NamedTemporaryFile() | |
| with open(temp.name, "wb") as f: | |
| f.write(file.getvalue()) | |
| loader = PyPDFLoader(temp.name) | |
| if file_type == "text/plain": | |
| temp = tempfile.NamedTemporaryFile() | |
| with open(temp.name, "wb") as f: | |
| f.write(file.getvalue()) | |
| loader = TextLoader(temp.name) | |
| raw_documents = loader.load() | |
| return raw_documents | |
| st.title("π RAG DEMO") | |
| with st.sidebar: | |
| file = st.file_uploader(label="π Upload Document", | |
| type=[".txt",".pdf"], | |
| on_change=vdb.clear_db | |
| ) | |
| ### populate the DB #### | |
| if file != None: | |
| text = read_file(file) | |
| documents = split_docs(text) | |
| db = vdb.populate_db(documents) | |
| retriever = db.as_retriever(threshold=0.75) | |
| else: | |
| retriever = {} | |
| print("Empty VectorDB") | |
| ######################## | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [{"role": "assistant", | |
| "content": "How can I help you?"}] | |
| for msg in st.session_state.messages: | |
| st.chat_message(msg["role"]).write(msg["content"]) | |
| llm = ChatOpenAI(base_url=model_service, | |
| api_key="EMPTY" if model_service_bearer is None else model_service_bearer, | |
| model=model_name, | |
| streaming=True, | |
| callbacks=[StreamlitCallbackHandler(st.container(), | |
| collapse_completed_thoughts=True)]) | |
| prompt = ChatPromptTemplate.from_template("""Answer the question based only on the following context: | |
| {context} | |
| Question: {input} | |
| """ | |
| ) | |
| chain = ( | |
| {"context": retriever, "input": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| ) | |
| if prompt := st.chat_input(): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| st.chat_message("user").markdown(prompt) | |
| response = chain.invoke(prompt) | |
| st.chat_message("assistant").markdown(response.content) | |
| st.session_state.messages.append({"role": "assistant", "content": response.content}) | |
| st.rerun() | |