Aditya757864's picture
Update app.py
ad1a23f verified
import streamlit as st
import torch
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from streamlit_chat import message
from better_profanity import profanity
# Check if the device supports GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define custom prompt
custom_prompt_template = """
Always answer the following QUESTION based on the CONTEXT ONLY and make sure the answer is in bullet points along with a few conversational lines related to the question.
If the CONTEXT doesn't contain the answer, or the question is outside the domain of expertise for CPGRAMS (Centralised Public Grievance Redress and Monitoring System), politely respond with:
"I'm sorry, but I don't have any information on that topic in my database. However, I'm here to help with any other questions or concerns you may have regarding grievance issues or anything else! Feel free to ask, and let's work together to find a solution. Your satisfaction is my priority!"
Context: {context}
Question: {question}
"""
# Check for offensive content
def is_offensive(text):
"""
Check if the given text contains offensive language using better-profanity.
Returns True if offensive, False otherwise.
"""
return profanity.contains_profanity(text)
# Cache resources to optimize performance
@st.cache_resource
def load_model():
"""Load the Hugging Face model and tokenizer."""
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",device_map="auto")
model_name = "mistralai/Mistral-7B-Instruct-v0.3" # Updated model name
hf_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device.type == "cuda" else -1)
return HuggingFacePipeline(pipeline=hf_pipeline)
@st.cache_resource
def load_faiss():
"""Load FAISS vector store with SentenceTransformer embeddings."""
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.load_local("Vector_Data", embeddings, allow_dangerous_deserialization=True)
return vector_store
def set_custom_prompt():
"""Set the custom prompt for QA."""
return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
@st.cache_resource
def qa_llm():
"""Create the RetrievalQA chain."""
llm = load_model()
vector_store = load_faiss()
retriever = vector_store.as_retriever(search_kwargs={"k": 1})
prompt = set_custom_prompt()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt},
)
return qa_chain
def process_answer(instruction):
"""Process the user's query and get the model's answer."""
qa = qa_llm()
response = qa({"query": instruction})
return response["result"]
# Main Streamlit application
def main():
st.title("🤖 CPGRAM Grievance Chatbot")
if "generated" not in st.session_state:
st.session_state["generated"] = ["Hello! Ask me any queries related to Grievance and CPGRAM Portal."]
if "past" not in st.session_state:
st.session_state["past"] = ["Hey! 👋"]
reply_container = st.container()
user_input = st.chat_input(placeholder="Please describe your queries here...", key="input")
# Predefined buttons for common questions
if st.button("What is CPGRAM?", key="cpgram_button"):
st.session_state["past"].append("What is CPGRAM?")
with st.spinner("Generating response..."):
answer = process_answer("What is CPGRAM?")
st.session_state["generated"].append(answer)
elif st.button("How to fill grievance form?", key="grievance_button"):
st.session_state["past"].append("How to fill grievance form?")
with st.spinner("Generating response..."):
answer = process_answer("How to fill grievance form?")
st.session_state["generated"].append(answer)
# Handle user input
elif user_input:
if is_offensive(user_input):
st.session_state["past"].append("User input flagged as offensive")
st.session_state["generated"].append("I'm sorry, but I can't assist with offensive content.")
else:
st.session_state["past"].append(user_input)
with st.spinner("Generating response..."):
answer = process_answer(user_input)
st.session_state["generated"].append(answer)
# Display conversation
if st.session_state["generated"]:
with reply_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))
if __name__ == "__main__":
main()