File size: 14,456 Bytes
49f9f52
 
 
 
 
 
 
 
6435ecc
49f9f52
 
6435ecc
49f9f52
6435ecc
49f9f52
6435ecc
49f9f52
 
 
 
 
 
6435ecc
49f9f52
6435ecc
49f9f52
6435ecc
 
49f9f52
6435ecc
 
 
49f9f52
6435ecc
 
49f9f52
6435ecc
49f9f52
 
 
 
6435ecc
 
49f9f52
 
 
 
6435ecc
 
 
 
 
49f9f52
6435ecc
49f9f52
6435ecc
49f9f52
 
 
6435ecc
 
49f9f52
6435ecc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f9f52
 
 
6435ecc
49f9f52
 
 
6435ecc
49f9f52
 
6435ecc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f9f52
 
 
 
6435ecc
 
 
 
 
 
49f9f52
6435ecc
 
 
 
49f9f52
6435ecc
49f9f52
6435ecc
 
 
49f9f52
6435ecc
 
 
 
 
 
 
 
 
 
 
 
49f9f52
 
 
6435ecc
 
 
49f9f52
 
6435ecc
 
 
49f9f52
 
6435ecc
 
 
 
 
 
49f9f52
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import faiss
import numpy as np
import torch
import time
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import pipeline


