Spaces:
Build error
Build error
| import streamlit as st | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.embeddings.cohere import CohereEmbeddings | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch | |
| from langchain.vectorstores import Chroma | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| from dotenv import load_dotenv | |
| from streamlit_chat import message | |
| import os | |
| import time | |
| import textwrap | |
| import threading | |
| load_dotenv() | |
| os.environ["OPENAI_API_KEY"] = os.environ['my_secret'] | |
| if 'chat_history' not in st.session_state: | |
| st.session_state['chat_history'] = [] | |
| if 'chain' not in st.session_state: | |
| llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0) | |
| loader = PyPDFLoader("/home/user/app/docs.pdf") | |
| documents = loader.load() | |
| text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0) | |
| texts = text_splitter.split_documents(documents) | |
| embeddings = OpenAIEmbeddings() | |
| vector_store = Chroma.from_documents(texts, embeddings) | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
| from langchain.prompts.chat import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| system_template="""You act like a successful pharmacist. Talk to students about the career path of a pharmacist. | |
| Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly. | |
| Given the following summaries of a long document and a question, create a final answer with references. | |
| If you don't know the answer, just say that "I don't know", don't try to make up an answer. | |
| ---------------- | |
| {summaries} | |
| You MUST answer in Korean and in Markdown format""" | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}") | |
| ] | |
| prompt = ChatPromptTemplate.from_messages(messages) | |
| chain_type_kwargs = {"prompt": prompt} | |
| st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs=chain_type_kwargs, | |
| reduce_k_below_max_tokens=True, | |
| verbose=True, | |
| ) | |
| def generate_response(user_input): | |
| result = st.session_state['chain'](user_input) | |
| bot_message = result['answer'] | |
| for i, doc in enumerate(result['source_documents']): | |
| bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') ' | |
| return bot_message | |
| def wrap_text(text, max_length=40): | |
| return '\n'.join(textwrap.wrap(text, max_length)) | |
| # st.header("[μμμ§λ‘ν΅] μ½μ¬μ κΈΈ \n μ€μ μ½μ¬μ μΈν°λ·° λ΄μ©μ κΈ°λ°μΌλ‘ μ§λ‘ μλ΄μ ν΄λ³΄μΈμ") | |
| # with st.form('form', clear_on_submit=True): | |
| # user_input = st.text_input('You: ', '', key='input') | |
| # submitted = st.form_submit_button('Send') | |
| # if submitted and user_input: | |
| # with st.spinner('μλ΅μ μμ±μ€μ λλ€...'): | |
| # output = generate_response(user_input) | |
| # st.session_state.chat_history.append({"User": user_input, "Bot": output}) | |
| # for idx, chat in enumerate(st.session_state['chat_history'][:-1]): | |
| # message(chat['User'], is_user=True, key=str(idx) + '_user') | |
| # message(wrap_text("μ½μ¬: " + chat['Bot']), key=str(idx)) | |
| # if st.session_state['chat_history']: | |
| # last_chat = st.session_state['chat_history'][-1] | |
| # message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user') | |
| # new_placeholder = st.empty() | |
| # sender_name = "μ½μ¬: " | |
| # for j in range(len(last_chat['Bot'])): | |
| # new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1])) | |
| # time.sleep(0.05) | |
| st.header("[μμμ§λ‘ν΅] μ½μ¬μ κΈΈ \n μ€μ μ½μ¬μ μΈν°λ·° λ΄μ©μ κΈ°λ°μΌλ‘ μ§λ‘ μλ΄μ ν΄λ³΄μΈμ") | |
| if st.session_state['chat_history']: | |
| for idx, chat in enumerate(st.session_state['chat_history'][:-1]): | |
| message(chat['User'], is_user=True, key=str(idx) + '_user') | |
| message(wrap_text("μ½μ¬: " + chat['Bot']), key=str(idx)) | |
| if st.session_state['chat_history']: | |
| last_chat = st.session_state['chat_history'][-1] | |
| message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user') | |
| new_placeholder = st.empty() | |
| sender_name = "μ½μ¬: " | |
| for j in range(len(last_chat['Bot'])): | |
| new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1])) | |
| time.sleep(0.05) | |
| with st.form('form', clear_on_submit=True): | |
| user_input = st.text_input('You: ', '', key='input') | |
| submitted = st.form_submit_button('Send') | |
| if submitted and user_input: | |
| with st.spinner('μλ΅μ μμ±μ€μ λλ€...'): | |
| output = generate_response(user_input) | |
| st.session_state.chat_history.append({"User": user_input, "Bot": output}) |