Spaces:
Sleeping
Sleeping
File size: 14,456 Bytes
49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 6435ecc 49f9f52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
import faiss
import numpy as np
import torch
import time
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import pipeline
class CustomerServiceAgent:
"""
AI Customer Service Agent with RAG + robust off-topic detection.
"""
def __init__(self):
print("Initializing IMPROVED Customer Service Agent...")
self._load_models()
self._build_knowledge_base()
print("\nAgent is ready.")
def _load_models(self):
"""
Loads all the ML models required for the agent.
"""
print("\n[1/4] Loading models...")
device = 0 if torch.cuda.is_available() else -1
# Embedding model for retrieval
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Generation pipelines
self.moderator_pipeline = pipeline("text2text-generation", model='google/flan-t5-base', device=device)
self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device)
# Sentiment analysis
self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
print("All models loaded successfully.")
def _build_knowledge_base(self):
"""
Loads dataset, chunks it, builds FAISS index with normalized embeddings,
and prepares optional zero-shot classifier.
"""
print("\n[2/4] Preparing Knowledge Base...")
try:
dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train")
raw_docs = [item for item in dataset['answer'] if item and item.strip()]
self.knowledge_base = []
for doc in raw_docs:
self.knowledge_base.extend(doc.split('\n\n'))
print(f"Successfully loaded and chunked {len(raw_docs)} documents into {len(self.knowledge_base)} chunks.")
except Exception as e:
print(f"Failed to load dataset. Using fallback. Error: {e}")
self.knowledge_base = [
"You can update your payment method by going to the 'Billing' section in your account settings. All payment information is encrypted and processed securely over an SSL connection.",
"To check your order status, please log in to your account and navigate to the 'My Orders' page.",
"I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.",
]
print(f"Using fallback KB with {len(self.knowledge_base)} documents.")
print("\n[3/4] Creating embeddings for the knowledge base...")
raw_embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
raw_embeddings = np.array(raw_embeddings).astype('float32')
# Normalize embeddings for cosine similarity
norms = np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
self.kb_embeddings = raw_embeddings / norms
print("\n[4/4] Setting up FAISS cosine similarity index...")
d = self.kb_embeddings.shape[1]
self.index = faiss.IndexFlatIP(d)
self.index.add(self.kb_embeddings)
print("FAISS retriever ready.")
# Optional zero-shot classifier
try:
self.zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
device=0 if torch.cuda.is_available() else -1)
print("Zero-shot classifier loaded.")
except Exception:
self.zero_shot = None
print("Zero-shot classifier unavailable (skipping).")
def _rewrite_followup(self, query, history, max_new_tokens=64):
"""
Rewrite a follow-up query into a standalone question.
- Uses the llm_pipeline but falls back to a heuristic if the rewrite equals
the previous user message (which indicates a bad rewrite).
- Always returns a non-empty string.
"""
query = query.strip()
if not history:
return query
# Build compact history showing only the last user message and assistant reply
# (keeps prompt short and focused)
last_turn = history[-1]
last_user = last_turn.get('user', '').strip()
last_assistant = last_turn.get('assistant', '').strip()
rewrite_prompt = f"""
Given the following short chat history and a follow-up question, rewrite the follow-up
question as a single, self-contained question that requires no prior context.
Return ONLY the rewritten question (no explanation, no punctuation at the end beyond normal).
Chat history:
User: {last_user}
Assistant: {last_assistant}
Follow-up: {query}
Standalone question:
"""
try:
out = self.llm_pipeline(rewrite_prompt,
max_new_tokens=max_new_tokens,
num_beams=4,
do_sample=False)[0]['generated_text'].strip()
except Exception as e:
# If model fails, fallback to simple heuristic
out = ""
# Basic sanity checks and fallback:
# - if the rewrite is empty or exactly equals the last user question, do heuristic
# - if rewrite equals last_user (ignore case & punctuation), fallback
def norm(s): return "".join(ch for ch in s.lower() if ch.isalnum() or ch.isspace()).strip()
if not out or norm(out) == norm(last_user):
# Heuristic: attach the last user question as referent to the follow-up
# e.g., "Is that process secure?" -> "Is that process secure? (Referring to: How do I change my payment method?)"
if last_user:
out = f"{query} (Referring to: {last_user})"
else:
out = query
return out
def _is_query_on_topic(self,
query,
allowed_topics=None,
similarity_threshold=0.44, # lowered default
top_k=5,
use_zero_shot=True,
debug=True):
"""
Robust on-topic detector.
Combines:
- embedding best cosine similarity (top-1)
- mean cosine similarity of top_k
- zero-shot 'off-topic' probability -> converted to on-topic prob
- simple keyword whitelist fallback
Returns True if combined_score >= similarity_threshold.
"""
if allowed_topics is None:
allowed_topics = ['billing', 'orders', 'shipping', 'account', 'product issue', 'returns', 'security']
q = query.strip().lower()
if len(q) == 0:
return False
# Quick keyword whitelist: immediate accept if contains explicit intent words
keywords = ['payment', 'pay', 'card', 'invoice', 'order', 'tracking', 'shipment', 'ship', 'shipping',
'password', 'login', 'signin', 'account', 'refund', 'return', 'cancel', 'billing', 'subscribe',
'subscription', 'charge', 'charged', 'security']
for kw in keywords:
if kw in q:
if debug:
print(f"[Safeguard-kw] Keyword '{kw}' matched -> ACCEPT")
return True
# Embedding-based scores
q_emb = self.embedding_model.encode([query])
q_emb = np.array(q_emb).astype('float32')
q_emb /= (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
# Search top_k (IndexFlatIP stored normalized vectors)
D, I = self.index.search(q_emb, top_k) # D: inner-products ~ cosine
d_list = [float(x) for x in D[0] if x is not None]
if len(d_list) == 0:
if debug:
print("[Safeguard] No neighbors returned by FAISS.")
embedding_best = 0.0
embedding_mean = 0.0
else:
embedding_best = d_list[0]
embedding_mean = float(sum(d_list) / len(d_list))
if debug:
print(f"[Safeguard] embedding_best={embedding_best:.4f}, embedding_mean(top{top_k})={embedding_mean:.4f}")
# Zero-shot: compute probability of being on-topic = 1 - P(off-topic)
zs_on_prob = 0.0
if use_zero_shot and self.zero_shot is not None:
try:
candidate_labels = allowed_topics + ["off-topic"]
zs = self.zero_shot(query, candidate_labels, multi_label=False)
# find index of 'off-topic' label and its score
off_idx = zs['labels'].index('off-topic') if 'off-topic' in zs['labels'] else None
off_score = 0.0
if off_idx is not None:
off_score = float(zs['scores'][off_idx])
zs_on_prob = 1.0 - off_score
if debug:
print(f"[Safeguard] zero-shot off-topic_score={off_score:.3f} -> on_prob={zs_on_prob:.3f} (top_label='{zs['labels'][0]}')")
except Exception as e:
if debug:
print(f"[Safeguard] zero-shot failed: {e}")
zs_on_prob = 0.0
# Combine signals with weights (tune these if needed)
# We give embedding_best the most weight, embedding_mean helps stability, zs_on_prob is supportive.
w_best = 0.55
w_mean = 0.25
w_zs = 0.20
combined_score = (w_best * max(0.0, embedding_best) +
w_mean * max(0.0, embedding_mean) +
w_zs * max(0.0, zs_on_prob))
if debug:
print(f"[Safeguard] combined_score={combined_score:.4f}, threshold={similarity_threshold}")
return combined_score >= similarity_threshold
def _retrieve_context(self, query, k=3):
"""
Retrieves the top-k most relevant chunks from the knowledge base
based on cosine similarity of sentence embeddings.
"""
query_embedding = self.embedding_model.encode([query])
scores, indices = self.index.search(np.array(query_embedding).astype("float32"), k)
retrieved_docs = [self.knowledge_base[i] for i in indices[0]]
context = "\n\n".join(retrieved_docs)
return context
def get_rag_response(self, query, history, k=3):
"""
Generates a RAG-based response with safeguards. Uses a robust rewrite-first flow.
"""
print(f"\nProcessing query: '{query}'")
# Build chat history text for debug and rewriting
history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history])
# 1) Rewrite follow-up into standalone question BEFORE the safeguard
standalone_query = self._rewrite_followup(query, history)
print(f"Rewritten query for retrieval & safeguard: '{standalone_query}'")
# 2) Safeguard check on the standalone query
if not self._is_query_on_topic(standalone_query, similarity_threshold=0.44, top_k=5, use_zero_shot=True):
return ("I'm sorry — I can only assist with customer-service related questions "
"like billing, orders, shipping, or account issues. Could you rephrase your question?")
# 3) Sentiment (optional; can be done earlier if you want)
sentiment = self.sentiment_classifier(standalone_query)[0]['label']
print(f"Detected Sentiment: {sentiment}")
# 4) Retrieve context using the standalone query
context = self._retrieve_context(standalone_query, k=k)
# 5) Persona and final prompt (use standalone query; forbid echo)
if sentiment == 'NEGATIVE':
persona = ("You are an exceptionally empathetic and understanding customer support agent. "
"Acknowledge frustration, apologize, and provide the next steps clearly.")
else:
persona = ("You are a friendly, efficient, and professional customer support agent. "
"Provide clear, concise, and helpful answers.")
prompt = f"""
{persona}
Your role is STRICTLY to be a customer support agent.
Use only the provided context to answer precise customer-support questions.
If the answer is not in the context, say you don't know and provide a safe next step (e.g., ask for order number).
Do NOT repeat the question back in your answer. Return a concise answer of 1-3 sentences.
Context:
{context}
Question: {standalone_query}
Answer:
"""
start_time = time.time()
llm_output = self.llm_pipeline(prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
response = llm_output[0]['generated_text'].strip()
print(f"LLM Response Time: {time.time() - start_time:.2f}s")
# Some models sometimes return the question as the output when confused; guard against that:
if response.lower().startswith(standalone_query.lower()):
# If it echoed the question, ask the model one more time with an explicit instruction
retry_prompt = prompt + "\n(Do NOT repeat the question; give the answer only.)\nAnswer:"
llm_output = self.llm_pipeline(retry_prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
response = llm_output[0]['generated_text'].strip()
return response
# --- Terminal Demo ---
if __name__ == "__main__":
agent = CustomerServiceAgent()
conversation_history = []
print("\n--- Testing ---")
query1 = "how do i change my password?"
response1 = agent.get_rag_response(query1, conversation_history)
conversation_history.append({'user': query1, 'assistant': response1})
print(f"\nUser: {query1}\nAgent: {response1}")
query2 = "my package never arrived."
response2 = agent.get_rag_response(query2, conversation_history)
conversation_history.append({'user': query2, 'assistant': response2})
print(f"\nUser: {query2}\nAgent: {response2}")
print("\n--- Testing Safeguard (Off-topic) ---")
query3 = "What's the best recipe for lasagna?"
response3 = agent.get_rag_response(query3, [])
print(f"\nUser: {query3}\nAgent: {response3}")
print("\n--- Demo Complete ---") |