from dataclasses import dataclass from typing import Literal import streamlit as st from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from langchain.document_loaders.csv_loader import CSVLoader from langchain.callbacks import get_openai_callback from langchain.chains import ConversationChain from langchain.chains.conversation.memory import ConversationSummaryMemory from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain import streamlit.components.v1 as components import os from langchain.chains import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT @dataclass class Message: """Class for keeping track of a chat message.""" origin: Literal["human", "ai"] message: str def load_css(): with open("static/styles.css", "r") as f: css = f"" st.markdown(css, unsafe_allow_html=True) @st.cache_resource() def load_index(): loader = CSVLoader(file_path='dagens_nyheter_new.csv') doc = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) docs = text_splitter.split_documents(doc) embeddings = OpenAIEmbeddings() docsearch = Chroma.from_documents(docs, embeddings) return docsearch vectorstore = load_index() def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "conversation" not in st.session_state: memory = ConversationSummaryMemory( llm=ChatOpenAI(temperature=0, model='gpt-3.5-turbo-0613'), memory_key="chat_history", return_messages=True ) question_generator = LLMChain( llm=ChatOpenAI(temperature=0, model="gpt-4"), prompt=CONDENSE_QUESTION_PROMPT ) doc_chain = load_qa_chain( ChatOpenAI(temperature=0, model='gpt-3.5-turbo-16k'), chain_type="stuff" ) st.session_state.conversation = ConversationalRetrievalChain( retriever=vectorstore.as_retriever(search_kwargs=dict(k=50)), question_generator=question_generator, combine_docs_chain=doc_chain, memory=memory ) def on_click_callback(): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt llm_response = st.session_state.conversation.run( {"question": human_prompt} ) st.session_state.history.append( Message("human", human_prompt) ) st.session_state.history.append( Message("ai", llm_response) ) st.session_state.token_count += cb.total_tokens # Reset the user input to empty string after sending the message st.session_state.human_prompt = "" load_css() initialize_session_state() st.title("Dagens Nyheter Reviews") chat_placeholder = st.container() prompt_placeholder = st.form("chat-form") credit_card_placeholder = st.empty() with chat_placeholder: for chat in st.session_state.history: div = f"""