class CustomerServiceAgent:
    """
    AI Customer Service Agent with RAG + robust off-topic detection.
    """

    def __init__(self):
        print("Initializing IMPROVED Customer Service Agent...")
        self._load_models()
        self._build_knowledge_base()
        print("\nAgent is ready.")

    def _load_models(self):
        """
        Loads all the ML models required for the agent.
        """
        print("\n[1/4] Loading models...")
        device = 0 if torch.cuda.is_available() else -1

        # Embedding model for retrieval
        self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

        # Generation pipelines
        self.moderator_pipeline = pipeline("text2text-generation", model='google/flan-t5-base', device=device)
        self.llm_pipeline = pipeline("text2text-generation", model='google/flan-t5-large', device=device)

        # Sentiment analysis
        self.sentiment_classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)

        print("All models loaded successfully.")

    def _build_knowledge_base(self):
        """
        Loads dataset, chunks it, builds FAISS index with normalized embeddings,
        and prepares optional zero-shot classifier.
        """
        print("\n[2/4] Preparing Knowledge Base...")
        try:
            dataset = load_dataset("MakTek/Customer_support_faqs_dataset", split="train")
            raw_docs = [item for item in dataset['answer'] if item and item.strip()]
            self.knowledge_base = []
            for doc in raw_docs:
                self.knowledge_base.extend(doc.split('\n\n'))
            print(f"Successfully loaded and chunked {len(raw_docs)} documents into {len(self.knowledge_base)} chunks.")
        except Exception as e:
            print(f"Failed to load dataset. Using fallback. Error: {e}")
            self.knowledge_base = [
                "You can update your payment method by going to the 'Billing' section in your account settings. All payment information is encrypted and processed securely over an SSL connection.",
                "To check your order status, please log in to your account and navigate to the 'My Orders' page.",
                "I am very sorry to hear your package has not arrived. Please provide your order number so I can investigate.",
            ]
            print(f"Using fallback KB with {len(self.knowledge_base)} documents.")

        print("\n[3/4] Creating embeddings for the knowledge base...")
        raw_embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
        raw_embeddings = np.array(raw_embeddings).astype('float32')

        # Normalize embeddings for cosine similarity
        norms = np.linalg.norm(raw_embeddings, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10
        self.kb_embeddings = raw_embeddings / norms

        print("\n[4/4] Setting up FAISS cosine similarity index...")
        d = self.kb_embeddings.shape[1]
        self.index = faiss.IndexFlatIP(d)
        self.index.add(self.kb_embeddings)
        print("FAISS retriever ready.")

        # Optional zero-shot classifier
        try:
            self.zero_shot = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
                                      device=0 if torch.cuda.is_available() else -1)
            print("Zero-shot classifier loaded.")
        except Exception:
            self.zero_shot = None
            print("Zero-shot classifier unavailable (skipping).")

    def _rewrite_followup(self, query, history, max_new_tokens=64):
        """
        Rewrite a follow-up query into a standalone question.
        - Uses the llm_pipeline but falls back to a heuristic if the rewrite equals
          the previous user message (which indicates a bad rewrite).
        - Always returns a non-empty string.
        """
        query = query.strip()
        if not history:
            return query

        # Build compact history showing only the last user message and assistant reply
        # (keeps prompt short and focused)
        last_turn = history[-1]
        last_user = last_turn.get('user', '').strip()
        last_assistant = last_turn.get('assistant', '').strip()

        rewrite_prompt = f"""
        Given the following short chat history and a follow-up question, rewrite the follow-up
        question as a single, self-contained question that requires no prior context.
        Return ONLY the rewritten question (no explanation, no punctuation at the end beyond normal).

        Chat history:
        User: {last_user}
        Assistant: {last_assistant}

        Follow-up: {query}

        Standalone question:
        """
        try:
            out = self.llm_pipeline(rewrite_prompt,
                                     max_new_tokens=max_new_tokens,
                                     num_beams=4,
                                     do_sample=False)[0]['generated_text'].strip()
        except Exception as e:
            # If model fails, fallback to simple heuristic
            out = ""

        # Basic sanity checks and fallback:
        #  - if the rewrite is empty or exactly equals the last user question, do heuristic
        #  - if rewrite equals last_user (ignore case & punctuation), fallback
        def norm(s): return "".join(ch for ch in s.lower() if ch.isalnum() or ch.isspace()).strip()
        if not out or norm(out) == norm(last_user):
            # Heuristic: attach the last user question as referent to the follow-up
            # e.g., "Is that process secure?" -> "Is that process secure? (Referring to: How do I change my payment method?)"
            if last_user:
                out = f"{query} (Referring to: {last_user})"
            else:
                out = query

        return out


    def _is_query_on_topic(self,
                           query,
                           allowed_topics=None,
                           similarity_threshold=0.44,   # lowered default
                           top_k=5,
                           use_zero_shot=True,
                           debug=True):
        """
        Robust on-topic detector.

        Combines:
         - embedding best cosine similarity (top-1)
         - mean cosine similarity of top_k
         - zero-shot 'off-topic' probability -> converted to on-topic prob
         - simple keyword whitelist fallback

        Returns True if combined_score >= similarity_threshold.
        """
        if allowed_topics is None:
            allowed_topics = ['billing', 'orders', 'shipping', 'account', 'product issue', 'returns', 'security']

        q = query.strip().lower()
        if len(q) == 0:
            return False

        # Quick keyword whitelist: immediate accept if contains explicit intent words
        keywords = ['payment', 'pay', 'card', 'invoice', 'order', 'tracking', 'shipment', 'ship', 'shipping',
                    'password', 'login', 'signin', 'account', 'refund', 'return', 'cancel', 'billing', 'subscribe',
                    'subscription', 'charge', 'charged', 'security']
        for kw in keywords:
            if kw in q:
                if debug:
                    print(f"[Safeguard-kw] Keyword '{kw}' matched -> ACCEPT")
                return True

        # Embedding-based scores
        q_emb = self.embedding_model.encode([query])
        q_emb = np.array(q_emb).astype('float32')
        q_emb /= (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)

        # Search top_k (IndexFlatIP stored normalized vectors)
        D, I = self.index.search(q_emb, top_k)  # D: inner-products ~ cosine
        d_list = [float(x) for x in D[0] if x is not None]
        if len(d_list) == 0:
            if debug:
                print("[Safeguard] No neighbors returned by FAISS.")
            embedding_best = 0.0
            embedding_mean = 0.0
        else:
            embedding_best = d_list[0]
            embedding_mean = float(sum(d_list) / len(d_list))

        if debug:
            print(f"[Safeguard] embedding_best={embedding_best:.4f}, embedding_mean(top{top_k})={embedding_mean:.4f}")

        # Zero-shot: compute probability of being on-topic = 1 - P(off-topic)
        zs_on_prob = 0.0
        if use_zero_shot and self.zero_shot is not None:
            try:
                candidate_labels = allowed_topics + ["off-topic"]
                zs = self.zero_shot(query, candidate_labels, multi_label=False)
                # find index of 'off-topic' label and its score
                off_idx = zs['labels'].index('off-topic') if 'off-topic' in zs['labels'] else None
                off_score = 0.0
                if off_idx is not None:
                    off_score = float(zs['scores'][off_idx])
                zs_on_prob = 1.0 - off_score
                if debug:
                    print(f"[Safeguard] zero-shot off-topic_score={off_score:.3f} -> on_prob={zs_on_prob:.3f} (top_label='{zs['labels'][0]}')")
            except Exception as e:
                if debug:
                    print(f"[Safeguard] zero-shot failed: {e}")
                zs_on_prob = 0.0

        # Combine signals with weights (tune these if needed)
        # We give embedding_best the most weight, embedding_mean helps stability, zs_on_prob is supportive.
        w_best = 0.55
        w_mean = 0.25
        w_zs = 0.20
        combined_score = (w_best * max(0.0, embedding_best) +
                          w_mean * max(0.0, embedding_mean) +
                          w_zs * max(0.0, zs_on_prob))

        if debug:
            print(f"[Safeguard] combined_score={combined_score:.4f}, threshold={similarity_threshold}")

        return combined_score >= similarity_threshold

    def _retrieve_context(self, query, k=3):
        """
        Retrieves the top-k most relevant chunks from the knowledge base
        based on cosine similarity of sentence embeddings.
        """
        query_embedding = self.embedding_model.encode([query])
        scores, indices = self.index.search(np.array(query_embedding).astype("float32"), k)
        retrieved_docs = [self.knowledge_base[i] for i in indices[0]]
        context = "\n\n".join(retrieved_docs)
        return context

    def get_rag_response(self, query, history, k=3):
        """
        Generates a RAG-based response with safeguards. Uses a robust rewrite-first flow.
        """
        print(f"\nProcessing query: '{query}'")

        # Build chat history text for debug and rewriting
        history_string = "".join([f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in history])

        # 1) Rewrite follow-up into standalone question BEFORE the safeguard
        standalone_query = self._rewrite_followup(query, history)
        print(f"Rewritten query for retrieval & safeguard: '{standalone_query}'")

        # 2) Safeguard check on the standalone query
        if not self._is_query_on_topic(standalone_query, similarity_threshold=0.44, top_k=5, use_zero_shot=True):
            return ("I'm sorry — I can only assist with customer-service related questions "
                    "like billing, orders, shipping, or account issues. Could you rephrase your question?")

        # 3) Sentiment (optional; can be done earlier if you want)
        sentiment = self.sentiment_classifier(standalone_query)[0]['label']
        print(f"Detected Sentiment: {sentiment}")

        # 4) Retrieve context using the standalone query
        context = self._retrieve_context(standalone_query, k=k)

        # 5) Persona and final prompt (use standalone query; forbid echo)
        if sentiment == 'NEGATIVE':
            persona = ("You are an exceptionally empathetic and understanding customer support agent. "
                       "Acknowledge frustration, apologize, and provide the next steps clearly.")
        else:
            persona = ("You are a friendly, efficient, and professional customer support agent. "
                       "Provide clear, concise, and helpful answers.")

        prompt = f"""
        {persona}

        Your role is STRICTLY to be a customer support agent.
        Use only the provided context to answer precise customer-support questions.
        If the answer is not in the context, say you don't know and provide a safe next step (e.g., ask for order number).
        Do NOT repeat the question back in your answer. Return a concise answer of 1-3 sentences.

        Context:
        {context}

        Question: {standalone_query}

        Answer:
        """

        start_time = time.time()
        llm_output = self.llm_pipeline(prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
        response = llm_output[0]['generated_text'].strip()
        print(f"LLM Response Time: {time.time() - start_time:.2f}s")

        # Some models sometimes return the question as the output when confused; guard against that:
        if response.lower().startswith(standalone_query.lower()):
            # If it echoed the question, ask the model one more time with an explicit instruction
            retry_prompt = prompt + "\n(Do NOT repeat the question; give the answer only.)\nAnswer:"
            llm_output = self.llm_pipeline(retry_prompt, max_new_tokens=150, num_beams=4, early_stopping=True)
            response = llm_output[0]['generated_text'].strip()

        return response



# --- Terminal Demo ---
if __name__ == "__main__":
    agent = CustomerServiceAgent()
    conversation_history = []

    print("\n--- Testing  ---")
    query1 = "how do i change my password?"
    response1 = agent.get_rag_response(query1, conversation_history)
    conversation_history.append({'user': query1, 'assistant': response1})
    print(f"\nUser: {query1}\nAgent: {response1}")

    query2 = "my package never arrived."
    response2 = agent.get_rag_response(query2, conversation_history)
    conversation_history.append({'user': query2, 'assistant': response2})
    print(f"\nUser: {query2}\nAgent: {response2}")

    print("\n--- Testing Safeguard (Off-topic) ---")
    query3 = "What's the best recipe for lasagna?"
    response3 = agent.get_rag_response(query3, [])
    print(f"\nUser: {query3}\nAgent: {response3}")

    print("\n--- Demo Complete ---")