|
|
import streamlit as st |
|
|
from rag_components import load_documents, split_documents, create_embeddings, setup_vector_store, create_qa_chain, create_streaming_response |
|
|
import os |
|
|
|
|
|
|
|
|
cache_dirs = ["/tmp/huggingface_cache", "/tmp/transformers_cache", "/tmp/hf_hub_cache", "/tmp/sentence_transformers_cache"] |
|
|
for cache_dir in cache_dirs: |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Assistant", |
|
|
page_icon="🤖", |
|
|
layout="wide", |
|
|
initial_sidebar_state="collapsed" |
|
|
) |
|
|
|
|
|
st.title("Juma's Assistant") |
|
|
st.markdown("---") |
|
|
|
|
|
@st.cache_resource |
|
|
def initialize_rag_components(file_path="./me.txt"): |
|
|
"""Initializes and caches RAG components with better error handling.""" |
|
|
try: |
|
|
if not os.path.exists(file_path): |
|
|
st.error(f"Error: Document file not found at {file_path}") |
|
|
return None, None |
|
|
|
|
|
with st.spinner("Loading documents..."): |
|
|
documents = load_documents(file_path) |
|
|
st.info(f"Loaded {len(documents)} documents") |
|
|
|
|
|
with st.spinner("Splitting documents into chunks..."): |
|
|
docs = split_documents(documents) |
|
|
st.info(f"Split into {len(docs)} chunks") |
|
|
|
|
|
with st.spinner("Creating embeddings (this may take a while)..."): |
|
|
embeddings = create_embeddings() |
|
|
|
|
|
with st.spinner("Setting up vector store..."): |
|
|
retriever = setup_vector_store(docs, embeddings) |
|
|
|
|
|
with st.spinner("Initializing QA chain..."): |
|
|
qa_chain = create_qa_chain(retriever) |
|
|
|
|
|
st.success("Welcome! Ask me anything about Juma.") |
|
|
return qa_chain, retriever |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error: initializing: {e}") |
|
|
st.info("This might be due to model download issues. Please try refreshing the page.") |
|
|
return None, None |
|
|
|
|
|
qa_chain, retriever = initialize_rag_components() |
|
|
|
|
|
if qa_chain is not None: |
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask me any question..."): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
st.write(prompt) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
message_placeholder = st.empty() |
|
|
|
|
|
try: |
|
|
|
|
|
full_response = create_streaming_response(qa_chain, prompt, message_placeholder) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {e}") |
|
|
full_response = "I apologize, but I encountered an error while processing your question. Please try again." |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
else: |
|
|
st.warning("RAG components could not be initialized. Please check the document file path.") |
|
|
|