DanielKiani's picture
version 2
6435ecc verified
import faiss
import numpy as np
import torch
import time
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import pipeline
class CustomerServiceAgent:
"""
AI Customer Service Agent with RAG + robust off-topic detection.
"""
def __init__(self):
print("Initializing IMPROVED Customer Service Agent...")
self._load_models()
self._build_knowledge_base()
print("\nAgent is ready.")
def _load_models(self):
"""
Loads all the ML models required for the agent.
"""
print("\n[1/4] Loading models...")
device = 0 if torch.cuda.is_available() else -1
# Embedding model for retrieval
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Generation pipelines
self.moderator_pipeline = pipeline("text2text-generation", model='google/flan-t5-base', device=device)
self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device)
# Sentiment analysis
self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
print("All models loaded successfully.")
def _build_knowledge_base(self):
"""
Loads dataset, chunks it, builds FAISS index with normalized embeddings,
and prepares optional zero-shot classifier.
"""
print("\n[2/4] Preparing Knowledge Base...")
try:
dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train")
raw_docs = [item for item in dataset['answer'] if item and item.strip()]
self.knowledge_base = []
for doc in raw_docs:
self.knowledge_base.extend(doc.split('\n\n'))
print(f"Successfully loaded and chunked {len(raw_docs)} documents into {len(self.knowledge_base)} chunks.")
except Exception as e:
print(f"Failed to load dataset. Using fallback. Error: {e}")
self.knowledge_base = [
"You can update your payment method by going to the 'Billing' section in your account settings. All payment information is encrypted and processed securely over an SSL connection.",
"To check your order status, please log in to your account and navigate to the 'My Orders' page.",
"I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.",
]
print(f"Using fallback KB with {len(self.knowledge_base)} documents.")
print("\n[3/4] Creating embeddings for the knowledge base...")
raw_embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
raw_embeddings = np.array(raw_embeddings).astype('float32')
# Normalize embeddings for cosine similarity
norms = np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
self.kb_embeddings = raw_embeddings / norms
print("\n[4/4] Setting up FAISS cosine similarity index...")
d = self.kb_embeddings.shape[1]
self.index = faiss.IndexFlatIP(d)
self.index.add(self.kb_embeddings)
print("FAISS retriever ready.")
# Optional zero-shot classifier
try:
self.zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
device=0 if torch.cuda.is_available() else -1)
print("Zero-shot classifier loaded.")
except Exception:
self.zero_shot = None
print("Zero-shot classifier unavailable (skipping).")
def _rewrite_followup(self, query, history, max_new_tokens=64):
"""
Rewrite a follow-up query into a standalone question.
- Uses the llm_pipeline but falls back to a heuristic if the rewrite equals
the previous user message (which indicates a bad rewrite).
- Always returns a non-empty string.
"""
query = query.strip()
if not history:
return query
# Build compact history showing only the last user message and assistant reply
# (keeps prompt short and focused)
last_turn = history[-1]
last_user = last_turn.get('user', '').strip()
last_assistant = last_turn.get('assistant', '').strip()
rewrite_prompt = f"""
Given the following short chat history and a follow-up question, rewrite the follow-up
question as a single, self-contained question that requires no prior context.
Return ONLY the rewritten question (no explanation, no punctuation at the end beyond normal).
Chat history:
User: {last_user}
Assistant: {last_assistant}
Follow-up: {query}
Standalone question:
"""
try:
out = self.llm_pipeline(rewrite_prompt,
max_new_tokens=max_new_tokens,
num_beams=4,
do_sample=False)[0]['generated_text'].strip()
except Exception as e:
# If model fails, fallback to simple heuristic
out = ""
# Basic sanity checks and fallback:
# - if the rewrite is empty or exactly equals the last user question, do heuristic
# - if rewrite equals last_user (ignore case & punctuation), fallback
def norm(s): return "".join(ch for ch in s.lower() if ch.isalnum() or ch.isspace()).strip()
if not out or norm(out) == norm(last_user):
# Heuristic: attach the last user question as referent to the follow-up
# e.g., "Is that process secure?" -> "Is that process secure? (Referring to: How do I change my payment method?)"
if last_user:
out = f"{query} (Referring to: {last_user})"
else:
out = query
return out
def _is_query_on_topic(self,
query,
allowed_topics=None,
similarity_threshold=0.44, # lowered default
top_k=5,
use_zero_shot=True,
debug=True):
"""
Robust on-topic detector.
Combines:
- embedding best cosine similarity (top-1)
- mean cosine similarity of top_k
- zero-shot 'off-topic' probability -> converted to on-topic prob
- simple keyword whitelist fallback
Returns True if combined_score >= similarity_threshold.
"""
if allowed_topics is None:
allowed_topics = ['billing', 'orders', 'shipping', 'account', 'product issue', 'returns', 'security']
q = query.strip().lower()
if len(q) == 0:
return False
# Quick keyword whitelist: immediate accept if contains explicit intent words
keywords = ['payment', 'pay', 'card', 'invoice', 'order', 'tracking', 'shipment', 'ship', 'shipping',
'password', 'login', 'signin', 'account', 'refund', 'return', 'cancel', 'billing', 'subscribe',
'subscription', 'charge', 'charged', 'security']
for kw in keywords:
if kw in q:
if debug:
print(f"[Safeguard-kw] Keyword '{kw}' matched -> ACCEPT")
return True
# Embedding-based scores
q_emb = self.embedding_model.encode([query])
q_emb = np.array(q_emb).astype('float32')
q_emb /= (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
# Search top_k (IndexFlatIP stored normalized vectors)
D, I = self.index.search(q_emb, top_k) # D: inner-products ~ cosine
d_list = [float(x) for x in D[0] if x is not None]
if len(d_list) == 0:
if debug:
print("[Safeguard] No neighbors returned by FAISS.")
embedding_best = 0.0
embedding_mean = 0.0
else:
embedding_best = d_list[0]
embedding_mean = float(sum(d_list) / len(d_list))
if debug:
print(f"[Safeguard] embedding_best={embedding_best:.4f}, embedding_mean(top{top_k})={embedding_mean:.4f}")
# Zero-shot: compute probability of being on-topic = 1 - P(off-topic)
zs_on_prob = 0.0
if use_zero_shot and self.zero_shot is not None:
try:
candidate_labels = allowed_topics + ["off-topic"]
zs = self.zero_shot(query, candidate_labels, multi_label=False)
# find index of 'off-topic' label and its score
off_idx = zs['labels'].index('off-topic') if 'off-topic' in zs['labels'] else None
off_score = 0.0
if off_idx is not None:
off_score = float(zs['scores'][off_idx])
zs_on_prob = 1.0 - off_score
if debug:
print(f"[Safeguard] zero-shot off-topic_score={off_score:.3f} -> on_prob={zs_on_prob:.3f} (top_label='{zs['labels'][0]}')")
except Exception as e:
if debug:
print(f"[Safeguard] zero-shot failed: {e}")
zs_on_prob = 0.0
# Combine signals with weights (tune these if needed)
# We give embedding_best the most weight, embedding_mean helps stability, zs_on_prob is supportive.
w_best = 0.55
w_mean = 0.25
w_zs = 0.20
combined_score = (w_best * max(0.0, embedding_best) +
w_mean * max(0.0, embedding_mean) +
w_zs * max(0.0, zs_on_prob))
if debug:
print(f"[Safeguard] combined_score={combined_score:.4f}, threshold={similarity_threshold}")
return combined_score >= similarity_threshold
def _retrieve_context(self, query, k=3):
"""
Retrieves the top-k most relevant chunks from the knowledge base
based on cosine similarity of sentence embeddings.
"""
query_embedding = self.embedding_model.encode([query])
scores, indices = self.index.search(np.array(query_embedding).astype("float32"), k)
retrieved_docs = [self.knowledge_base[i] for i in indices[0]]
context = "\n\n".join(retrieved_docs)
return context
def get_rag_response(self, query, history, k=3):
"""
Generates a RAG-based response with safeguards. Uses a robust rewrite-first flow.
"""
print(f"\nProcessing query: '{query}'")
# Build chat history text for debug and rewriting
history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history])
# 1) Rewrite follow-up into standalone question BEFORE the safeguard
standalone_query = self._rewrite_followup(query, history)
print(f"Rewritten query for retrieval & safeguard: '{standalone_query}'")
# 2) Safeguard check on the standalone query
if not self._is_query_on_topic(standalone_query, similarity_threshold=0.44, top_k=5, use_zero_shot=True):
return ("I'm sorry — I can only assist with customer-service related questions "
"like billing, orders, shipping, or account issues. Could you rephrase your question?")
# 3) Sentiment (optional; can be done earlier if you want)
sentiment = self.sentiment_classifier(standalone_query)[0]['label']
print(f"Detected Sentiment: {sentiment}")
# 4) Retrieve context using the standalone query
context = self._retrieve_context(standalone_query, k=k)
# 5) Persona and final prompt (use standalone query; forbid echo)
if sentiment == 'NEGATIVE':
persona = ("You are an exceptionally empathetic and understanding customer support agent. "
"Acknowledge frustration, apologize, and provide the next steps clearly.")
else:
persona = ("You are a friendly, efficient, and professional customer support agent. "
"Provide clear, concise, and helpful answers.")
prompt = f"""
{persona}
Your role is STRICTLY to be a customer support agent.
Use only the provided context to answer precise customer-support questions.
If the answer is not in the context, say you don't know and provide a safe next step (e.g., ask for order number).
Do NOT repeat the question back in your answer. Return a concise answer of 1-3 sentences.
Context:
{context}
Question: {standalone_query}
Answer:
"""
start_time = time.time()
llm_output = self.llm_pipeline(prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
response = llm_output[0]['generated_text'].strip()
print(f"LLM Response Time: {time.time() - start_time:.2f}s")
# Some models sometimes return the question as the output when confused; guard against that:
if response.lower().startswith(standalone_query.lower()):
# If it echoed the question, ask the model one more time with an explicit instruction
retry_prompt = prompt + "\n(Do NOT repeat the question; give the answer only.)\nAnswer:"
llm_output = self.llm_pipeline(retry_prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
response = llm_output[0]['generated_text'].strip()
return response
# --- Terminal Demo ---
if __name__ == "__main__":
agent = CustomerServiceAgent()
conversation_history = []
print("\n--- Testing ---")
query1 = "how do i change my password?"
response1 = agent.get_rag_response(query1, conversation_history)
conversation_history.append({'user': query1, 'assistant': response1})
print(f"\nUser: {query1}\nAgent: {response1}")
query2 = "my package never arrived."
response2 = agent.get_rag_response(query2, conversation_history)
conversation_history.append({'user': query2, 'assistant': response2})
print(f"\nUser: {query2}\nAgent: {response2}")
print("\n--- Testing Safeguard (Off-topic) ---")
query3 = "What's the best recipe for lasagna?"
response3 = agent.get_rag_response(query3, [])
print(f"\nUser: {query3}\nAgent: {response3}")
print("\n--- Demo Complete ---")