DanielKiani commited on
Commit
6435ecc
·
verified ·
1 Parent(s): 8fb7841
Files changed (1) hide show
  1. scripts/agent.py +263 -60
scripts/agent.py CHANGED
@@ -6,117 +6,320 @@ from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import pipeline
8
 
 
9
  class CustomerServiceAgent:
10
  """
11
- Encapsulates all the functionality of the AI customer service agent,
12
- including model loading, knowledge base preparation, and response generation.
13
  """
 
14
  def __init__(self):
15
- """
16
- Initializes the agent by loading all necessary models and building the
17
- retrieval-augmented generation (RAG) knowledge base.
18
- """
19
- print("Initializing Customer Service Agent...")
20
  self._load_models()
21
  self._build_knowledge_base()
22
  print("\nAgent is ready.")
23
 
24
  def _load_models(self):
25
  """
26
- Loads all the machine learning models required for the agent to function.
27
  """
28
- print("\n[1/4] Loading all models...")
29
  device = 0 if torch.cuda.is_available() else -1
30
-
 
31
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
32
  self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device)
 
 
33
  self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
34
-
35
  print("All models loaded successfully.")
36
 
37
  def _build_knowledge_base(self):
38
  """
39
- Prepares the knowledge base for the RAG system by loading FAQs,
40
- creating vector embeddings, and storing them in a FAISS index.
41
  """
42
  print("\n[2/4] Preparing Knowledge Base...")
43
  try:
44
  dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train")
45
- self.knowledge_base = [item for item in dataset['answer'] if item and item.strip()]
46
- print(f"Successfully loaded {len(self.knowledge_base)} documents.")
 
 
 
47
  except Exception as e:
48
- print(f"Failed to load dataset. Using a fallback. Error: {e}")
49
  self.knowledge_base = [
50
- "You can update your payment method by going to the 'Billing' section in your account settings.",
51
  "To check your order status, please log in to your account and navigate to the 'My Orders' page.",
52
  "I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.",
53
  ]
54
-
 
55
  print("\n[3/4] Creating embeddings for the knowledge base...")
56
- embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
57
-
58
- print("\n[4/4] Setting up FAISS vector index...")
59
- self.index = faiss.IndexFlatL2(embeddings.shape[1])
60
- self.index.add(np.array(embeddings))
61
- print("FAISS retriever is ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def get_rag_response(self, query, history, k=3):
64
  """
65
- Generates a response using the Retrieval-Augmented Generation (RAG) pipeline.
66
  """
67
  print(f"\nProcessing query: '{query}'")
68
-
69
- sentiment = self.sentiment_classifier(query)[0]['label']
70
- print(f"Detected Sentiment: {sentiment}")
71
 
 
72
  history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history])
73
 
