Spaces:
Build error
Build error
| import streamlit as st | |
| import time | |
| import os | |
| import gc | |
| import torch | |
| from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs | |
| from src.embedding import FAQEmbedder | |
| from src.llm_response import ResponseGenerator | |
| from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search | |
| from deep_translator import GoogleTranslator # Updated import | |
| # Suppress CUDA warning and Torch path errors | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| os.environ["TORCH_NO_PATH_CHECK"] = "1" | |
| st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded") | |
| def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True): | |
| """ | |
| Initialize RAG system components | |
| """ | |
| try: | |
| if use_huggingface: | |
| faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs") | |
| else: | |
| faqs = load_faq_data("data/faq_data.csv") | |
| processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation) | |
| embedder = FAQEmbedder() | |
| if os.path.exists("embeddings"): | |
| embedder.load("embeddings") | |
| else: | |
| embedder.create_embeddings(processed_faqs) | |
| embedder.save("embeddings") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| response_generator = ResponseGenerator(model_name=model_name) | |
| response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}]) | |
| return embedder, response_generator, len(processed_faqs) | |
| except Exception as e: | |
| st.error(f"Initialization failed: {e}") | |
| raise | |
| def main(): | |
| st.title("E-Commerce Customer Support FAQ Chatbot") | |
| st.subheader("Ask about orders, shipping, returns, or other e-commerce queries") | |
| st.sidebar.title("Configuration") | |
| use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True) | |
| enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset") | |
| target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0) | |
| model_options = { | |
| "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2", | |
| "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1" | |
| } | |
| selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0) | |
| model_name = model_options[selected_model] | |
| if st.sidebar.checkbox("Show Memory Usage", value=True): | |
| st.sidebar.subheader("Memory Usage") | |
| for key, value in format_memory_stats().items(): | |
| st.sidebar.text(f"{key}: {value}") | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "query_cache" not in st.session_state: | |
| st.session_state.query_cache = {} | |
| if "feedback" not in st.session_state: | |
| st.session_state.feedback = [] | |
| if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"): | |
| with st.spinner("Initializing system..."): | |
| try: | |
| st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components( | |
| use_huggingface=use_huggingface, | |
| model_name=model_name, | |
| enable_augmentation=enable_augmentation | |
| ) | |
| st.session_state.system_initialized = True | |
| st.sidebar.success(f"System initialized with {num_faqs} FAQs!") | |
| except Exception as e: | |
| st.error(f"System initialization failed: {e}") | |
| return | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("Conversation") | |
| chat_container = st.container(height=400) | |
| with chat_container: | |
| for i, message in enumerate(st.session_state.chat_history): | |
| if message["role"] == "user": | |
| st.markdown(f"**You**: {message['content']}") | |
| else: | |
| st.markdown(f"**Bot**: {message['content']}") | |
| if i < len(st.session_state.chat_history) - 1: | |
| st.markdown("---") | |
| with st.form(key="chat_form"): | |
| user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?") | |
| submit_button = st.form_submit_button("Ask") | |
| if len(st.session_state.chat_history) > 0: | |
| with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"): | |
| rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}") | |
| comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}") | |
| if st.form_submit_button("Submit Feedback"): | |
| st.session_state.feedback.append({ | |
| "rating": rating, | |
| "comments": comments, | |
| "response": st.session_state.chat_history[-1]["content"] | |
| }) | |
| with open("feedback.json", "w") as f: | |
| json.dump(st.session_state.feedback, f) | |
| st.success("Feedback submitted!") | |
| with col2: | |
| if st.session_state.get("system_initialized", False): | |
| st.subheader("Retrieved Information") | |
| info_container = st.container(height=500) | |
| with info_container: | |
| if "current_faqs" in st.session_state: | |
| for i, faq in enumerate(st.session_state.current_faqs): | |
| st.markdown(f"**Relevant FAQ #{i+1}**") | |
| st.markdown(f"**Q**: {faq['question']}") | |
| st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}") | |
| st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}") | |
| if 'category' in faq and faq['category']: | |
| st.markdown(f"*Category*: {faq['category']}") | |
| st.markdown("---") | |
| else: | |
| st.markdown("Ask a question to see relevant FAQs.") | |
| if "retrieval_time" in st.session_state and "generation_time" in st.session_state: | |
| st.sidebar.subheader("Performance Metrics") | |
| st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds") | |
| st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds") | |
| st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds") | |
| if submit_button and user_query: | |
| from src.data_processing import translate_faq | |
| translator = GoogleTranslator(source='auto', target='en') # Updated translator | |
| if target_lang != "en": | |
| user_query_translated = translator.translate(user_query) | |
| else: | |
| user_query_translated = user_query | |
| if user_query_translated in st.session_state.query_cache: | |
| response, relevant_faqs = st.session_state.query_cache[user_query_translated] | |
| else: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| start_time = time.time() | |
| relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated) | |
| retrieval_time = time.time() - start_time | |
| if target_lang != "en": | |
| relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
| start_time = time.time() | |
| response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs) | |
| generation_time = time.time() - start_time | |
| if target_lang != "en": | |
| response = translator.translate(response, target=target_lang) | |
| st.session_state.query_cache[user_query_translated] = (response, relevant_faqs) | |
| st.session_state.retrieval_time = retrieval_time | |
| st.session_state.generation_time = generation_time | |
| st.session_state.current_faqs = relevant_faqs | |
| st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| if st.button("Clear Chat History"): | |
| st.session_state.chat_history = [] | |
| st.session_state.query_cache = {} | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if st.session_state.get("system_initialized", False): | |
| st.sidebar.subheader("Baseline Comparison") | |
| baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs) | |
| st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}") | |
| st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}") | |
| st.subheader("Sample Questions") | |
| sample_questions = [ | |
| "How do I track my order?", | |
| "What should I do if my delivery is delayed?", | |
| "How do I return a product?", | |
| "Can I cancel my order after placing it?", | |
| "How quickly will my order be delivered?" | |
| ] | |
| cols = st.columns(2) | |
| for i, question in enumerate(sample_questions): | |
| col_idx = i % 2 | |
| if cols[col_idx].button(question, key=f"sample_{i}"): | |
| st.session_state.user_input = question | |
| st.session_state.chat_history.append({"role": "user", "content": question}) | |
| translator = GoogleTranslator(source='auto', target='en') # Updated translator | |
| if target_lang != "en": | |
| question_translated = translator.translate(question) | |
| else: | |
| question_translated = question | |
| if question_translated in st.session_state.query_cache: | |
| response, relevant_faqs = st.session_state.query_cache[question_translated] | |
| else: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| start_time = time.time() | |
| relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated) | |
| retrieval_time = time.time() - start_time | |
| if target_lang != "en": | |
| relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
| start_time = time.time() | |
| response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs) | |
| generation_time = time.time() - start_time | |
| if target_lang != "en": | |
| response = translator.translate(response, target=target_lang) | |
| st.session_state.query_cache[question_translated] = (response, relevant_faqs) | |
| st.session_state.retrieval_time = retrieval_time | |
| st.session_state.generation_time = generation_time | |
| st.session_state.current_faqs = relevant_faqs | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| if __name__ == "__main__": | |
| main() | |
| # import streamlit as st | |
| # import time | |
| # import os | |
| # import gc | |
| # import torch | |
| # from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs | |
| # from src.embedding import FAQEmbedder | |
| # from src.llm_response import ResponseGenerator | |
| # from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search | |
| # # Suppress CUDA warning and Torch path errors | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # os.environ["TORCH_NO_PATH_CHECK"] = "1" | |
| # st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded") | |
| # @time_function | |
| # def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True): | |
| # """ | |
| # Initialize RAG system components | |
| # """ | |
| # try: | |
| # if use_huggingface: | |
| # faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs") | |
| # else: | |
| # faqs = load_faq_data("data/faq_data.csv") | |
| # processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation) | |
| # embedder = FAQEmbedder() | |
| # if os.path.exists("embeddings"): | |
| # embedder.load("embeddings") | |
| # else: | |
| # embedder.create_embeddings(processed_faqs) | |
| # embedder.save("embeddings") | |
| # gc.collect() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.empty_cache() | |
| # response_generator = ResponseGenerator(model_name=model_name) | |
| # response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}]) | |
| # return embedder, response_generator, len(processed_faqs) | |
| # except Exception as e: | |
| # st.error(f"Initialization failed: {e}") | |
| # raise | |
| # def main(): | |
| # st.title("E-Commerce Customer Support FAQ Chatbot") | |
| # st.subheader("Ask about orders, shipping, returns, or other e-commerce queries") | |
| # st.sidebar.title("Configuration") | |
| # use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True) | |
| # enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset") | |
| # target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0) | |
| # model_options = { | |
| # "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2", | |
| # "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| # "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1" | |
| # } | |
| # selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0) | |
| # model_name = model_options[selected_model] | |
| # if st.sidebar.checkbox("Show Memory Usage", value=True): | |
| # st.sidebar.subheader("Memory Usage") | |
| # for key, value in format_memory_stats().items(): | |
| # st.sidebar.text(f"{key}: {value}") | |
| # if "chat_history" not in st.session_state: | |
| # st.session_state.chat_history = [] | |
| # if "query_cache" not in st.session_state: | |
| # st.session_state.query_cache = {} | |
| # if "feedback" not in st.session_state: | |
| # st.session_state.feedback = [] | |
| # if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"): | |
| # with st.spinner("Initializing system..."): | |
| # try: | |
| # st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components( | |
| # use_huggingface=use_huggingface, | |
| # model_name=model_name, | |
| # enable_augmentation=enable_augmentation | |
| # ) | |
| # st.session_state.system_initialized = True | |
| # st.sidebar.success(f"System initialized with {num_faqs} FAQs!") | |
| # except Exception as e: | |
| # st.error(f"System initialization failed: {e}") | |
| # return | |
| # col1, col2 = st.columns([2, 1]) | |
| # with col1: | |
| # st.subheader("Conversation") | |
| # chat_container = st.container(height=400) | |
| # with chat_container: | |
| # for i, message in enumerate(st.session_state.chat_history): | |
| # if message["role"] == "user": | |
| # st.markdown(f"**You**: {message['content']}") | |
| # else: | |
| # st.markdown(f"**Bot**: {message['content']}") | |
| # if i < len(st.session_state.chat_history) - 1: | |
| # st.markdown("---") | |
| # with st.form(key="chat_form"): | |
| # user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?") | |
| # submit_button = st.form_submit_button("Ask") | |
| # if len(st.session_state.chat_history) > 0: | |
| # with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"): | |
| # rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}") | |
| # comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}") | |
| # if st.form_submit_button("Submit Feedback"): | |
| # st.session_state.feedback.append({ | |
| # "rating": rating, | |
| # "comments": comments, | |
| # "response": st.session_state.chat_history[-1]["content"] | |
| # }) | |
| # with open("feedback.json", "w") as f: | |
| # json.dump(st.session_state.feedback, f) | |
| # st.success("Feedback submitted!") | |
| # with col2: | |
| # if st.session_state.get("system_initialized", False): | |
| # st.subheader("Retrieved Information") | |
| # info_container = st.container(height=500) | |
| # with info_container: | |
| # if "current_faqs" in st.session_state: | |
| # for i, faq in enumerate(st.session_state.current_faqs): | |
| # st.markdown(f"**Relevant FAQ #{i+1}**") | |
| # st.markdown(f"**Q**: {faq['question']}") | |
| # st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}") | |
| # st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}") | |
| # if 'category' in faq and faq['category']: | |
| # st.markdown(f"*Category*: {faq['category']}") | |
| # st.markdown("---") | |
| # else: | |
| # st.markdown("Ask a question to see relevant FAQs.") | |
| # if "retrieval_time" in st.session_state and "generation_time" in st.session_state: | |
| # st.sidebar.subheader("Performance Metrics") | |
| # st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds") | |
| # st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds") | |
| # st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds") | |
| # if submit_button and user_query: | |
| # from src.data_processing import translate_faq | |
| # from googletrans import Translator | |
| # translator = Translator() | |
| # if target_lang != "en": | |
| # user_query_translated = translator.translate(user_query, dest="en").text | |
| # else: | |
| # user_query_translated = user_query | |
| # if user_query_translated in st.session_state.query_cache: | |
| # response, relevant_faqs = st.session_state.query_cache[user_query_translated] | |
| # else: | |
| # gc.collect() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.empty_cache() | |
| # start_time = time.time() | |
| # relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated) | |
| # retrieval_time = time.time() - start_time | |
| # if target_lang != "en": | |
| # relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
| # start_time = time.time() | |
| # response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs) | |
| # generation_time = time.time() - start_time | |
| # if target_lang != "en": | |
| # response = translator.translate(response, dest=target_lang).text | |
| # st.session_state.query_cache[user_query_translated] = (response, relevant_faqs) | |
| # st.session_state.retrieval_time = retrieval_time | |
| # st.session_state.generation_time = generation_time | |
| # st.session_state.current_faqs = relevant_faqs | |
| # st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
| # st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| # if st.button("Clear Chat History"): | |
| # st.session_state.chat_history = [] | |
| # st.session_state.query_cache = {} | |
| # gc.collect() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.empty_cache() | |
| # if st.session_state.get("system_initialized", False): | |
| # st.sidebar.subheader("Baseline Comparison") | |
| # baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs) | |
| # st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}") | |
| # st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}") | |
| # st.subheader("Sample Questions") | |
| # sample_questions = [ | |
| # "How do I track my order?", | |
| # "What should I do if my delivery is delayed?", | |
| # "How do I return a product?", | |
| # "Can I cancel my order after placing it?", | |
| # "How quickly will my order be delivered?" | |
| # ] | |
| # cols = st.columns(2) | |
| # for i, question in enumerate(sample_questions): | |
| # col_idx = i % 2 | |
| # if cols[col_idx].button(question, key=f"sample_{i}"): | |
| # st.session_state.user_input = question | |
| # st.session_state.chat_history.append({"role": "user", "content": question}) | |
| # from src.data_processing import translate_faq | |
| # from googletrans import Translator | |
| # translator = Translator() | |
| # if target_lang != "en": | |
| # question_translated = translator.translate(question, dest="en").text | |
| # else: | |
| # question_translated = question | |
| # if question_translated in st.session_state.query_cache: | |
| # response, relevant_faqs = st.session_state.query_cache[question_translated] | |
| # else: | |
| # gc.collect() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.empty_cache() | |
| # start_time = time.time() | |
| # relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated) | |
| # retrieval_time = time.time() - start_time | |
| # if target_lang != "en": | |
| # relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
| # start_time = time.time() | |
| # response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs) | |
| # generation_time = time.time() - start_time | |
| # if target_lang != "en": | |
| # response = translator.translate(response, dest=target_lang).text | |
| # st.session_state.query_cache[question_translated] = (response, relevant_faqs) | |
| # st.session_state.retrieval_time = retrieval_time | |
| # st.session_state.generation_time = generation_time | |
| # st.session_state.current_faqs = relevant_faqs | |
| # st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| # if __name__ == "__main__": | |
| # main() |