Asmitha-28 commited on
Commit
d4f1f3e
Β·
verified Β·
1 Parent(s): 34d1ffe

Upload src\ensemble_router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src//ensemble_router.py +482 -0
src//ensemble_router.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/ensemble_router.py
2
+ # SupportMind β€” Ensemble Confidence-Gated Router
3
+ # Combines DistilBERT (MC Dropout) + TF-IDF Logistic Regression
4
+ # for best-in-class accuracy on ticket routing.
5
+ #
6
+ # Strategy: weighted soft-voting on probability distributions
7
+ # final_probs = w_bert * bert_probs + w_sklearn * sklearn_probs
8
+ #
9
+ # Why this beats either model alone:
10
+ # - DistilBERT: captures semantic meaning, handles paraphrases
11
+ # - TF-IDF+LR : captures keyword/n-gram signals, very confident on clear cases
12
+ # - Ensemble : DistilBERT corrects LR on ambiguous tickets,
13
+ # LR corrects BERT on keyword-heavy ones
14
+
15
+ import os
16
+ import gc
17
+ import pickle
18
+ import logging
19
+ import numpy as np
20
+ from typing import Dict, Optional
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ── Category map ────────────────────────────────────────────────────────────
25
+ CATEGORY_MAP = {
26
+ 0: 'billing',
27
+ 1: 'technical_support',
28
+ 2: 'account_management',
29
+ 3: 'feature_request',
30
+ 4: 'compliance_legal',
31
+ 5: 'onboarding',
32
+ 6: 'general_inquiry',
33
+ 7: 'churn_risk',
34
+ }
35
+ CATEGORY_REVERSE = {v: k for k, v in CATEGORY_MAP.items()}
36
+
37
+ # ── Routing thresholds ───────────────────────────────────────────────────────
38
+ ROUTE_THRESHOLD = 0.82 # ensemble conf >= this β†’ auto-route
39
+ CLARIFY_THRESHOLD = 0.58 # ensemble conf >= this β†’ ask 1 question
40
+ ENTROPY_MAX = 0.32 # ensemble entropy <= this β†’ low ambiguity
41
+ MC_PASSES = int(os.getenv('SUPPORTMIND_MC_PASSES', '3')) # CPU demo default
42
+
43
+ # ── Ensemble weights ─────────────────────────────────────────────────────────
44
+ # BERT weight is higher because it generalises better to unseen phrasing.
45
+ # These are tunable β€” increase SKLEARN_W if LR is more accurate on your data.
46
+ # BERT weight is significantly higher because DeBERTa-v3 is extremely robust.
47
+ BERT_W = 0.75
48
+ SKLEARN_W = 0.25
49
+
50
+
51
+ class EnsembleRouter:
52
+ """
53
+ Ensemble Confidence-Gated Router.
54
+
55
+ Combines:
56
+ 1. DistilBERT fine-tuned on support tickets (MC Dropout for uncertainty)
57
+ 2. TF-IDF + Calibrated Logistic Regression baseline
58
+
59
+ Falls back to sklearn-only if DistilBERT model weights are absent.
60
+ Drop-in replacement for ConfidenceGatedRouter β€” same .route() interface.
61
+ """
62
+
63
+ def __init__(self, model_dir: Optional[str] = None, device: str = 'cpu'):
64
+ base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
65
+ ultimate_path = os.path.join(base, 'models', 'deberta_ultimate')
66
+ standard_path = os.path.join(base, 'models', 'ticket_classifier')
67
+
68
+ if model_dir is None:
69
+ if os.path.exists(os.path.join(ultimate_path, 'config.json')):
70
+ self.model_dir = ultimate_path
71
+ else:
72
+ self.model_dir = standard_path
73
+ else:
74
+ self.model_dir = model_dir
75
+
76
+ self._bert_router = None
77
+ self._sklearn_pipe = None
78
+ self._bert_available = False
79
+ self._bert_reason = 'not_loaded'
80
+ self._sklearn_source = 'unknown'
81
+
82
+ # IMPORTANT: Load BERT first and do a warmup pass.
83
+ # On Windows, unpickling sklearn before PyTorch's first forward pass
84
+ # causes a segfault in torch.distributed/optree DLLs.
85
+ self._load_bert(device)
86
+ if self._bert_available:
87
+ self._warmup_bert()
88
+ self._load_sklearn()
89
+
90
+ try:
91
+ from historical_memory import HistoricalMemoryLayer
92
+ self._memory_layer = HistoricalMemoryLayer()
93
+ except Exception as e:
94
+ logger.warning(f"[EnsembleRouter] Could not load Historical Memory Layer: {e}")
95
+ self._memory_layer = None
96
+
97
+ self.model_status = {
98
+ 'mode': 'ensemble_transformer_lr' if self._bert_available else 'sklearn_fallback',
99
+ 'bert_available': self._bert_available,
100
+ 'bert_reason': self._bert_reason,
101
+ 'sklearn_source': self._sklearn_source,
102
+ 'model_dir': os.path.relpath(self.model_dir, base),
103
+ 'memory_available': bool(
104
+ getattr(getattr(self, '_memory_layer', None), 'is_ready', False)
105
+ ),
106
+ }
107
+
108
+ logger.info(
109
+ f"[EnsembleRouter] BERT={'ON' if self._bert_available else 'OFF (fallback)'} | "
110
+ f"sklearn=ON | weights=({BERT_W}/{SKLEARN_W}) | memory={'ON' if getattr(self, '_memory_layer', None) and self._memory_layer.is_ready else 'OFF'}"
111
+ )
112
+
113
+ def _warmup_bert(self):
114
+ """Perform a warmup forward pass to initialize PyTorch/CUDA state."""
115
+ try:
116
+ self._bert_router.mc_predict("warmup", n_passes=1)
117
+ logger.info("[EnsembleRouter] BERT warmup complete.")
118
+ except Exception as e:
119
+ logger.warning(f"[EnsembleRouter] BERT warmup failed: {e}")
120
+
121
+ # ── Model loaders ────────────────────────────────────────────────────────
122
+
123
+ def _load_sklearn(self):
124
+ # Check model_dir first, then fall back to ticket_classifier
125
+ pkl = os.path.join(self.model_dir, 'sklearn_router.pkl')
126
+ if not os.path.exists(pkl):
127
+ base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
128
+ pkl = os.path.join(base, 'models', 'ticket_classifier', 'sklearn_router.pkl')
129
+ if not os.path.exists(pkl):
130
+ logger.warning(
131
+ "[EnsembleRouter] sklearn_router.pkl not found. "
132
+ "Using embedded synthetic fallback model."
133
+ )
134
+ self._sklearn_pipe = self._build_embedded_sklearn()
135
+ self._sklearn_source = 'embedded_synthetic'
136
+ return
137
+ with open(pkl, 'rb') as f:
138
+ self._sklearn_pipe = pickle.load(f)
139
+ base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
140
+ self._sklearn_source = os.path.relpath(pkl, base)
141
+ logger.info(f"[EnsembleRouter] sklearn pipeline loaded from {pkl}.")
142
+
143
+ def _build_embedded_sklearn(self):
144
+ """Build a tiny in-memory classifier so clean clones and CI still run."""
145
+ from sklearn.feature_extraction.text import TfidfVectorizer
146
+ from sklearn.linear_model import LogisticRegression
147
+ from sklearn.pipeline import Pipeline
148
+
149
+ examples = {
150
+ 'billing': [
151
+ 'invoice is wrong', 'refund request', 'payment failed',
152
+ 'billing charge incorrect', 'subscription price changed',
153
+ ],
154
+ 'technical_support': [
155
+ 'api returns 500 error', 'export is broken', 'dashboard crash',
156
+ 'integration timeout', 'feature not working',
157
+ ],
158
+ 'account_management': [
159
+ 'reset password', 'add user account', 'sso login issue',
160
+ 'change admin permission', 'locked out of account',
161
+ ],
162
+ 'feature_request': [
163
+ 'please add dark mode', 'new feature request',
164
+ 'need custom dashboard', 'enhancement idea',
165
+ ],
166
+ 'compliance_legal': [
167
+ 'gdpr data request', 'soc 2 audit report',
168
+ 'data processing agreement', 'privacy compliance',
169
+ ],
170
+ 'onboarding': [
171
+ 'help with setup', 'new user onboarding',
172
+ 'configure integration', 'getting started guide',
173
+ ],
174
+ 'general_inquiry': [
175
+ 'how do i use this', 'pricing question', 'where is documentation',
176
+ 'do you offer a demo',
177
+ ],
178
+ 'churn_risk': [
179
+ 'cancel my account', 'switching to competitor',
180
+ 'very frustrated', 'not renewing contract',
181
+ ],
182
+ }
183
+
184
+ texts, labels = [], []
185
+ for category, samples in examples.items():
186
+ for sample in samples:
187
+ texts.append(sample)
188
+ labels.append(CATEGORY_REVERSE[category])
189
+
190
+ pipeline = Pipeline([
191
+ ('tfidf', TfidfVectorizer(stop_words='english', ngram_range=(1, 2))),
192
+ ('clf', LogisticRegression(class_weight='balanced', max_iter=1000)),
193
+ ])
194
+ pipeline.fit(texts, labels)
195
+ return pipeline
196
+
197
+ def _load_bert(self, device: str):
198
+ """Load transformer router when the runtime is configured for it."""
199
+ disable_transformer = os.getenv('SUPPORTMIND_DISABLE_TRANSFORMER', '0') == '1'
200
+ force_transformer = os.getenv('SUPPORTMIND_FORCE_TRANSFORMER', '0') == '1'
201
+
202
+ if disable_transformer:
203
+ self._bert_reason = 'disabled_by_SUPPORTMIND_DISABLE_TRANSFORMER'
204
+ logger.warning("[EnsembleRouter] Transformer loading disabled by environment.")
205
+ return
206
+
207
+ if os.name == 'nt' and not force_transformer:
208
+ self._bert_reason = 'disabled_on_windows_set_SUPPORTMIND_FORCE_TRANSFORMER_to_enable'
209
+ logger.warning(
210
+ "[EnsembleRouter] Transformer loading disabled on Windows by default "
211
+ "to avoid native PyTorch/safetensors crashes. Set "
212
+ "SUPPORTMIND_FORCE_TRANSFORMER=1 to enable it."
213
+ )
214
+ return
215
+
216
+ import json, traceback as tb
217
+ model_bin = os.path.join(self.model_dir, 'pytorch_model.bin')
218
+ model_safe = os.path.join(self.model_dir, 'model.safetensors')
219
+ config = os.path.join(self.model_dir, 'config.json')
220
+
221
+ bert_ready = os.path.exists(config) and (
222
+ os.path.exists(model_bin) or os.path.exists(model_safe)
223
+ )
224
+
225
+ if not bert_ready:
226
+ self._bert_reason = 'weights_not_found'
227
+ logger.warning(
228
+ "[EnsembleRouter] DistilBERT weights not found β€” running sklearn-only."
229
+ )
230
+ return
231
+
232
+ # Check for stale baseline stub (only present before first real training run)
233
+ try:
234
+ with open(config) as f:
235
+ cfg = json.load(f)
236
+ if cfg.get('model_type') == 'baseline_sklearn':
237
+ self._bert_reason = 'baseline_stub_config'
238
+ logger.warning("[EnsembleRouter] config.json is baseline stub β€” skipping BERT.")
239
+ return
240
+ except Exception:
241
+ pass
242
+
243
+ try:
244
+ from confidence_router import ConfidenceGatedRouter
245
+ self._bert_router = ConfidenceGatedRouter(self.model_dir, device=device)
246
+ self._bert_available = not getattr(self._bert_router, '_fallback_mode', False)
247
+ fallback_reason = getattr(self._bert_router, 'fallback_reason', None)
248
+ self._bert_reason = (
249
+ 'loaded' if self._bert_available
250
+ else f'confidence_router_fallback: {fallback_reason or "unknown"}'
251
+ )
252
+ gc.collect()
253
+ if self._bert_available:
254
+ logger.info(f"[EnsembleRouter] {self._bert_router.model.config.model_type.upper()} loaded successfully.")
255
+ except (Exception, OSError) as e:
256
+ logger.error(f"[EnsembleRouter] BERT load failed (likely memory constraint): {e}")
257
+ # Ensure we don't leave a half-initialized router
258
+ self._bert_router = None
259
+ self._bert_available = False
260
+ self._bert_reason = f'load_failed: {type(e).__name__}'
261
+ gc.collect()
262
+
263
+ # ── Prediction ───────────────────────────────────────────────────────────
264
+
265
+ def _sklearn_probs(self, text: str) -> np.ndarray:
266
+ """Return calibrated probability distribution from sklearn pipeline."""
267
+ return self._sklearn_pipe.predict_proba([text])[0] # shape [8]
268
+
269
+ def _bert_probs(self, text: str) -> np.ndarray:
270
+ """Return MC-Dropout probability distribution from DistilBERT."""
271
+ _, _, _, mean_p, _ = self._bert_router.mc_predict(text, n_passes=MC_PASSES)
272
+ return mean_p # shape [8]
273
+
274
+ def _blend(self, text: str):
275
+ """
276
+ Compute blended probability distribution.
277
+ Returns: (blended_probs, bert_probs_or_None, sklearn_probs, bert_std_or_None)
278
+ """
279
+ sk_probs = self._sklearn_probs(text)
280
+
281
+ if self._bert_available:
282
+ _, _, _, bert_mean, bert_std = self._bert_router.mc_predict(text, MC_PASSES)
283
+ blended = BERT_W * bert_mean + SKLEARN_W * sk_probs
284
+ # Re-normalise (floating point can drift slightly)
285
+ blended = blended / blended.sum()
286
+ return blended, bert_mean, sk_probs, bert_std
287
+ else:
288
+ return sk_probs, None, sk_probs, np.zeros(8)
289
+
290
+ # ── Public API ───────────────────────────────────────────────────────────
291
+
292
+ def route(self, ticket_text: str, n_passes: int = MC_PASSES) -> Dict:
293
+ """
294
+ Route a ticket through the ensemble confidence gate.
295
+ Returns the same dict schema as ConfidenceGatedRouter.route()
296
+ so it is a drop-in replacement in api.py.
297
+ """
298
+ blended, bert_p, sk_p, bert_std = self._blend(ticket_text)
299
+
300
+ confidence = float(blended.max())
301
+ entropy = float(-np.sum(blended * np.log(blended + 1e-9)))
302
+
303
+ # ── Temperature Scaling (T=0.7) ──────────────────────────────────
304
+ # Sharpen probabilities to reduce noise in unrelated classes.
305
+ # logits_scaled = logits / T; softmax(logits_scaled)
306
+ # Since we have probs, we can approximate with power scaling:
307
+ # p_scaled = p^(1/T) / sum(p^(1/T))
308
+ T = 0.7
309
+ blended_sharp = np.power(blended + 1e-9, 1.0 / T)
310
+ blended_sharp = blended_sharp / blended_sharp.sum()
311
+
312
+ # ── Keyword Reinforcement ────────────────────────────────────────
313
+ # If text contains specific strong keywords for a category,
314
+ # give that category a small 'calibration boost'.
315
+ reinforce_map = {
316
+ 'billing': ['invoice', 'refund', 'charge', 'payment', 'billing'],
317
+ 'technical_support': ['error', 'bug', 'crash', '500', 'api', 'broken', 'not working'],
318
+ 'account_management': ['login', 'password', 'reset', 'account', 'permission', 'access', 'sso', 'user'],
319
+ 'feature_request': ['feature', 'add', 'request', 'enhancement', 'dark mode', 'new capability', 'could you add'],
320
+ 'compliance_legal': ['gdpr', 'compliance', 'legal', 'audit', 'privacy'],
321
+ 'churn_risk': ['cancel', 'leaving', 'competitor', 'terminate', 'switching'],
322
+ 'onboarding': ['setup', 'configure', 'getting started', 'new user', 'import'],
323
+ }
324
+ text_low = ticket_text.lower()
325
+ for cat, kws in reinforce_map.items():
326
+ hit_count = sum(1 for kw in kws if kw in text_low)
327
+ if hit_count:
328
+ idx = CATEGORY_REVERSE[cat]
329
+ blended_sharp[idx] *= 1.0 + min(0.45, hit_count * 0.12)
330
+ blended_sharp[idx] += min(0.12, hit_count * 0.03)
331
+
332
+ # Re-normalise after boost
333
+ blended_sharp = blended_sharp / blended_sharp.sum()
334
+
335
+ confidence = float(blended_sharp.max())
336
+ pred_class = int(blended_sharp.argmax())
337
+ category = CATEGORY_MAP[pred_class]
338
+
339
+ # ── Visual Confidence Cap (98.5%) ────────────────────────────────
340
+ # Probabilistic ML should rarely claim 100% certainty.
341
+ display_confidence = min(confidence, 0.985)
342
+
343
+ # Build ranking
344
+ ranking = sorted(
345
+ [(CATEGORY_MAP[i], round(float(blended_sharp[i]), 4)) for i in range(8)],
346
+ key=lambda x: x[1], reverse=True
347
+ )
348
+ top_two = [ranking[0][0], ranking[1][0]]
349
+
350
+ base = {
351
+ 'confidence': round(display_confidence, 4),
352
+ 'raw_confidence': round(confidence, 4),
353
+ 'entropy': round(entropy, 4),
354
+ 'top_category': category,
355
+ 'all_probs': {CATEGORY_MAP[i]: round(float(blended_sharp[i]), 4) for i in range(8)},
356
+ 'std_probs': {CATEGORY_MAP[i]: round(float(bert_std[i]), 4) for i in range(8)},
357
+ 'category_ranking': ranking,
358
+ 'top_two_classes': top_two,
359
+ 'mc_passes': n_passes,
360
+ # Extra ensemble diagnostics
361
+ 'ensemble': {
362
+ 'bert_available': self._bert_available,
363
+ 'bert_top': CATEGORY_MAP[int(bert_p.argmax())] if bert_p is not None else None,
364
+ 'sklearn_top': CATEGORY_MAP[int(sk_p.argmax())],
365
+ 'bert_weight': BERT_W if self._bert_available else 0.0,
366
+ 'sklearn_weight': SKLEARN_W if self._bert_available else 1.0,
367
+ 'agreement': (
368
+ CATEGORY_MAP[int(bert_p.argmax())] == CATEGORY_MAP[int(sk_p.argmax())]
369
+ if bert_p is not None else True
370
+ ),
371
+ }
372
+ }
373
+
374
+ top1_score = ranking[0][1]
375
+ top2_score = ranking[1][1]
376
+ margin = top1_score - top2_score
377
+
378
+ hist_boost = 0.0
379
+ if getattr(self, '_memory_layer', None) and self._memory_layer.is_ready:
380
+ hist_boost = self._memory_layer.compute_historical_boost(ticket_text, category)
381
+ base['historical_boost'] = hist_boost
382
+
383
+ base['margin'] = round(margin, 4)
384
+ base['confidence'] = round(display_confidence, 4)
385
+
386
+ critical_labels = ['compliance_legal', 'account_management']
387
+
388
+ effective_conf = confidence + hist_boost
389
+
390
+ if category in critical_labels:
391
+ if effective_conf >= 0.90 and margin >= 0.35 and entropy < 0.60:
392
+ action = 'route'
393
+ reason = f'β€’ Safe to auto-route sensitive intent<br>β€’ Confidence: {confidence:.2%}<br>β€’ Margin: {margin:.2f}'
394
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>'
395
+ else:
396
+ action = 'escalate'
397
+ reason = f'β€’ Escalated sensitive intent ({category})<br>β€’ Strict confidence/margin threshold not met'
398
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient)'
399
+ elif category == 'technical_support':
400
+ # Category-specific check for technical support to catch billing misroutes
401
+ billing_keywords = ['invoice', 'billing', 'charge', 'refund', 'payment', 'subscription', 'plan']
402
+ has_billing_kw = any(kw in ticket_text.lower() for kw in billing_keywords)
403
+
404
+ if has_billing_kw and 'billing' in [r[0] for r in ranking[:3]]:
405
+ action = 'clarify'
406
+ reason = f'β€’ Billing overlap detected<br>β€’ Clarification needed between technical_support and billing'
407
+ elif effective_conf >= 0.88 and margin >= 0.30 and entropy < 0.65:
408
+ # Stricter thresholds for technical_support
409
+ action = 'route'
410
+ reason = f'β€’ Strong dominant intent<br>β€’ Confidence: {confidence:.2%}<br>β€’ Margin: {margin:.2f}<br>β€’ Safe to auto-route'
411
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>'
412
+ elif effective_conf >= 0.60 and entropy < 1.05:
413
+ action = 'clarify'
414
+ reason = f'β€’ Medium ambiguity detected<br>β€’ Clarification needed between {top_two[0]} and {top_two[1]}<br>β€’ Margin: {margin:.2f}'
415
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient for auto-route)'
416
+ else:
417
+ action = 'escalate'
418
+ reason = f'β€’ High ambiguity / Low confidence ({confidence:.2%})<br>β€’ Multiple overlapping intents detected<br>β€’ Human triage needed'
419
+ else:
420
+ if effective_conf >= 0.85 and margin >= 0.25 and entropy < 0.70:
421
+ action = 'route'
422
+ reason = f'β€’ Strong dominant intent<br>β€’ Confidence: {confidence:.2%}<br>β€’ Margin: {margin:.2f}<br>β€’ Safe to auto-route'
423
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>'
424
+ elif effective_conf >= 0.60 and entropy < 1.05:
425
+ action = 'clarify'
426
+ reason = f'β€’ Medium ambiguity detected<br>β€’ Clarification needed between {top_two[0]} and {top_two[1]}<br>β€’ Margin: {margin:.2f}'
427
+ if hist_boost > 0: reason += f'<br>β€’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient for auto-route)'
428
+ else:
429
+ action = 'escalate'
430
+ reason = f'β€’ High ambiguity / Low confidence ({confidence:.2%})<br>β€’ Multiple overlapping intents detected<br>β€’ Human triage needed'
431
+
432
+ return {**base, 'action': action, 'queue': category if action == 'route' else None, 'reason': reason}
433
+
434
+ def batch_route(self, tickets: list, n_passes: int = MC_PASSES) -> list:
435
+ return [self.route(t, n_passes) for t in tickets]
436
+
437
+ # Property to expose model/tokenizer for the SHAP explainer in api.py
438
+ @property
439
+ def model(self):
440
+ if self._bert_available:
441
+ return self._bert_router.model
442
+ return None
443
+
444
+ @property
445
+ def tokenizer(self):
446
+ if self._bert_available:
447
+ return self._bert_router.tokenizer
448
+ return None
449
+
450
+
451
+ if __name__ == '__main__':
452
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
453
+
454
+ router = EnsembleRouter()
455
+
456
+ tests = [
457
+ "My invoice from last month is incorrect, please fix the billing.",
458
+ "The API keeps returning 500 errors since last Tuesday's update.",
459
+ "I want to cancel β€” this tool has been broken for weeks.",
460
+ "How do I add another user to our account?",
461
+ "We need GDPR data processing agreements for our EU customers.",
462
+ "Not happy at all, considering switching to a competitor.",
463
+ "Can you add a dark mode to the dashboard?",
464
+ "Just signed up β€” how do I import my existing data?",
465
+ # Tricky ambiguous cases
466
+ "Invoice is wrong AND the app keeps crashing.",
467
+ "Not happy with service",
468
+ ]
469
+
470
+ print(f"\n{'='*90}")
471
+ print(f" SupportMind Ensemble Router β€” BERT={'ON' if router._bert_available else 'OFF (sklearn only)'}")
472
+ print(f"{'='*90}\n")
473
+
474
+ for ticket in tests:
475
+ r = router.route(ticket)
476
+ agree = 'AGREE' if r['ensemble']['agreement'] else 'DISAGREE'
477
+ print(
478
+ f"[{r['action'].upper():8s}] [{r['confidence']:.2%}] "
479
+ f"{'H' if r['entropy'] < ENTROPY_MAX else 'L'}-certainty | "
480
+ f"{r['top_category']:20s} | "
481
+ f"Models: {agree} | {ticket[:60]}"
482
+ )