74
- query_embedding = self.embedding_model.encode([query])
75
- _, indices = self.index.search(np.array(query_embedding), k)
76
- context = "\n\n".join([self.knowledge_base[i] for i in indices[0]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- persona = "You are an empathetic customer support agent." if sentiment == 'NEGATIVE' else "You are a helpful customer support agent."
79
-
80
  prompt = f"""
81
  {persona}
82
- Based on history and context, answer the user's question.
83
 
84
- ### History:
85
- {history_string}
86
- ### Context:
 
 
 
87
  {context}
88
- ### Question:
89
- {query}
90
- ### Answer:
 
91
  """
92
-
93
  start_time = time.time()
94
- response = self.llm_pipeline(prompt, max_new_tokens=100, num_beams=5, early_stopping=True)[0]['generated_text']
95
- print(f"LLM Response Time: {time.time() - start_time:.2f} seconds")
96
-
97
- return response.strip()
98
 
99
- # --- Terminal-based Demo ---
 
 
 
 
 
 
 
 
 
 
 
100
  if __name__ == "__main__":
101
  agent = CustomerServiceAgent()
102
  conversation_history = []
103
-
104
- print("\n--- Starting Terminal Demo ---")
105
-
106
- # First query
107
- query1 = "This is so frustrating, my package never arrived!"
108
  response1 = agent.get_rag_response(query1, conversation_history)
109
  conversation_history.append({'user': query1, 'assistant': response1})
110
-
111
- print(f"\nUser: {query1}")
112
- print(f"Agent: {response1}")
113
-
114
- # Follow-up query to test memory
115
- query2 = "Okay, what do you need from me to find it?"
116
  response2 = agent.get_rag_response(query2, conversation_history)
117
  conversation_history.append({'user': query2, 'assistant': response2})
 
 
 
 
 
 
118
 
119
- print(f"\nUser: {query2}")
120
- print(f"Agent: {response2}")
121
-
122
  print("\n--- Demo Complete ---")
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import pipeline
8
 
9
+
10
  class CustomerServiceAgent:
11
  """
12
+ AI Customer Service Agent with RAG + robust off-topic detection.
 
13
  """
14
+
15
  def __init__(self):
16
+ print("Initializing IMPROVED Customer Service Agent...")
 
 
 
 
17
  self._load_models()
18
  self._build_knowledge_base()
19
  print("\nAgent is ready.")
20
 
21
  def _load_models(self):
22
  """
23
+ Loads all the ML models required for the agent.
24
  """
25
+ print("\n[1/4] Loading models...")
26
  device = 0 if torch.cuda.is_available() else -1
27
+
28
+ # Embedding model for retrieval
29
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
30
+
31
+ # Generation pipelines
32
+ self.moderator_pipeline = pipeline("text2text-generation", model='google/flan-t5-base', device=device)
33
  self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device)
34
+
35
+ # Sentiment analysis
36
  self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
37
+
38
  print("All models loaded successfully.")
39
 
40
  def _build_knowledge_base(self):
41
  """
42
+ Loads dataset, chunks it, builds FAISS index with normalized embeddings,
43
+ and prepares optional zero-shot classifier.
44
  """
45
  print("\n[2/4] Preparing Knowledge Base...")
46
  try:
47
  dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train")
48
+ raw_docs = [item for item in dataset['answer'] if item and item.strip()]
49
+ self.knowledge_base = []
50
+ for doc in raw_docs:
51
+ self.knowledge_base.extend(doc.split('\n\n'))
52
+ print(f"Successfully loaded and chunked {len(raw_docs)} documents into {len(self.knowledge_base)} chunks.")
53
  except Exception as e:
54
+ print(f"Failed to load dataset. Using fallback. Error: {e}")
55
  self.knowledge_base = [
56
+ "You can update your payment method by going to the 'Billing' section in your account settings. All payment information is encrypted and processed securely over an SSL connection.",
57
  "To check your order status, please log in to your account and navigate to the 'My Orders' page.",
58
  "I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.",
59
  ]
60
+ print(f"Using fallback KB with {len(self.knowledge_base)} documents.")
61
+
62
  print("\n[3/4] Creating embeddings for the knowledge base...")
63
+ raw_embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
64
+ raw_embeddings = np.array(raw_embeddings).astype('float32')
65
+
66
+ # Normalize embeddings for cosine similarity
67
+ norms = np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
68
+ norms[norms == 0] = 1e-10
69
+ self.kb_embeddings = raw_embeddings / norms
70
+
71
+ print("\n[4/4] Setting up FAISS cosine similarity index...")
72
+ d = self.kb_embeddings.shape[1]
73
+ self.index = faiss.IndexFlatIP(d)
74
+ self.index.add(self.kb_embeddings)
75
+ print("FAISS retriever ready.")
76
+
77
+ # Optional zero-shot classifier
78
+ try:
79
+ self.zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
80
+ device=0 if torch.cuda.is_available() else -1)
81
+ print("Zero-shot classifier loaded.")
82
+ except Exception:
83
+ self.zero_shot = None
84
+ print("Zero-shot classifier unavailable (skipping).")
85
+
86
+ def _rewrite_followup(self, query, history, max_new_tokens=64):
87
+ """
88
+ Rewrite a follow-up query into a standalone question.
89
+ - Uses the llm_pipeline but falls back to a heuristic if the rewrite equals
90
+ the previous user message (which indicates a bad rewrite).
91
+ - Always returns a non-empty string.
92
+ """
93
+ query = query.strip()
94
+ if not history:
95
+ return query
96
+
97
+ # Build compact history showing only the last user message and assistant reply
98
+ # (keeps prompt short and focused)
99
+ last_turn = history[-1]
100
+ last_user = last_turn.get('user', '').strip()
101
+ last_assistant = last_turn.get('assistant', '').strip()
102
+
103
+ rewrite_prompt = f"""
104
+ Given the following short chat history and a follow-up question, rewrite the follow-up
105
+ question as a single, self-contained question that requires no prior context.
106
+ Return ONLY the rewritten question (no explanation, no punctuation at the end beyond normal).
107
+
108
+ Chat history:
109
+ User: {last_user}
110
+ Assistant: {last_assistant}
111
+
112
+ Follow-up: {query}
113
+
114
+ Standalone question:
115
+ """
116
+ try:
117
+ out = self.llm_pipeline(rewrite_prompt,
118
+ max_new_tokens=max_new_tokens,
119
+ num_beams=4,
120
+ do_sample=False)[0]['generated_text'].strip()
121
+ except Exception as e:
122
+ # If model fails, fallback to simple heuristic
123
+ out = ""
124
+
125
+ # Basic sanity checks and fallback:
126
+ # - if the rewrite is empty or exactly equals the last user question, do heuristic
127
+ # - if rewrite equals last_user (ignore case & punctuation), fallback
128
+ def norm(s): return "".join(ch for ch in s.lower() if ch.isalnum() or ch.isspace()).strip()
129
+ if not out or norm(out) == norm(last_user):
130
+ # Heuristic: attach the last user question as referent to the follow-up
131
+ # e.g., "Is that process secure?" -> "Is that process secure? (Referring to: How do I change my payment method?)"
132
+ if last_user:
133
+ out = f"{query} (Referring to: {last_user})"
134
+ else:
135
+ out = query
136
+
137
+ return out
138
+
139
+
140
+ def _is_query_on_topic(self,
141
+ query,
142
+ allowed_topics=None,
143
+ similarity_threshold=0.44, # lowered default
144
+ top_k=5,
145
+ use_zero_shot=True,
146
+ debug=True):
147
+ """
148
+ Robust on-topic detector.
149
+
150
+ Combines:
151
+ - embedding best cosine similarity (top-1)
152
+ - mean cosine similarity of top_k
153
+ - zero-shot 'off-topic' probability -> converted to on-topic prob
154
+ - simple keyword whitelist fallback
155
+
156
+ Returns True if combined_score >= similarity_threshold.
157
+ """
158
+ if allowed_topics is None:
159
+ allowed_topics = ['billing', 'orders', 'shipping', 'account', 'product issue', 'returns', 'security']
160
+
161
+ q = query.strip().lower()
162
+ if len(q) == 0:
163
+ return False
164
+
165
+ # Quick keyword whitelist: immediate accept if contains explicit intent words
166
+ keywords = ['payment', 'pay', 'card', 'invoice', 'order', 'tracking', 'shipment', 'ship', 'shipping',
167
+ 'password', 'login', 'signin', 'account', 'refund', 'return', 'cancel', 'billing', 'subscribe',
168
+ 'subscription', 'charge', 'charged', 'security']
169
+ for kw in keywords:
170
+ if kw in q:
171
+ if debug:
172
+ print(f"[Safeguard-kw] Keyword '{kw}' matched -> ACCEPT")
173
+ return True
174
+
175
+ # Embedding-based scores
176
+ q_emb = self.embedding_model.encode([query])
177
+ q_emb = np.array(q_emb).astype('float32')
178
+ q_emb /= (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
179
+
180
+ # Search top_k (IndexFlatIP stored normalized vectors)
181
+ D, I = self.index.search(q_emb, top_k) # D: inner-products ~ cosine
182
+ d_list = [float(x) for x in D[0] if x is not None]
183
+ if len(d_list) == 0:
184
+ if debug:
185
+ print("[Safeguard] No neighbors returned by FAISS.")
186
+ embedding_best = 0.0
187
+ embedding_mean = 0.0
188
+ else:
189
+ embedding_best = d_list[0]
190
+ embedding_mean = float(sum(d_list) / len(d_list))
191
+
192
+ if debug:
193
+ print(f"[Safeguard] embedding_best={embedding_best:.4f}, embedding_mean(top{top_k})={embedding_mean:.4f}")
194
+
195
+ # Zero-shot: compute probability of being on-topic = 1 - P(off-topic)
196
+ zs_on_prob = 0.0
197
+ if use_zero_shot and self.zero_shot is not None:
198
+ try:
199
+ candidate_labels = allowed_topics + ["off-topic"]
200
+ zs = self.zero_shot(query, candidate_labels, multi_label=False)
201
+ # find index of 'off-topic' label and its score
202
+ off_idx = zs['labels'].index('off-topic') if 'off-topic' in zs['labels'] else None
203
+ off_score = 0.0
204
+ if off_idx is not None:
205
+ off_score = float(zs['scores'][off_idx])
206
+ zs_on_prob = 1.0 - off_score
207
+ if debug:
208
+ print(f"[Safeguard] zero-shot off-topic_score={off_score:.3f} -> on_prob={zs_on_prob:.3f} (top_label='{zs['labels'][0]}')")
209
+ except Exception as e:
210
+ if debug:
211
+ print(f"[Safeguard] zero-shot failed: {e}")
212
+ zs_on_prob = 0.0
213
+
214
+ # Combine signals with weights (tune these if needed)
215
+ # We give embedding_best the most weight, embedding_mean helps stability, zs_on_prob is supportive.
216
+ w_best = 0.55
217
+ w_mean = 0.25
218
+ w_zs = 0.20
219
+ combined_score = (w_best * max(0.0, embedding_best) +
220
+ w_mean * max(0.0, embedding_mean) +
221
+ w_zs * max(0.0, zs_on_prob))
222
+
223
+ if debug:
224
+ print(f"[Safeguard] combined_score={combined_score:.4f}, threshold={similarity_threshold}")
225
+
226
+ return combined_score >= similarity_threshold
227
+
228
+ def _retrieve_context(self, query, k=3):
229
+ """
230
+ Retrieves the top-k most relevant chunks from the knowledge base
231
+ based on cosine similarity of sentence embeddings.
232
+ """
233
+ query_embedding = self.embedding_model.encode([query])
234
+ scores, indices = self.index.search(np.array(query_embedding).astype("float32"), k)
235
+ retrieved_docs = [self.knowledge_base[i] for i in indices[0]]
236
+ context = "\n\n".join(retrieved_docs)
237
+ return context
238
 
239
  def get_rag_response(self, query, history, k=3):
240
  """
241
+ Generates a RAG-based response with safeguards. Uses a robust rewrite-first flow.
242
  """
243
  print(f"\nProcessing query: '{query}'")
 
 
 
244
 
245
+ # Build chat history text for debug and rewriting
246
  history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history])
