dn / app.py
felix-weiland's picture
Update app.py
32d2438
from dataclasses import dataclass
from typing import Literal
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryMemory
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
import streamlit.components.v1 as components
import os
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
@dataclass
class Message:
"""Class for keeping track of a chat message."""
origin: Literal["human", "ai"]
message: str
def load_css():
with open("static/styles.css", "r") as f:
css = f"<style>{f.read()}</style>"
st.markdown(css, unsafe_allow_html=True)
@st.cache_resource()
def load_index():
loader = CSVLoader(file_path='dagens_nyheter_new.csv')
doc = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = text_splitter.split_documents(doc)
embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_documents(docs, embeddings)
return docsearch
vectorstore = load_index()
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = []
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if "conversation" not in st.session_state:
memory = ConversationSummaryMemory(
llm=ChatOpenAI(temperature=0, model='gpt-3.5-turbo-0613'),
memory_key="chat_history",
return_messages=True
)
question_generator = LLMChain(
llm=ChatOpenAI(temperature=0, model="gpt-4"),
prompt=CONDENSE_QUESTION_PROMPT
)
doc_chain = load_qa_chain(
ChatOpenAI(temperature=0, model='gpt-3.5-turbo-16k'),
chain_type="stuff"
)
st.session_state.conversation = ConversationalRetrievalChain(
retriever=vectorstore.as_retriever(search_kwargs=dict(k=50)),
question_generator=question_generator,
combine_docs_chain=doc_chain,
memory=memory
)
def on_click_callback():
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt
llm_response = st.session_state.conversation.run(
{"question": human_prompt}
)
st.session_state.history.append(
Message("human", human_prompt)
)
st.session_state.history.append(
Message("ai", llm_response)
)
st.session_state.token_count += cb.total_tokens
# Reset the user input to empty string after sending the message
st.session_state.human_prompt = ""
load_css()
initialize_session_state()
st.title("Dagens Nyheter Reviews")
chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form")
credit_card_placeholder = st.empty()
with chat_placeholder:
for chat in st.session_state.history:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="https://huggingface.co/spaces/felix-weiland/dn/resolve/main/static/{
'ai_icon.png' if chat.origin == 'ai'
else 'user_icon.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
value="",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
credit_card_placeholder.caption(f"""
Used {st.session_state.token_count} tokens \n
Debug Langchain conversation:
{st.session_state.conversation.memory.buffer}
""")
components.html("""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)