Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from src.file_loader import load_file | |
| from src.rag_pipeline import build_rag_pipeline, get_relevant_docs | |
| from src.model_utils import load_hf_model, generate_answer | |
| from src.utils import get_font_css | |
| st.set_page_config(page_title="AI Chatbot", page_icon=":robot_face:", layout="wide") | |
| st.markdown(get_font_css(), unsafe_allow_html=True) | |
| st.sidebar.image("assets/logo.png", width=180) | |
| st.sidebar.title("AI Chatbot") | |
| st.sidebar.markdown("Upload a file to get started:") | |
| uploaded_file = st.sidebar.file_uploader( | |
| "Upload PDF, CSV, or XLSX", type=["pdf", "csv", "xlsx"] | |
| ) | |
| model_name = st.sidebar.text_input( | |
| "HuggingFace Model (text-generation)", value="amiguel/GM_Qwen1.8B_Finetune" | |
| ) | |
| embedding_model = st.sidebar.text_input( | |
| "Embedding Model", value="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("Powered by [Your Company]") | |
| st.markdown( | |
| """ | |
| <div style="display: flex; align-items: center; margin-bottom: 1rem;"> | |
| <img src="app/assets/logo.png" width="60" style="margin-right: 1rem;"> | |
| <h1 style="font-family: 'Tw Cen MT', sans-serif; margin: 0;">AI Chatbot</h1> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| if uploaded_file: | |
| with st.spinner("Processing file..."): | |
| text = load_file(uploaded_file) | |
| docs = [{"page_content": chunk, "metadata": {}} for chunk in text] | |
| retriever = build_rag_pipeline(docs, embedding_model) | |
| st.success("File processed and indexed!") | |
| with st.spinner("Loading model..."): | |
| text_gen = load_hf_model(model_name) | |
| st.success("Model loaded!") | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| user_input = st.text_input("Ask a question about your document:", key="user_input") | |
| if st.button("Send", use_container_width=True) and user_input: | |
| with st.spinner("Generating answer..."): | |
| context_docs = get_relevant_docs(retriever, user_input) | |
| context = " ".join([doc["page_content"] for doc in context_docs]) | |
| answer = generate_answer(text_gen, user_input, context) | |
| st.session_state.chat_history.append(("user", user_input)) | |
| st.session_state.chat_history.append(("bot", answer)) | |
| for sender, msg in st.session_state.chat_history: | |
| if sender == "user": | |
| st.markdown( | |
| f""" | |
| <div style="background: #e6f0fa; border-radius: 10px; padding: 10px; margin-bottom: 5px; text-align: right; font-family: 'Tw Cen MT', sans-serif;"> | |
| <b>You:</b> {msg} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.markdown( | |
| f""" | |
| <div style="background: #f4f4f4; border-radius: 10px; padding: 10px; margin-bottom: 10px; text-align: left; font-family: 'Tw Cen MT', sans-serif;"> | |
| <b>AI:</b> {msg} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.info("Please upload a PDF, CSV, or XLSX file to begin.") | |