247
 
248
+ # 1) Rewrite follow-up into standalone question BEFORE the safeguard
249
+ standalone_query = self._rewrite_followup(query, history)
250
+ print(f"Rewritten query for retrieval & safeguard: '{standalone_query}'")
251
+
252
+ # 2) Safeguard check on the standalone query
253
+ if not self._is_query_on_topic(standalone_query, similarity_threshold=0.44, top_k=5, use_zero_shot=True):
254
+ return ("I'm sorry — I can only assist with customer-service related questions "
255
+ "like billing, orders, shipping, or account issues. Could you rephrase your question?")
256
+
257
+ # 3) Sentiment (optional; can be done earlier if you want)
258
+ sentiment = self.sentiment_classifier(standalone_query)[0]['label']
259
+ print(f"Detected Sentiment: {sentiment}")
260
+
261
+ # 4) Retrieve context using the standalone query
262
+ context = self._retrieve_context(standalone_query, k=k)
263
+
264
+ # 5) Persona and final prompt (use standalone query; forbid echo)
265
+ if sentiment == 'NEGATIVE':
266
+ persona = ("You are an exceptionally empathetic and understanding customer support agent. "
267
+ "Acknowledge frustration, apologize, and provide the next steps clearly.")
268
+ else:
269
+ persona = ("You are a friendly, efficient, and professional customer support agent. "
270
+ "Provide clear, concise, and helpful answers.")
271
 
 
 
