Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from ragchatbot import RAGChatBot | |
| from pydantic_models import RequestModel, ChatHistoryItem | |
| def validate_chat_history_item(chat_history_item: ChatHistoryItem): | |
| return ChatHistoryItem.model_validate(chat_history_item.model_dump()) | |
| st.set_page_config(page_title="RAG-Chatbot", page_icon=":mag:", layout="wide") | |
| st.title("Test Formatted Text - KCS10-31") | |
| col1, col2 = st.columns(2) | |
| col1.title("Current Model") | |
| col2.title("Formatted Text") | |
| if "formatted_ragchatbot" not in st.session_state: | |
| st.session_state.formatted_ragchatbot = RAGChatBot(vectorstore_path="KCS_formatted_vectorstore") | |
| if "just_ragchatbot" not in st.session_state: | |
| st.session_state.just_ragchatbot = RAGChatBot(vectorstore_path="KCS_current_vectorstore") | |
| if "formatted_chat_history" not in st.session_state: | |
| st.session_state.formatted_chat_history = [] | |
| if "just_chat_history" not in st.session_state: | |
| st.session_state.just_chat_history = [] | |
| for chat_index in range(0,len(st.session_state.formatted_chat_history)): | |
| assert len(st.session_state.formatted_chat_history) == len(st.session_state.just_chat_history) | |
| for col, chat_history, sources_text in zip(st.columns(2, vertical_alignment="top"), [st.session_state.just_chat_history, st.session_state.formatted_chat_history], ["Current Model", "Formatted Text"]): | |
| chat = chat_history[chat_index] | |
| with col.chat_message("user"): | |
| st.write(chat.get("user_message").replace("\n","\n\n")) | |
| with col.chat_message("assistant"): | |
| st.write(chat.get("assistant_message").replace("\n","\n\n")) | |
| st.write(chat.get("search_phrase")) | |
| for i, doc in enumerate(chat.get("sources_documents")): | |
| with st.expander(f"{sources_text} Sources - {i+1}"): | |
| st.subheader(f"{doc.get('heading')} - {doc.get('relevance_score')}") | |
| if sources_text == "Contextual Chunking": | |
| st.write(doc.get("page_content").replace("\n","\n\n").split("<chunk_content>")[1].split("</chunk_content>")[0]) | |
| else: | |
| st.write(doc.get("page_content").replace("\n","\n\n")) | |
| # print_session_state_variables() | |
| if user_query := st.chat_input("Enter your query"): | |
| for col in st.columns(2, vertical_alignment="top"): | |
| with col.chat_message("user"): | |
| st.write(user_query.replace("\n","\n\n")) | |
| with st.spinner("Generating response..."): | |
| just_response = st.session_state.just_ragchatbot.get_response( | |
| RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.just_chat_history]) | |
| ) | |
| sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in just_response.sources_documents] | |
| st.session_state.just_chat_history.append({ | |
| "user_message": user_query, | |
| "assistant_message": just_response.answer, | |
| "search_phrase": just_response.search_phrase, | |
| "sources_documents": sources_documents | |
| }) | |
| formatted_response = st.session_state.formatted_ragchatbot.get_response( | |
| RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.formatted_chat_history]) | |
| ) | |
| sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in formatted_response.sources_documents] | |
| st.session_state.formatted_chat_history.append({ | |
| "user_message": user_query, | |
| "assistant_message": formatted_response.answer, | |
| "search_phrase": formatted_response.search_phrase, | |
| "sources_documents": sources_documents | |
| }) | |
| st.rerun() |