File size: 5,152 Bytes
4d772a0
 
ce1bfdf
4d772a0
 
ce1bfdf
 
 
4d772a0
 
 
ce1bfdf
 
 
 
 
 
 
 
 
 
 
 
 
 
4d772a0
 
 
 
 
 
 
ce1bfdf
 
 
 
ad1a23f
 
928a547
ce1bfdf
 
4d772a0
ce1bfdf
 
 
 
 
 
4d772a0
 
ce1bfdf
 
4d772a0
 
 
ce1bfdf
 
 
 
4d772a0
ce1bfdf
 
 
 
4d772a0
ce1bfdf
4d772a0
ce1bfdf
4d772a0
 
ce1bfdf
4d772a0
ce1bfdf
 
4d772a0
ce1bfdf
4d772a0
 
 
 
ce1bfdf
4d772a0
 
ce1bfdf
4d772a0
 
 
ce1bfdf
 
 
 
 
 
 
4d772a0
ce1bfdf
 
 
 
 
 
4d772a0
 
ce1bfdf
 
4d772a0
ce1bfdf
 
 
 
4d772a0
ce1bfdf
4d772a0
ce1bfdf
4d772a0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import torch
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from streamlit_chat import message
from better_profanity import profanity

# Check if the device supports GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define custom prompt
custom_prompt_template = """
Always answer the following QUESTION based on the CONTEXT ONLY and make sure the answer is in bullet points along with a few conversational lines related to the question. 
If the CONTEXT doesn't contain the answer, or the question is outside the domain of expertise for CPGRAMS (Centralised Public Grievance Redress and Monitoring System), politely respond with:
"I'm sorry, but I don't have any information on that topic in my database. However, I'm here to help with any other questions or concerns you may have regarding grievance issues or anything else! Feel free to ask, and let's work together to find a solution. Your satisfaction is my priority!"

Context: {context}
Question: {question}
"""

# Check for offensive content
def is_offensive(text):
    """
    Check if the given text contains offensive language using better-profanity.
    Returns True if offensive, False otherwise.
    """
    return profanity.contains_profanity(text)

# Cache resources to optimize performance
@st.cache_resource
def load_model():
    """Load the Hugging Face model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",device_map="auto")
    model_name = "mistralai/Mistral-7B-Instruct-v0.3"  # Updated model name
    hf_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device.type == "cuda" else -1)
    return HuggingFacePipeline(pipeline=hf_pipeline)

@st.cache_resource
def load_faiss():
    """Load FAISS vector store with SentenceTransformer embeddings."""
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = FAISS.load_local("Vector_Data", embeddings, allow_dangerous_deserialization=True)
    return vector_store

def set_custom_prompt():
    """Set the custom prompt for QA."""
    return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])

@st.cache_resource
def qa_llm():
    """Create the RetrievalQA chain."""
    llm = load_model()
    vector_store = load_faiss()
    retriever = vector_store.as_retriever(search_kwargs={"k": 1})
    prompt = set_custom_prompt()
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": prompt},
    )
    return qa_chain

def process_answer(instruction):
    """Process the user's query and get the model's answer."""
    qa = qa_llm()
    response = qa({"query": instruction})
    return response["result"]

# Main Streamlit application
def main():
    st.title("🤖 CPGRAM Grievance Chatbot")

    if "generated" not in st.session_state:
        st.session_state["generated"] = ["Hello! Ask me any queries related to Grievance and CPGRAM Portal."]
    if "past" not in st.session_state:
        st.session_state["past"] = ["Hey! 👋"]

    reply_container = st.container()
    user_input = st.chat_input(placeholder="Please describe your queries here...", key="input")

    # Predefined buttons for common questions
    if st.button("What is CPGRAM?", key="cpgram_button"):
        st.session_state["past"].append("What is CPGRAM?")
        with st.spinner("Generating response..."):
            answer = process_answer("What is CPGRAM?")
        st.session_state["generated"].append(answer)

    elif st.button("How to fill grievance form?", key="grievance_button"):
        st.session_state["past"].append("How to fill grievance form?")
        with st.spinner("Generating response..."):
            answer = process_answer("How to fill grievance form?")
        st.session_state["generated"].append(answer)

    # Handle user input
    elif user_input:
        if is_offensive(user_input):
            st.session_state["past"].append("User input flagged as offensive")
            st.session_state["generated"].append("I'm sorry, but I can't assist with offensive content.")
        else:
            st.session_state["past"].append(user_input)
            with st.spinner("Generating response..."):
                answer = process_answer(user_input)
            st.session_state["generated"].append(answer)

    # Display conversation
    if st.session_state["generated"]:
        with reply_container:
            for i in range(len(st.session_state["generated"])):
                message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
                message(st.session_state["generated"][i], key=str(i))

if __name__ == "__main__":
    main()