File size: 7,552 Bytes
8e6fe89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01e9305
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import streamlit as st
from transformers import AutoTokenizer,  AutoModelForSequenceClassification, pipeline,AutoModelForSeq2SeqLM
from sqlalchemy import create_engine, Column, Integer, String, DateTime,Text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import torch


st.title("Simple Chatbot with persistent memory (mysql)(Flan -T5)")


# database setup mysql

Base = declarative_base()

class Conversation(Base):
    __tablename__ = "conversations"
    id = Column(Integer , primary_key = True ,autoincrement= True)
    user_input = Column(Text, nullable = False)
    chatbot_response = Column(Text, nullable = False)
    timestamp = Column(DateTime, nullable = False)
# creating engine to connect to mysql 

DATABASE_URL = "sqlite:///chatbot.db"  # SQLite database file
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})  
Base.metadata.create_all(engine) 

Session = sessionmaker(bind=engine)
    

# LOADING THE MODEL AND TOKENIZER
@st.cache_resource
def load_model():
    model_name = "google/flan-t5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return tokenizer, model
    

tokenizer, model = load_model()

# intent detection setup
intents = {
     "greeting": ["hello", "hi", "hey", "good morning", "good afternoon"],
    "farewell": ["bye", "goodbye", "see you later", "farewell"],
    "general": ["how are you?", "tell me a joke", "what is the weather?","who am i?","do you recognize me?"],
    "about_me": ["what is your name?","tell me about yourself", "who are you","you?","Murphy?"],
    "search_web": ["search for", "find", "what is", "look up", "search the web for", "google", "find information about", "show me", "give me info on", "tell me about"]
}
all_texts = []
all_labels = []

for label, texts in intents.items():
    all_texts.extend(texts)
    all_labels.extend([label]*len(texts))

# vectorizer 

vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(all_texts)


# classsifier

classifier = LogisticRegression()
classifier.fit(X, all_labels)

# intention detection function 
def detect_intent(user_input):
    user_input_vectorized = vectorizer.transform([user_input])
    intent = classifier.predict(user_input_vectorized)[0]
    return intent
# Few-shot examples
few_shot_examples = [
    "User: What is the weather like today?",
    "Chatbot: I'm sorry, I cannot provide real time information.",
    "User: Tell me a joke.",
    "Chatbot: Why don't scientists trust atoms? Because they make up everything!",
]

if 'conversation_history' not in st.session_state:
    st.session_state.conversation_history = []

    # sentiment analysis setup
sentiment_pipeline = pipeline("sentiment-analysis")
def get_sentiment(text):
    result = sentiment_pipeline(text)[0]
    return result["label"], result["score"]

sentiment_threshold = 0.8


# sumarization setup
@st.cache_resource
def load_summarizer():
    return pipeline("summarization", model="facebook/bart-large-cnn")

summarizer = load_summarizer()


# summarizer = pipeline("summarization")

def summarize_history(history):
    text = "\n".join(history)
    summary = summarizer(text, max_length=150, min_length=30, do_sample=False)[0]['summary_text']
    return summary

user_input = st.text_input("You:", key="user_input_1")

# transformer-based intent detection setup
intent_model_name = "distilbert-base-uncased" 
intent_tokenizer = AutoTokenizer.from_pretrained(intent_model_name)
intent_model = AutoModelForSequenceClassification.from_pretrained(intent_model_name, num_labels=len(intents))

# Load the current model into the GPU if available.
# Load the current model into the GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
intent_model.to(device)  # Move the intent detection model to the correct device
model.to(device)  #  move the flan to same device 

def detect_intent_transformer(user_input):
    inputs = intent_tokenizer(user_input, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = intent_model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1).item()
    return list(intents.keys())[predicted_class]


if user_input:
    session = Session()
    try:
        intent = detect_intent_transformer(user_input) #Use the transformer.
        # FOR SENTIMENT
        sentiment_label, sentiment_score = get_sentiment(user_input)
        if intent == "greeting":
            response = "Hello there! I'm Murphy, developed by Mr.Abhishek"
        elif intent == "farewell":
            response = "Goodbye!"
        elif intent == "about_me":
            response = "I'm Murphy, developed by Abhishek, i can learn by myself and i'm feeling good right now, seems like i am alive hehe!"
        # elif intent == "search_web":
        #     response = "Web search is disabled."
        else:
            
            # Check if history needs summarization
            if len(st.session_state.conversation_history) > 10:
                try:
                    summary = summarize_history(st.session_state.conversation_history)
                    st.session_state.conversation_history = [summary]
                except Exception as e:
                    st.write(f"Summarization failed: {e}")
            # construct few-shot prompt
            few_shot_context = "\n".join(few_shot_examples)
            # construct contextual prompt
            context = "\n".join(st.session_state.conversation_history)
            prompt = f"{few_shot_context}\n{context}\nUser: {user_input}\nChatbot:"
            inputs = tokenizer(user_input, return_tensors="pt").to(device)
            outputs = model.generate(** inputs, max_length = 50 , pad_token_id = tokenizer.eos_token_id)
            response = tokenizer.decode(outputs[0], skip_special_tokens = True)

# modifyinh responseds on the sentiment
        if sentiment_score > sentiment_threshold:
            if sentiment_label == "NEGATIVE":
                response = f"I sense you're feeling a bit down. {response}"
            elif sentiment_label == "POSITIVE":
                response = f"I'm glad you're in a good mood! {response}"

        # now i am storing conversations in the database
        conversation = Conversation( user_input = user_input, chatbot_response = response, timestamp = datetime.now())
        session.add(conversation)
        session.commit()


        # displaying the response
        st.write(f"You: {user_input}")
        st.write(f"Murphy:{response}")
  
# update the conversation history 

        st.session_state.conversation_history.append(f"User: {user_input}")
        st.session_state.conversation_history.append(f"Chatbot: {response}")

    except Exception as e:
        st.write(f"An error occurred: {e}")
        session.rollback()
    finally:
        session.close()


# dispaly the conversation from history 
session = Session()
try:
    conversations = session.query(Conversation).all()
    if conversations:
        st.subheader("Conversations History")
        for conv in conversations:
            st.write(f"You: {conv.user_input}")
            st.write(f"Murphy: {conv.chatbot_response}")
except Exception as e:
    st.write(f"Error retrieving conversations: {e}")
finally:
    session.close()