Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal | |
| import streamlit as st | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.document_loaders.csv_loader import CSVLoader | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import ConversationChain | |
| from langchain.chains.conversation.memory import ConversationSummaryMemory | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| import streamlit.components.v1 as components | |
| import os | |
| from langchain.chains import LLMChain | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
| class Message: | |
| """Class for keeping track of a chat message.""" | |
| origin: Literal["human", "ai"] | |
| message: str | |
| def load_css(): | |
| with open("static/styles.css", "r") as f: | |
| css = f"<style>{f.read()}</style>" | |
| st.markdown(css, unsafe_allow_html=True) | |
| def load_index(): | |
| loader = CSVLoader(file_path='dagens_nyheter_new.csv') | |
| doc = loader.load() | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| docs = text_splitter.split_documents(doc) | |
| embeddings = OpenAIEmbeddings() | |
| docsearch = Chroma.from_documents(docs, embeddings) | |
| return docsearch | |
| vectorstore = load_index() | |
| def initialize_session_state(): | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| if "token_count" not in st.session_state: | |
| st.session_state.token_count = 0 | |
| if "conversation" not in st.session_state: | |
| memory = ConversationSummaryMemory( | |
| llm=ChatOpenAI(temperature=0, model='gpt-3.5-turbo-0613'), | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| question_generator = LLMChain( | |
| llm=ChatOpenAI(temperature=0, model="gpt-4"), | |
| prompt=CONDENSE_QUESTION_PROMPT | |
| ) | |
| doc_chain = load_qa_chain( | |
| ChatOpenAI(temperature=0, model='gpt-3.5-turbo-16k'), | |
| chain_type="stuff" | |
| ) | |
| st.session_state.conversation = ConversationalRetrievalChain( | |
| retriever=vectorstore.as_retriever(search_kwargs=dict(k=50)), | |
| question_generator=question_generator, | |
| combine_docs_chain=doc_chain, | |
| memory=memory | |
| ) | |
| def on_click_callback(): | |
| with get_openai_callback() as cb: | |
| human_prompt = st.session_state.human_prompt | |
| llm_response = st.session_state.conversation.run( | |
| {"question": human_prompt} | |
| ) | |
| st.session_state.history.append( | |
| Message("human", human_prompt) | |
| ) | |
| st.session_state.history.append( | |
| Message("ai", llm_response) | |
| ) | |
| st.session_state.token_count += cb.total_tokens | |
| # Reset the user input to empty string after sending the message | |
| st.session_state.human_prompt = "" | |
| load_css() | |
| initialize_session_state() | |
| st.title("Dagens Nyheter Reviews") | |
| chat_placeholder = st.container() | |
| prompt_placeholder = st.form("chat-form") | |
| credit_card_placeholder = st.empty() | |
| with chat_placeholder: | |
| for chat in st.session_state.history: | |
| div = f""" | |
| <div class="chat-row | |
| {'' if chat.origin == 'ai' else 'row-reverse'}"> | |
| <img class="chat-icon" src="https://huggingface.co/spaces/felix-weiland/dn/resolve/main/static/{ | |
| 'ai_icon.png' if chat.origin == 'ai' | |
| else 'user_icon.png'}" | |
| width=32 height=32> | |
| <div class="chat-bubble | |
| {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
| ​{chat.message} | |
| </div> | |
| </div> | |
| """ | |
| st.markdown(div, unsafe_allow_html=True) | |
| for _ in range(3): | |
| st.markdown("") | |
| with prompt_placeholder: | |
| st.markdown("**Chat**") | |
| cols = st.columns((6, 1)) | |
| cols[0].text_input( | |
| "Chat", | |
| value="", | |
| label_visibility="collapsed", | |
| key="human_prompt", | |
| ) | |
| cols[1].form_submit_button( | |
| "Submit", | |
| type="primary", | |
| on_click=on_click_callback, | |
| ) | |
| credit_card_placeholder.caption(f""" | |
| Used {st.session_state.token_count} tokens \n | |
| Debug Langchain conversation: | |
| {st.session_state.conversation.memory.buffer} | |
| """) | |
| components.html(""" | |
| <script> | |
| const streamlitDoc = window.parent.document; | |
| const buttons = Array.from( | |
| streamlitDoc.querySelectorAll('.stButton > button') | |
| ); | |
| const submitButton = buttons.find( | |
| el => el.innerText === 'Submit' | |
| ); | |
| streamlitDoc.addEventListener('keydown', function(e) { | |
| switch (e.key) { | |
| case 'Enter': | |
| submitButton.click(); | |
| break; | |
| } | |
| }); | |
| </script> | |
| """, | |
| height=0, | |
| width=0, | |
| ) |