272
  prompt = f"""
273
  {persona}
 
274
 
275
+ Your role is STRICTLY to be a customer support agent.
276
+ Use only the provided context to answer precise customer-support questions.
277
+ If the answer is not in the context, say you don't know and provide a safe next step (e.g., ask for order number).
278
+ Do NOT repeat the question back in your answer. Return a concise answer of 1-3 sentences.
279
+
280
+ Context:
281
  {context}
282
+
283
+ Question: {standalone_query}
284
+
285
+ Answer:
286
  """
287
+
288
  start_time = time.time()
289
+ llm_output = self.llm_pipeline(prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
290
+ response = llm_output[0]['generated_text'].strip()
291
+ print(f"LLM Response Time: {time.time() - start_time:.2f}s")
 
292
 
293
+ # Some models sometimes return the question as the output when confused; guard against that:
294
+ if response.lower().startswith(standalone_query.lower()):
295
+ # If it echoed the question, ask the model one more time with an explicit instruction
296
+ retry_prompt = prompt + "\n(Do NOT repeat the question; give the answer only.)\nAnswer:"
297
+ llm_output = self.llm_pipeline(retry_prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
298
+ response = llm_output[0]['generated_text'].strip()
299
+
300
+ return response
301
+
302
+
303
+
304
+ # --- Terminal Demo ---
305
  if __name__ == "__main__":
306
  agent = CustomerServiceAgent()
307
  conversation_history = []
308
+
309
+ print("\n--- Testing ---")
310
+ query1 = "how do i change my password?"
 
 
311
  response1 = agent.get_rag_response(query1, conversation_history)
312
  conversation_history.append({'user': query1, 'assistant': response1})
313
+ print(f"\nUser: {query1}\nAgent: {response1}")
314
+
315
+ query2 = "my package never arrived."
 
 
 
316
  response2 = agent.get_rag_response(query2, conversation_history)
317
  conversation_history.append({'user': query2, 'assistant': response2})
318
+ print(f"\nUser: {query2}\nAgent: {response2}")
319
+
320
+ print("\n--- Testing Safeguard (Off-topic) ---")
321
+ query3 = "What's the best recipe for lasagna?"
322
+ response3 = agent.get_rag_response(query3, [])
323
+ print(f"\nUser: {query3}\nAgent: {response3}")
324
 
 
 
 
325
  print("\n--- Demo Complete ---")