import streamlit as st import torch from langchain.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from streamlit_chat import message from better_profanity import profanity # Check if the device supports GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define custom prompt custom_prompt_template = """ Always answer the following QUESTION based on the CONTEXT ONLY and make sure the answer is in bullet points along with a few conversational lines related to the question. If the CONTEXT doesn't contain the answer, or the question is outside the domain of expertise for CPGRAMS (Centralised Public Grievance Redress and Monitoring System), politely respond with: "I'm sorry, but I don't have any information on that topic in my database. However, I'm here to help with any other questions or concerns you may have regarding grievance issues or anything else! Feel free to ask, and let's work together to find a solution. Your satisfaction is my priority!" Context: {context} Question: {question} """ # Check for offensive content def is_offensive(text): """ Check if the given text contains offensive language using better-profanity. Returns True if offensive, False otherwise. """ return profanity.contains_profanity(text) # Cache resources to optimize performance @st.cache_resource def load_model(): """Load the Hugging Face model and tokenizer.""" tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",device_map="auto") model_name = "mistralai/Mistral-7B-Instruct-v0.3" # Updated model name hf_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device.type == "cuda" else -1) return HuggingFacePipeline(pipeline=hf_pipeline) @st.cache_resource def load_faiss(): """Load FAISS vector store with SentenceTransformer embeddings.""" embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") vector_store = FAISS.load_local("Vector_Data", embeddings, allow_dangerous_deserialization=True) return vector_store def set_custom_prompt(): """Set the custom prompt for QA.""" return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"]) @st.cache_resource def qa_llm(): """Create the RetrievalQA chain.""" llm = load_model() vector_store = load_faiss() retriever = vector_store.as_retriever(search_kwargs={"k": 1}) prompt = set_custom_prompt() qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": prompt}, ) return qa_chain def process_answer(instruction): """Process the user's query and get the model's answer.""" qa = qa_llm() response = qa({"query": instruction}) return response["result"] # Main Streamlit application def main(): st.title("🤖 CPGRAM Grievance Chatbot") if "generated" not in st.session_state: st.session_state["generated"] = ["Hello! Ask me any queries related to Grievance and CPGRAM Portal."] if "past" not in st.session_state: st.session_state["past"] = ["Hey! 👋"] reply_container = st.container() user_input = st.chat_input(placeholder="Please describe your queries here...", key="input") # Predefined buttons for common questions if st.button("What is CPGRAM?", key="cpgram_button"): st.session_state["past"].append("What is CPGRAM?") with st.spinner("Generating response..."): answer = process_answer("What is CPGRAM?") st.session_state["generated"].append(answer) elif st.button("How to fill grievance form?", key="grievance_button"): st.session_state["past"].append("How to fill grievance form?") with st.spinner("Generating response..."): answer = process_answer("How to fill grievance form?") st.session_state["generated"].append(answer) # Handle user input elif user_input: if is_offensive(user_input): st.session_state["past"].append("User input flagged as offensive") st.session_state["generated"].append("I'm sorry, but I can't assist with offensive content.") else: st.session_state["past"].append(user_input) with st.spinner("Generating response..."): answer = process_answer(user_input) st.session_state["generated"].append(answer) # Display conversation if st.session_state["generated"]: with reply_container: for i in range(len(st.session_state["generated"])): message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") message(st.session_state["generated"][i], key=str(i)) if __name__ == "__main__": main()