Spaces:
Sleeping
Sleeping
| import faiss | |
| import numpy as np | |
| import torch | |
| import time | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| class CustomerServiceAgent: | |
| """ | |
| AI Customer Service Agent with RAG + robust off-topic detection. | |
| """ | |
| def __init__(self): | |
| print("Initializing IMPROVED Customer Service Agent...") | |
| self._load_models() | |
| self._build_knowledge_base() | |
| print("\nAgent is ready.") | |
| def _load_models(self): | |
| """ | |
| Loads all the ML models required for the agent. | |
| """ | |
| print("\n[1/4] Loading models...") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| # Embedding model for retrieval | |
| self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # Generation pipelines | |
| self.moderator_pipeline = pipeline("text2text-generation", model='google/flan-t5-base', device=device) | |
| self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device) | |
| # Sentiment analysis | |
| self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device) | |
| print("All models loaded successfully.") | |
| def _build_knowledge_base(self): | |
| """ | |
| Loads dataset, chunks it, builds FAISS index with normalized embeddings, | |
| and prepares optional zero-shot classifier. | |
| """ | |
| print("\n[2/4] Preparing Knowledge Base...") | |
| try: | |
| dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train") | |
| raw_docs = [item for item in dataset['answer'] if item and item.strip()] | |
| self.knowledge_base = [] | |
| for doc in raw_docs: | |
| self.knowledge_base.extend(doc.split('\n\n')) | |
| print(f"Successfully loaded and chunked {len(raw_docs)} documents into {len(self.knowledge_base)} chunks.") | |
| except Exception as e: | |
| print(f"Failed to load dataset. Using fallback. Error: {e}") | |
| self.knowledge_base = [ | |
| "You can update your payment method by going to the 'Billing' section in your account settings. All payment information is encrypted and processed securely over an SSL connection.", | |
| "To check your order status, please log in to your account and navigate to the 'My Orders' page.", | |
| "I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.", | |
| ] | |
| print(f"Using fallback KB with {len(self.knowledge_base)} documents.") | |
| print("\n[3/4] Creating embeddings for the knowledge base...") | |
| raw_embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True) | |
| raw_embeddings = np.array(raw_embeddings).astype('float32') | |
| # Normalize embeddings for cosine similarity | |
| norms = np.linalg.norm(raw_embeddings, axis=1, keepdims=True) | |
| norms[norms == 0] = 1e-10 | |
| self.kb_embeddings = raw_embeddings / norms | |
| print("\n[4/4] Setting up FAISS cosine similarity index...") | |
| d = self.kb_embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(d) | |
| self.index.add(self.kb_embeddings) | |
| print("FAISS retriever ready.") | |
| # Optional zero-shot classifier | |
| try: | |
| self.zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", | |
| device=0 if torch.cuda.is_available() else -1) | |
| print("Zero-shot classifier loaded.") | |
| except Exception: | |
| self.zero_shot = None | |
| print("Zero-shot classifier unavailable (skipping).") | |
| def _rewrite_followup(self, query, history, max_new_tokens=64): | |
| """ | |
| Rewrite a follow-up query into a standalone question. | |
| - Uses the llm_pipeline but falls back to a heuristic if the rewrite equals | |
| the previous user message (which indicates a bad rewrite). | |
| - Always returns a non-empty string. | |
| """ | |
| query = query.strip() | |
| if not history: | |
| return query | |
| # Build compact history showing only the last user message and assistant reply | |
| # (keeps prompt short and focused) | |
| last_turn = history[-1] | |
| last_user = last_turn.get('user', '').strip() | |
| last_assistant = last_turn.get('assistant', '').strip() | |
| rewrite_prompt = f""" | |
| Given the following short chat history and a follow-up question, rewrite the follow-up | |
| question as a single, self-contained question that requires no prior context. | |
| Return ONLY the rewritten question (no explanation, no punctuation at the end beyond normal). | |
| Chat history: | |
| User: {last_user} | |
| Assistant: {last_assistant} | |
| Follow-up: {query} | |
| Standalone question: | |
| """ | |
| try: | |
| out = self.llm_pipeline(rewrite_prompt, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=4, | |
| do_sample=False)[0]['generated_text'].strip() | |
| except Exception as e: | |
| # If model fails, fallback to simple heuristic | |
| out = "" | |
| # Basic sanity checks and fallback: | |
| # - if the rewrite is empty or exactly equals the last user question, do heuristic | |
| # - if rewrite equals last_user (ignore case & punctuation), fallback | |
| def norm(s): return "".join(ch for ch in s.lower() if ch.isalnum() or ch.isspace()).strip() | |
| if not out or norm(out) == norm(last_user): | |
| # Heuristic: attach the last user question as referent to the follow-up | |
| # e.g., "Is that process secure?" -> "Is that process secure? (Referring to: How do I change my payment method?)" | |
| if last_user: | |
| out = f"{query} (Referring to: {last_user})" | |
| else: | |
| out = query | |
| return out | |
| def _is_query_on_topic(self, | |
| query, | |
| allowed_topics=None, | |
| similarity_threshold=0.44, # lowered default | |
| top_k=5, | |
| use_zero_shot=True, | |
| debug=True): | |
| """ | |
| Robust on-topic detector. | |
| Combines: | |
| - embedding best cosine similarity (top-1) | |
| - mean cosine similarity of top_k | |
| - zero-shot 'off-topic' probability -> converted to on-topic prob | |
| - simple keyword whitelist fallback | |
| Returns True if combined_score >= similarity_threshold. | |
| """ | |
| if allowed_topics is None: | |
| allowed_topics = ['billing', 'orders', 'shipping', 'account', 'product issue', 'returns', 'security'] | |
| q = query.strip().lower() | |
| if len(q) == 0: | |
| return False | |
| # Quick keyword whitelist: immediate accept if contains explicit intent words | |
| keywords = ['payment', 'pay', 'card', 'invoice', 'order', 'tracking', 'shipment', 'ship', 'shipping', | |
| 'password', 'login', 'signin', 'account', 'refund', 'return', 'cancel', 'billing', 'subscribe', | |
| 'subscription', 'charge', 'charged', 'security'] | |
| for kw in keywords: | |
| if kw in q: | |
| if debug: | |
| print(f"[Safeguard-kw] Keyword '{kw}' matched -> ACCEPT") | |
| return True | |
| # Embedding-based scores | |
| q_emb = self.embedding_model.encode([query]) | |
| q_emb = np.array(q_emb).astype('float32') | |
| q_emb /= (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10) | |
| # Search top_k (IndexFlatIP stored normalized vectors) | |
| D, I = self.index.search(q_emb, top_k) # D: inner-products ~ cosine | |
| d_list = [float(x) for x in D[0] if x is not None] | |
| if len(d_list) == 0: | |
| if debug: | |
| print("[Safeguard] No neighbors returned by FAISS.") | |
| embedding_best = 0.0 | |
| embedding_mean = 0.0 | |
| else: | |
| embedding_best = d_list[0] | |
| embedding_mean = float(sum(d_list) / len(d_list)) | |
| if debug: | |
| print(f"[Safeguard] embedding_best={embedding_best:.4f}, embedding_mean(top{top_k})={embedding_mean:.4f}") | |
| # Zero-shot: compute probability of being on-topic = 1 - P(off-topic) | |
| zs_on_prob = 0.0 | |
| if use_zero_shot and self.zero_shot is not None: | |
| try: | |
| candidate_labels = allowed_topics + ["off-topic"] | |
| zs = self.zero_shot(query, candidate_labels, multi_label=False) | |
| # find index of 'off-topic' label and its score | |
| off_idx = zs['labels'].index('off-topic') if 'off-topic' in zs['labels'] else None | |
| off_score = 0.0 | |
| if off_idx is not None: | |
| off_score = float(zs['scores'][off_idx]) | |
| zs_on_prob = 1.0 - off_score | |
| if debug: | |
| print(f"[Safeguard] zero-shot off-topic_score={off_score:.3f} -> on_prob={zs_on_prob:.3f} (top_label='{zs['labels'][0]}')") | |
| except Exception as e: | |
| if debug: | |
| print(f"[Safeguard] zero-shot failed: {e}") | |
| zs_on_prob = 0.0 | |
| # Combine signals with weights (tune these if needed) | |
| # We give embedding_best the most weight, embedding_mean helps stability, zs_on_prob is supportive. | |
| w_best = 0.55 | |
| w_mean = 0.25 | |
| w_zs = 0.20 | |
| combined_score = (w_best * max(0.0, embedding_best) + | |
| w_mean * max(0.0, embedding_mean) + | |
| w_zs * max(0.0, zs_on_prob)) | |
| if debug: | |
| print(f"[Safeguard] combined_score={combined_score:.4f}, threshold={similarity_threshold}") | |
| return combined_score >= similarity_threshold | |
| def _retrieve_context(self, query, k=3): | |
| """ | |
| Retrieves the top-k most relevant chunks from the knowledge base | |
| based on cosine similarity of sentence embeddings. | |
| """ | |
| query_embedding = self.embedding_model.encode([query]) | |
| scores, indices = self.index.search(np.array(query_embedding).astype("float32"), k) | |
| retrieved_docs = [self.knowledge_base[i] for i in indices[0]] | |
| context = "\n\n".join(retrieved_docs) | |
| return context | |
| def get_rag_response(self, query, history, k=3): | |
| """ | |
| Generates a RAG-based response with safeguards. Uses a robust rewrite-first flow. | |
| """ | |
| print(f"\nProcessing query: '{query}'") | |
| # Build chat history text for debug and rewriting | |
| history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history]) | |
| # 1) Rewrite follow-up into standalone question BEFORE the safeguard | |
| standalone_query = self._rewrite_followup(query, history) | |
| print(f"Rewritten query for retrieval & safeguard: '{standalone_query}'") | |
| # 2) Safeguard check on the standalone query | |
| if not self._is_query_on_topic(standalone_query, similarity_threshold=0.44, top_k=5, use_zero_shot=True): | |
| return ("I'm sorry — I can only assist with customer-service related questions " | |
| "like billing, orders, shipping, or account issues. Could you rephrase your question?") | |
| # 3) Sentiment (optional; can be done earlier if you want) | |
| sentiment = self.sentiment_classifier(standalone_query)[0]['label'] | |
| print(f"Detected Sentiment: {sentiment}") | |
| # 4) Retrieve context using the standalone query | |
| context = self._retrieve_context(standalone_query, k=k) | |
| # 5) Persona and final prompt (use standalone query; forbid echo) | |
| if sentiment == 'NEGATIVE': | |
| persona = ("You are an exceptionally empathetic and understanding customer support agent. " | |
| "Acknowledge frustration, apologize, and provide the next steps clearly.") | |
| else: | |
| persona = ("You are a friendly, efficient, and professional customer support agent. " | |
| "Provide clear, concise, and helpful answers.") | |
| prompt = f""" | |
| {persona} | |
| Your role is STRICTLY to be a customer support agent. | |
| Use only the provided context to answer precise customer-support questions. | |
| If the answer is not in the context, say you don't know and provide a safe next step (e.g., ask for order number). | |
| Do NOT repeat the question back in your answer. Return a concise answer of 1-3 sentences. | |
| Context: | |
| {context} | |
| Question: {standalone_query} | |
| Answer: | |
| """ | |
| start_time = time.time() | |
| llm_output = self.llm_pipeline(prompt, max_new_tokens=150, num_beams=4, early_stopping=True) | |
| response = llm_output[0]['generated_text'].strip() | |
| print(f"LLM Response Time: {time.time() - start_time:.2f}s") | |
| # Some models sometimes return the question as the output when confused; guard against that: | |
| if response.lower().startswith(standalone_query.lower()): | |
| # If it echoed the question, ask the model one more time with an explicit instruction | |
| retry_prompt = prompt + "\n(Do NOT repeat the question; give the answer only.)\nAnswer:" | |
| llm_output = self.llm_pipeline(retry_prompt, max_new_tokens=150, num_beams=4, early_stopping=True) | |
| response = llm_output[0]['generated_text'].strip() | |
| return response | |
| # --- Terminal Demo --- | |
| if __name__ == "__main__": | |
| agent = CustomerServiceAgent() | |
| conversation_history = [] | |
| print("\n--- Testing ---") | |
| query1 = "how do i change my password?" | |
| response1 = agent.get_rag_response(query1, conversation_history) | |
| conversation_history.append({'user': query1, 'assistant': response1}) | |
| print(f"\nUser: {query1}\nAgent: {response1}") | |
| query2 = "my package never arrived." | |
| response2 = agent.get_rag_response(query2, conversation_history) | |
| conversation_history.append({'user': query2, 'assistant': response2}) | |
| print(f"\nUser: {query2}\nAgent: {response2}") | |
| print("\n--- Testing Safeguard (Off-topic) ---") | |
| query3 = "What's the best recipe for lasagna?" | |
| response3 = agent.get_rag_response(query3, []) | |
| print(f"\nUser: {query3}\nAgent: {response3}") | |
| print("\n--- Demo Complete ---") |