import os import streamlit as st from langchain.embeddings import HuggingFaceEmbeddings from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS from langchain_core.prompts import PromptTemplate from langchain_community.llms import HuggingFaceEndpoint import time import translators as ts from huggingface_hub import hf_hub_download # Set page layout to wide st.set_page_config(layout="wide") # ================== CONFIGURATION ================== # HF_TOKEN = os.getenv("HF_TOKEN") # From Spaces secrets VECTORSTORE_REPO_ID = "vashu2425/bhagavad-geeta-faiss-vectordb" MODEL_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3" CUSTOM_PROMPT_TEMPLATE = """ Use The Pieces Of Information Provided In The Context To Answer User's Question. If You Don't Know The Answer, Just Say "I Don't Have Information",except this do not say anything. Don't Try To Make Up An Answer. Don't Provide Anything Out Of The Given Context. Context: {context} Question: {question} Start The Answer Directly., Please. The Answer Should Contain All 3 Contexts. Consider Yourself As God Krishna And Answer The Question Result Should Not Start With "Answer" """ # Keep your template here # ---------- Session Management Functions ---------- # def initialize_session_states(): session_defaults = { "messages": [], "selected_question": None, "show_predefined": True, "last_response": None, "translation_done": False, "last_prompt": None # Add this line } for key, val in session_defaults.items(): if key not in st.session_state: st.session_state[key] = val def render_chat_messages(): for message in st.session_state.messages: with st.chat_message(message["role"], avatar="🐿" if message["role"] == "user" else "🪈"): content = message["content"] if "hindi-text" in content: st.markdown(content, unsafe_allow_html=True) else: st.markdown(content) def render_predefined_questions(): predefined_questions = [ "Meaning of Dharma?", "What is the purpose of life?", "How to find inner peace?", "How can I be a better person?", "What is the meaning of life?", "How can I be a better friend?" ] st.markdown("### Or, try one of these:") buttons = st.columns(len(predefined_questions)) for idx, question in enumerate(predefined_questions): if buttons[idx].button(question, key=f"predefined_{idx}"): st.session_state.selected_question = question st.session_state.show_predefined = False # ---------- Core Functionality Functions ---------- # def translate_text(text, dest_language="hi"): try: # Use the updated translation method return ts.translate_text( text, to_language=dest_language, translator='google' ) except Exception as e: st.error(f"Translation failed: {str(e)}") return text @st.cache_resource def get_vectorstore(): try: embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') os.makedirs("vectorstore/db_faiss", exist_ok=True) faiss_files = ["index.faiss", "index.pkl"] for filename in faiss_files: if not os.path.exists(f"vectorstore/db_faiss/{filename}"): hf_hub_download( repo_id=VECTORSTORE_REPO_ID, filename=filename, local_dir="vectorstore/db_faiss", token=HF_TOKEN, repo_type="dataset" ) return FAISS.load_local("vectorstore/db_faiss", embedding_model, allow_dangerous_deserialization=True) except Exception as e: st.error(f"Vectorstore initialization failed: {str(e)}") st.stop() def set_custom_prompt(custom_prompt_template): return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"]) def load_llm(huggingface_repo_id, hf_token): return HuggingFaceEndpoint( repo_id=huggingface_repo_id, temperature=0.4, huggingfacehub_api_token=hf_token, model_kwargs={"max_length": 512} ) def handle_translation(): if "last_response" in st.session_state and st.session_state.last_response: try: if not st.session_state.get("translation_done", False): translated_text = translate_text(st.session_state.last_response, "hi") # Update the last assistant message for i in range(len(st.session_state.messages) - 1, -1, -1): if st.session_state.messages[i]["role"] == "assistant": st.session_state.messages[i]["content"] = f'
{translated_text}
' break # Mark translation as done st.session_state.translation_done = True st.rerun() # Forces a UI refresh except Exception as e: st.error(f"Translation error: {str(e)}") def format_source_docs(source_documents): formatted_docs = [] for idx, doc in enumerate(source_documents, start=1): content = doc.page_content.replace('\t', ' ').replace('\n', ' ').strip() formatted_doc = f"**Source {idx}** (Page {doc.metadata['page']}):\n\n{content[:500]}..." formatted_docs.append(formatted_doc) return "\n\n".join(formatted_docs) def handle_user_input(prompt, qa_chain): if prompt: # Check if this prompt has already been processed if st.session_state.get("last_prompt") == prompt: return # Store the current prompt to prevent reprocessing st.session_state.last_prompt = prompt with st.chat_message("user", avatar="🐿"): st.markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) try: # Add temporary assistant message with st.chat_message("assistant", avatar="🪈"): response_placeholder = st.empty() # Process query and generate response response = qa_chain.invoke({"query": prompt}) result = response["result"] source_documents = response["source_documents"] # Build response incrementally accumulated_text = "" for char in result: accumulated_text += char response_placeholder.markdown(f'
{accumulated_text}
', unsafe_allow_html=True) time.sleep(0.01) # Update session state with final response st.session_state.messages.append({ "role": "assistant", "content": f'
{accumulated_text}
', "original": accumulated_text }) st.session_state.last_response = accumulated_text st.session_state.show_predefined = False st.session_state.translation_done = False if "don't have information" not in result.lower(): with st.expander("Source Documents"): st.markdown(format_source_docs(source_documents)) except Exception as e: st.error(f"Error: {str(e)}") # Remove temporary assistant message on error if st.session_state.messages and st.session_state.messages[-1]["role"] == "assistant": st.session_state.messages.pop() # def handle_translation(): # if "last_response" in st.session_state and st.session_state.last_response: # try: # if not st.session_state.get("translation_done", False): # translated_text = translate_text(st.session_state.last_response, "hi") # # Update messages # for msg in reversed(st.session_state.messages): # if msg["role"] == "assistant": # msg["content"] = f'
{translated_text}
' # break # st.session_state.translation_done = True # st.rerun() # Corrected rerun method # except Exception as e: # st.error(f"Translation error: {str(e)}") def render_chat_messages(): for message in st.session_state.messages: with st.chat_message(message["role"], avatar="🐿" if message["role"] == "user" else "🪈"): content = message.get("original", message["content"]) # Show original if available if "hindi-text" in message["content"]: st.markdown(message["content"], unsafe_allow_html=True) else: st.markdown(content) def main(): st.markdown( """ Source Bhagavad Gita PDF """, unsafe_allow_html=True ) st.title("Ask Krishna! 🦚") st.markdown('

शांति स्वीकृति से शुरू होती है

', unsafe_allow_html=True) initialize_session_states() render_chat_messages() if st.session_state.show_predefined: render_predefined_questions() prompt = st.chat_input("What's your curiosity?") or st.session_state.selected_question st.session_state.selected_question = None try: vectorstore = get_vectorstore() qa_chain = RetrievalQA.from_chain_type( llm=load_llm(MODEL_REPO_ID, HF_TOKEN), chain_type="stuff", retriever=vectorstore.as_retriever(search_kwargs={"k": 3}), return_source_documents=True, chain_type_kwargs={"prompt": set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)} ) if prompt: handle_user_input(prompt, qa_chain) if st.session_state.get("last_response"): col1, col2 = st.columns([1, 3]) with col1: if st.button("🌐 Translate to Hindi", key="translate_btn"): handle_translation() with col2: if st.session_state.get("translation_done"): st.success("Translation to Hindi completed!") except Exception as e: st.error(f"Initialization error: {str(e)}") if __name__ == "__main__": main()