Toadoum commited on
Commit
49b1566
·
verified ·
1 Parent(s): ae6619f

Update nlu.py

Browse files
Files changed (1) hide show
  1. nlu.py +262 -255
nlu.py CHANGED
@@ -1,32 +1,29 @@
1
  """
2
- NLU — NLLB + Qwen pivot-through-English architecture with keyword fast-path.
3
-
4
- Flow:
5
- 1. Deterministic structural extractors run FIRST on the original Hausa
6
- text (digits, amounts, yes/no keywords). These MUST be deterministic
7
- because "1234" "provide_digits" with digits="1234" is non-negotiable
8
- for banks, and regex is faster + more reliable than any model for
9
- this sub-task.
10
-
11
- 2. Keyword fast-path for common Hausa + English intent phrases. Matches
12
- "check balance", "duba ma'auni", "canjin kuɗi", etc. in <10ms without
13
- loading any model. This is what real voice bots use for 90% of turns.
14
-
15
- 3. If structural + keyword layers don't match, the text is translated
16
- Hausa English via NLLB-200 (skipped if input is already English),
17
- then classified by Qwen2.5-1.5B in English (where it is strong) into
18
- one of a small fixed set of intent labels.
19
-
20
- 4. If NLLB or Qwen fails, we return "unknown" cleanly — the dialogue
21
- manager routes to a vertical-specific fallback prompt.
22
-
23
- All heavy models are lazy-loaded on first use. Cold-start downloads:
24
- - NLLB-200-distilled-600M: ~2.4 GB
25
- - Qwen2.5-1.5B-Instruct: ~3 GB
26
  """
27
  from __future__ import annotations
28
  import re
29
- import json
30
  import logging
31
  from typing import Optional
32
 
@@ -48,11 +45,9 @@ WORD_AMOUNTS = {
48
  "ɗari": 100, "dari": 100,
49
  }
50
 
51
- # Hausa yes/no keywords for the sole case where we short-circuit Qwen
52
  HAUSA_YES = {"i", "eh", "haka ne", "haka", "ok", "okay", "yes"}
53
  HAUSA_NO = {"a'a", "a'aa", "ba haka", "ba", "no"}
54
 
55
- # Human-agent escape hatch
56
  HUMAN_KEYWORDS = {"mutum", "wakili", "agent", "human"}
57
 
58
 
@@ -92,10 +87,9 @@ def _contains_human_keyword(text: str) -> bool:
92
  return any(kw in t for kw in HUMAN_KEYWORDS)
93
 
94
 
95
- # Keyword fast-path for common intents. Runs BEFORE NLLB+Qwen so that the
96
- # scripted demo flows don't require a 6GB LLM load. Phrases are Hausa and
97
- # English pairs that customers actually use. When none match, we fall
98
- # through to NLLB+Qwen for paraphrases.
99
  INTENT_KEYWORDS = {
100
  "check_balance": [
101
  "duba ma'auni", "ma'auni", "balance", "check balance",
@@ -136,10 +130,7 @@ INTENT_KEYWORDS = {
136
 
137
 
138
  def _match_intent_keyword(text: str) -> Optional[str]:
139
- """Keyword fast-path for common customer-service intents.
140
- Returns the intent name if a keyword matches, else None."""
141
  t = text.lower().strip()
142
- # Check longer phrases first so "check balance" wins over "check order"
143
  all_kw = [(intent, kw) for intent, kws in INTENT_KEYWORDS.items() for kw in kws]
144
  all_kw.sort(key=lambda x: len(x[1]), reverse=True)
145
  for intent, kw in all_kw:
@@ -148,204 +139,231 @@ def _match_intent_keyword(text: str) -> Optional[str]:
148
  return None
149
 
150
 
151
- def _looks_english(text: str) -> bool:
152
- """Heuristic: if text contains no Hausa-specific characters and is majority
153
- ASCII, treat as English and skip NLLB translation. Hausa uses ɓ, ɗ, ƙ, ƴ
154
- and the apostrophe in 'a'a', 'ma'auni', 'jumma'a' etc."""
155
- hausa_chars = set("ɓɗƙƴƁƊƘƳ")
156
- if any(c in hausa_chars for c in text):
157
- return False
158
- # Common Hausa words — if any match, treat as Hausa
159
- hausa_markers = {
160
- "duba", "ma'auni", "toshe", "kati", "canjin", "kuɗi", "kudi",
161
- "saya", "airtime", "bundle", "korafi", "bincika", "oda",
162
- "sake", "tsara", "mayar", "kaya", "wakili", "mutum",
163
- "sannu", "nagode", "don", "allah", "ka", "yana", "tana",
164
- "dubu", "ɗari", "dari", "biyar", "biyu", "uku", "hudu", "huɗu",
165
- }
166
- tokens = set(text.lower().split())
167
- return not bool(tokens & hausa_markers)
168
-
169
-
170
  # ---------------------------------------------------------------------------
171
- # NLLB-200 Ha En translation (lazy-loaded)
 
 
 
 
172
  # ---------------------------------------------------------------------------
173
- _nllb_model = None
174
- _nllb_tokenizer = None
175
- _nllb_failed = False
176
-
177
-
178
- def _load_nllb():
179
- """Lazy-load NLLB-200-distilled-600M."""
180
- global _nllb_model, _nllb_tokenizer, _nllb_failed
181
- if _nllb_failed:
182
- return None, None
183
- if _nllb_model is not None:
184
- return _nllb_model, _nllb_tokenizer
185
- try:
186
- import torch
187
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
188
- logger.info("Loading NLLB-200-distilled-600M…")
189
- model_id = "facebook/nllb-200-distilled-600M"
190
- _nllb_tokenizer = AutoTokenizer.from_pretrained(model_id)
191
- _nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
192
- model_id,
193
- torch_dtype=torch.float32,
194
- low_cpu_mem_usage=True,
195
- )
196
- _nllb_model.eval()
197
- logger.info("NLLB-200 ready.")
198
- return _nllb_model, _nllb_tokenizer
199
- except Exception as e:
200
- logger.warning(f"NLLB load failed: {e}")
201
- _nllb_failed = True
202
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
- def translate_ha_to_en(text: str) -> Optional[str]:
206
- """Translate Hausa to English via NLLB. Returns None on failure."""
207
- model, tokenizer = _load_nllb()
208
- if model is None or not text.strip():
209
- return None
210
- try:
211
- import torch
212
- # NLLB requires source language token set on tokenizer
213
- tokenizer.src_lang = "hau_Latn"
214
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
215
- # Force English output via forced_bos_token_id
216
- forced_bos_id = tokenizer.convert_tokens_to_ids("eng_Latn")
217
- with torch.no_grad():
218
- out = model.generate(
219
- **inputs,
220
- forced_bos_token_id=forced_bos_id,
221
- max_new_tokens=128,
222
- num_beams=2,
223
- )
224
- translated = tokenizer.batch_decode(out, skip_special_tokens=True)[0].strip()
225
- logger.info(f"NLLB Ha→En: {text!r} → {translated!r}")
226
- return translated
227
- except Exception as e:
228
- logger.warning(f"NLLB translate failed: {e}")
229
- return None
230
 
231
 
232
  # ---------------------------------------------------------------------------
233
- # Qwen2.5-1.5B intent classifier (operates on English text)
234
  # ---------------------------------------------------------------------------
235
- _llm_model = None
236
- _llm_tokenizer = None
237
- _llm_failed = False
238
 
239
 
240
- def _load_llm():
241
- global _llm_model, _llm_tokenizer, _llm_failed
242
- if _llm_failed:
243
- return None, None
244
- if _llm_model is not None:
245
- return _llm_model, _llm_tokenizer
 
246
  try:
247
- import torch
248
- from transformers import AutoModelForCausalLM, AutoTokenizer
249
- logger.info("Loading Qwen2.5-1.5B-Instruct…")
250
- model_id = "Qwen/Qwen2.5-1.5B-Instruct"
251
- _llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
252
- _llm_model = AutoModelForCausalLM.from_pretrained(
253
- model_id,
254
- torch_dtype=torch.float32,
255
- low_cpu_mem_usage=True,
256
- )
257
- _llm_model.eval()
258
- logger.info("Qwen2.5-1.5B ready.")
259
- return _llm_model, _llm_tokenizer
 
 
260
  except Exception as e:
261
- logger.warning(f"Qwen load failed: {e}")
262
- _llm_failed = True
263
- return None, None
264
-
265
-
266
- CANDIDATE_INTENTS = {
267
- None: ["check_balance", "block_card", "transfer_money",
268
- "buy_airtime", "buy_bundle", "complaint",
269
- "check_order", "reschedule", "return_item",
270
- "human_agent", "unknown"],
271
- "intent": ["check_balance", "block_card", "transfer_money",
272
- "buy_airtime", "buy_bundle", "complaint",
273
- "check_order", "reschedule", "return_item",
274
- "human_agent", "unknown"],
275
- "yesno": ["yes", "no", "human_agent", "unknown"],
276
- "name": ["provide_name", "human_agent", "unknown"],
277
- "date": ["provide_date", "human_agent", "unknown"],
278
- "bundle": ["provide_bundle", "human_agent", "unknown"],
279
- "text": ["provide_text", "human_agent", "unknown"],
280
- }
281
-
282
-
283
- SYSTEM_PROMPT = """You are an intent classifier for a customer-service voice bot.
284
-
285
- You will be given an English-language utterance (translated from Hausa) and a list of candidate intents. Return JSON with the single best-matching intent and any entities you can extract.
286
-
287
- Intent meanings:
288
- - check_balance: user wants to check an account balance
289
- - block_card: user wants to block, freeze, or cancel a bank card
290
- - transfer_money: user wants to send or transfer money
291
- - buy_airtime: user wants to buy phone airtime / top-up
292
- - buy_bundle: user wants to buy a data bundle / internet package
293
- - complaint: user wants to file a complaint or report a problem
294
- - check_order: user wants to check the status of an order
295
- - reschedule: user wants to reschedule a delivery
296
- - return_item: user wants to return an item
297
- - human_agent: user wants to speak to a human person
298
- - yes / no: affirmative or negative reply
299
- - provide_name / provide_date / provide_bundle / provide_text: user is supplying information
300
- - unknown: cannot determine intent
301
-
302
- Return ONLY valid JSON. No explanation, no markdown. Example: {"intent": "check_balance", "entities": {}}"""
303
 
304
 
305
- def _qwen_classify(english_text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]:
306
- """Classify an English utterance into an intent. Returns None on failure."""
307
- model, tokenizer = _load_llm()
308
- if model is None:
 
309
  return None
310
-
311
- candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None])
312
- user_prompt = (
313
- f'Utterance: "{english_text}"\n'
314
- f'Candidate intents: {", ".join(candidates)}\n\n'
315
- 'Return JSON only.'
316
- )
317
- messages = [
318
- {"role": "system", "content": SYSTEM_PROMPT},
319
- {"role": "user", "content": user_prompt},
320
- ]
321
  try:
322
- import torch
323
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
324
- inputs = tokenizer(prompt, return_tensors="pt")
325
- with torch.no_grad():
326
- out = model.generate(
327
- **inputs,
328
- max_new_tokens=60,
329
- do_sample=False,
330
- pad_token_id=tokenizer.eos_token_id,
331
- )
332
- generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
333
- logger.info(f"Qwen raw: {generated}")
334
-
335
- m = re.search(r"\{.*?\}", generated, re.DOTALL)
336
- if not m:
337
- return None
338
- parsed = json.loads(m.group())
339
- intent = parsed.get("intent", "unknown")
340
- entities = parsed.get("entities", {}) or {}
341
- if not isinstance(entities, dict):
342
- entities = {}
343
- if intent not in candidates:
344
- logger.info(f"Qwen returned out-of-candidate intent: {intent}")
345
- return None
346
- return intent, entities
347
  except Exception as e:
348
- logger.warning(f"Qwen inference failed: {e}")
349
  return None
350
 
351
 
@@ -355,23 +373,26 @@ def _qwen_classify(english_text: str, expected: Optional[str]) -> Optional[tuple
355
  def parse(text: str, expected: Optional[str] = None,
356
  use_llm: bool = True) -> tuple[str, dict, str]:
357
  """
358
- NLU. Returns (intent, entities, source) where source is one of:
359
- - 'structural': deterministic extractor caught digits/amount/yes-no
360
- - 'keyword': fast-path keyword matcher caught a common intent
361
- - 'qwen_en': input was English, classified directly by Qwen
362
- - 'nllb+qwen': translated via NLLB then classified via Qwen
363
- - 'human_keyword': caught human-agent escape hatch by keyword
364
- - 'unknown': nothing matched
 
 
 
365
  """
366
  entities: dict = {}
367
  if not text or not text.strip():
368
  return "unknown", entities, "unknown"
369
 
370
- # Always-on human-agent escape (safety)
371
  if _contains_human_keyword(text):
372
  return "human_agent", entities, "human_keyword"
373
 
374
- # Layer 1: deterministic structural extractors for strict-format slots
375
  if expected == "digits":
376
  d = _extract_digits(text)
377
  if d:
@@ -390,7 +411,6 @@ def parse(text: str, expected: Optional[str] = None,
390
  return yn, entities, "structural"
391
 
392
  if expected == "name":
393
- # Name is free-form; take the last token as a quick heuristic.
394
  name = text.strip().split()[-1] if text.strip() else ""
395
  if name:
396
  entities["name"] = name
@@ -400,51 +420,38 @@ def parse(text: str, expected: Optional[str] = None,
400
  entities["date"] = text.strip()
401
  return "provide_date", entities, "structural"
402
 
403
- # Layer 1.5: Keyword fast-path for common intents (Hausa + English).
404
- # Runs in ANY state so users can pivot intent mid-flow ("actually I want
405
- # to transfer money instead"). Structural extractors above already
406
- # claimed strict-slot cases, so if we're in a slot-filling state and
407
- # the text didn't match the slot, it's fair game to re-interpret as a
408
- # new intent.
409
  kw_intent = _match_intent_keyword(text)
410
  if kw_intent:
411
- logger.info(f"NLU: keyword matched {text!r} → {kw_intent}")
412
  return kw_intent, entities, "keyword"
413
 
414
- # Layer 2: NLLB Ha → En (skip if input already English), then Qwen
415
  if not use_llm:
416
  logger.info(f"NLU: use_llm=False, returning unknown for {text!r}")
417
  return "unknown", entities, "unknown"
418
 
419
- if _looks_english(text):
420
- logger.info(f"NLU: input looks English, skipping NLLB: {text!r}")
421
- english_text = text
422
- source_tag = "qwen_en"
423
- else:
424
- logger.info(f"NLU: translating Hausa via NLLB: {text!r}")
425
- english_text = translate_ha_to_en(text)
426
- if english_text is None:
427
- logger.warning("NLU: NLLB failed, returning unknown")
428
- return "unknown", entities, "unknown"
429
- source_tag = "nllb+qwen"
430
-
431
- qwen_result = _qwen_classify(english_text, expected)
432
- if qwen_result is None:
433
- logger.warning(f"NLU: Qwen returned no valid intent for {english_text!r}")
434
  return "unknown", entities, "unknown"
435
 
436
- intent, llm_entities = qwen_result
437
- logger.info(f"NLU: Qwen classified {english_text!r} → intent={intent}")
 
 
 
438
 
439
- # For free-text slots, pass the original Hausa text through
440
  if expected == "bundle":
441
  t = text.lower()
442
  for b in ("rana", "mako", "wata"):
443
  if b in t:
444
- llm_entities["bundle"] = b
445
  break
446
-
447
  if expected == "text":
448
- llm_entities["text"] = text.strip()
449
 
450
- return intent, llm_entities, source_tag
 
 
1
  """
2
+ NLU — Embedding similarity architecture.
3
+ =========================================
4
+ Replaces the legacy NLLB+Qwen pipeline (preserved in nlu_legacy.py).
5
+
6
+ Why embeddings?
7
+ - Latency: ~200ms vs ~10s on CPU for the legacy stack
8
+ - Memory: ~420MB vs ~8GB
9
+ - Hausa coverage: paraphrase-multilingual-MiniLM-L12-v2 was trained on 50+
10
+ languages including Hausa, so we no longer need a translation step
11
+ - Confidence comes for free: cosine similarity IS a calibrated confidence
12
+
13
+ Pipeline (in order):
14
+ Layer 0: Human-keyword escape ("wakili", "agent") → always wins
15
+ Layer 1: Structural extractors (digits, amounts, yes/no, name, date)
16
+ when the dialogue state has expected_slot set
17
+ Layer 1.5: Keyword fast-path for ultra-common phrases ("duba ma'auni")
18
+ sub-millisecond, no model call
19
+ Layer 2: Sentence-embedding similarity vs per-intent centroids
20
+ cosine sim threshold (0.4) that intent, else unknown
21
+
22
+ The dialogue manager receives the same (intent, entities, source) tuple
23
+ as before, so app.py needs no changes.
 
 
24
  """
25
  from __future__ import annotations
26
  import re
 
27
  import logging
28
  from typing import Optional
29
 
 
45
  "ɗari": 100, "dari": 100,
46
  }
47
 
 
48
  HAUSA_YES = {"i", "eh", "haka ne", "haka", "ok", "okay", "yes"}
49
  HAUSA_NO = {"a'a", "a'aa", "ba haka", "ba", "no"}
50
 
 
51
  HUMAN_KEYWORDS = {"mutum", "wakili", "agent", "human"}
52
 
53
 
 
87
  return any(kw in t for kw in HUMAN_KEYWORDS)
88
 
89
 
90
+ # ---------------------------------------------------------------------------
91
+ # Keyword fast-path instant matches for common scripted phrases
92
+ # ---------------------------------------------------------------------------
 
93
  INTENT_KEYWORDS = {
94
  "check_balance": [
95
  "duba ma'auni", "ma'auni", "balance", "check balance",
 
130
 
131
 
132
  def _match_intent_keyword(text: str) -> Optional[str]:
 
 
133
  t = text.lower().strip()
 
134
  all_kw = [(intent, kw) for intent, kws in INTENT_KEYWORDS.items() for kw in kws]
135
  all_kw.sort(key=lambda x: len(x[1]), reverse=True)
136
  for intent, kw in all_kw:
 
139
  return None
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # ---------------------------------------------------------------------------
143
+ # Intent example dataset the heart of the embedding NLU.
144
+ # These phrases are encoded once into centroids; at inference, user input is
145
+ # compared (cosine similarity) against each centroid. More examples = better
146
+ # coverage of paraphrases. Hausa + English mixed deliberately so cross-lingual
147
+ # matches work via the multilingual encoder.
148
  # ---------------------------------------------------------------------------
149
+ INTENT_EXAMPLES = {
150
+ "check_balance": [
151
+ # Hausa
152
+ "duba ma'auni",
153
+ "ina son sanin kuɗin asusuna",
154
+ "nawa ne a asusuna",
155
+ "menene ma'aunin asusuna",
156
+ "yi mini bayanin asusuna",
157
+ "ina son ganin kuɗina",
158
+ # English
159
+ "check my balance",
160
+ "what is my account balance",
161
+ "how much money do I have",
162
+ "show me my balance",
163
+ "tell me my balance",
164
+ "how much is in my account",
165
+ ],
166
+ "block_card": [
167
+ "toshe kati",
168
+ "ina son toshe katina",
169
+ "ɓatar da kati na",
170
+ "katina ya ɓace",
171
+ "yi mini taimako, kati na ya ɓace",
172
+ "in toshe ATM card",
173
+ "block my card",
174
+ "I lost my card",
175
+ "freeze my debit card",
176
+ "I need to cancel my card",
177
+ "my card was stolen",
178
+ "please block my ATM card",
179
+ ],
180
+ "transfer_money": [
181
+ "canjin kuɗi",
182
+ "ina son aika kuɗi",
183
+ "tura kuɗi zuwa wani",
184
+ "yi canji",
185
+ "in turawa abokina kuɗi",
186
+ "aiki kuɗi ga abokina",
187
+ "transfer money",
188
+ "send money to someone",
189
+ "I want to make a transfer",
190
+ "wire money to my friend",
191
+ "send naira to another account",
192
+ "make a payment",
193
+ ],
194
+ "buy_airtime": [
195
+ "saya airtime",
196
+ "ina son saya airtime",
197
+ "kunna waya",
198
+ "in saya credit",
199
+ "saya credit na waya",
200
+ "recharge waya na",
201
+ "buy airtime",
202
+ "top up my phone",
203
+ "recharge my phone",
204
+ "I need airtime",
205
+ "load credit",
206
+ "add credit to my phone",
207
+ ],
208
+ "buy_bundle": [
209
+ "saya bundle",
210
+ "ina son saya data",
211
+ "kunna intanet",
212
+ "in saya data bundle",
213
+ "saya megabyte",
214
+ "buy data",
215
+ "buy internet bundle",
216
+ "I want a data plan",
217
+ "purchase data bundle",
218
+ "get me a megabyte plan",
219
+ "subscribe to data",
220
+ "renew my data",
221
+ ],
222
+ "complaint": [
223
+ "yin korafi",
224
+ "ina da matsala",
225
+ "in yi koka",
226
+ "akwai matsala da hidima",
227
+ "ina son in kawo matsala",
228
+ "ba na gamsuwa",
229
+ "I want to file a complaint",
230
+ "I have a problem",
231
+ "report an issue",
232
+ "something is wrong",
233
+ "the service is bad",
234
+ "I'm not satisfied",
235
+ ],
236
+ "check_order": [
237
+ "bincika oda",
238
+ "ina oda na yake",
239
+ "tabbatar oda",
240
+ "yaushe za a kawo oda na",
241
+ "in san halin oda na",
242
+ "track order",
243
+ "where is my order",
244
+ "check order status",
245
+ "when will my order arrive",
246
+ "is my order ready",
247
+ "I want to know about my order",
248
+ ],
249
+ "reschedule": [
250
+ "sake tsara",
251
+ "ina son sake tsara lokaci",
252
+ "canjin ranar isar",
253
+ "in canza ranar kawowa",
254
+ "rana ta dabam",
255
+ "reschedule delivery",
256
+ "change delivery date",
257
+ "I want a different day",
258
+ "deliver tomorrow instead",
259
+ "postpone the delivery",
260
+ "move the delivery to later",
261
+ ],
262
+ "return_item": [
263
+ "mayar da kaya",
264
+ "ina son mayar da kaya",
265
+ "ba na son kaya",
266
+ "ina son mayarwa",
267
+ "kaya ba shi da kyau",
268
+ "return this item",
269
+ "I want to return my order",
270
+ "send it back",
271
+ "I want a refund",
272
+ "I don't want this anymore",
273
+ "the item is broken",
274
+ ],
275
+ "human_agent": [
276
+ "ina son magana da mutum",
277
+ "ka kawo mutum",
278
+ "wakili",
279
+ "in yi magana da wakilin",
280
+ "ba zan iya da bot ba",
281
+ "I want to speak to a human",
282
+ "connect me to an agent",
283
+ "transfer me to a person",
284
+ "I need to talk to someone",
285
+ "real person please",
286
+ "agent please",
287
+ ],
288
+ }
289
 
290
 
291
+ # Confidence threshold: cosine similarities below this become 'unknown'.
292
+ # Tuned by hand at 0.4; lower if too many things are routed to 'unknown',
293
+ # raise if too many incorrect intents get through. See nlu/tests for the
294
+ # validation methodology.
295
+ CONFIDENCE_THRESHOLD = 0.4
296
+
297
+ # Embedding model. Multilingual (50+ languages), 420MB, CPU-fast.
298
+ EMBEDDING_MODEL_ID = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
  # ---------------------------------------------------------------------------
302
+ # Embedding model + centroid cache (lazy-loaded)
303
  # ---------------------------------------------------------------------------
304
+ _encoder = None
305
+ _intent_centroids: Optional[dict] = None # intent_name -> np.ndarray
306
+ _embed_failed = False
307
 
308
 
309
+ def _load_encoder():
310
+ """Lazy-load the sentence encoder + compute intent centroids."""
311
+ global _encoder, _intent_centroids, _embed_failed
312
+ if _embed_failed:
313
+ return None
314
+ if _encoder is not None:
315
+ return _encoder
316
  try:
317
+ import numpy as np
318
+ from sentence_transformers import SentenceTransformer
319
+ logger.info(f"Loading embedding model {EMBEDDING_MODEL_ID}…")
320
+ _encoder = SentenceTransformer(EMBEDDING_MODEL_ID)
321
+ logger.info("Computing intent centroids…")
322
+ _intent_centroids = {}
323
+ for intent, phrases in INTENT_EXAMPLES.items():
324
+ # normalize_embeddings=True ⇒ unit vectors ⇒ dot product = cosine sim
325
+ embeddings = _encoder.encode(phrases, normalize_embeddings=True)
326
+ centroid = embeddings.mean(axis=0)
327
+ # Re-normalize the centroid so cosine math stays clean
328
+ centroid = centroid / np.linalg.norm(centroid)
329
+ _intent_centroids[intent] = centroid
330
+ logger.info(f"Encoder ready, {len(_intent_centroids)} intents.")
331
+ return _encoder
332
  except Exception as e:
333
+ logger.warning(f"Encoder load failed: {e}")
334
+ _embed_failed = True
335
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
 
338
+ def _classify_with_embedding(text: str, expected: Optional[str]) -> Optional[tuple[str, float]]:
339
+ """Cosine similarity vs intent centroids. Returns (intent, confidence)
340
+ or None on failure. Respects expected_slot if it constrains valid intents."""
341
+ encoder = _load_encoder()
342
+ if encoder is None or _intent_centroids is None:
343
  return None
 
 
 
 
 
 
 
 
 
 
 
344
  try:
345
+ import numpy as np
346
+ query = encoder.encode(text, normalize_embeddings=True)
347
+
348
+ # If expected_slot constrains the answer space, filter candidates.
349
+ # For 'yesno', embedding NLU shouldn't fire — yes/no is handled by
350
+ # the structural layer. If we get here with yesno expected, it means
351
+ # the user said something non-standard; we treat that as a possible
352
+ # intent pivot (any intent is fair game).
353
+ valid_intents = list(_intent_centroids.keys())
354
+
355
+ scores = {}
356
+ for intent in valid_intents:
357
+ centroid = _intent_centroids[intent]
358
+ scores[intent] = float(np.dot(query, centroid))
359
+
360
+ best_intent = max(scores, key=scores.get)
361
+ best_score = scores[best_intent]
362
+ logger.info(f"NLU embedding: top match {best_intent}@{best_score:.3f}, "
363
+ f"all scores: { {k: round(v,3) for k,v in sorted(scores.items(), key=lambda x: -x[1])[:3]} }")
364
+ return best_intent, best_score
 
 
 
 
 
365
  except Exception as e:
366
+ logger.warning(f"Embedding classification failed: {e}")
367
  return None
368
 
369
 
 
373
  def parse(text: str, expected: Optional[str] = None,
374
  use_llm: bool = True) -> tuple[str, dict, str]:
375
  """
376
+ NLU entry point. Returns (intent, entities, source) where source is:
377
+ - 'structural': digit/amount/yes-no/name/date regex matched
378
+ - 'keyword': keyword fast-path matched
379
+ - 'embedding': sentence encoder matched above threshold
380
+ - 'human_keyword': escape-hatch keyword caught
381
+ - 'unknown': nothing matched
382
+
383
+ `use_llm` is a misnomer kept for backward compat with the legacy module's
384
+ signature — here it means "use the embedding layer". Set False to test
385
+ rule-only behavior.
386
  """
387
  entities: dict = {}
388
  if not text or not text.strip():
389
  return "unknown", entities, "unknown"
390
 
391
+ # Layer 0: Always-on human-agent escape
392
  if _contains_human_keyword(text):
393
  return "human_agent", entities, "human_keyword"
394
 
395
+ # Layer 1: Structural extractors for strict-format slots
396
  if expected == "digits":
397
  d = _extract_digits(text)
398
  if d:
 
411
  return yn, entities, "structural"
412
 
413
  if expected == "name":
 
414
  name = text.strip().split()[-1] if text.strip() else ""
415
  if name:
416
  entities["name"] = name
 
420
  entities["date"] = text.strip()
421
  return "provide_date", entities, "structural"
422
 
423
+ # Layer 1.5: Keyword fast-path (cheap, runs in any state so users can
424
+ # pivot intent mid-flow).
 
 
 
 
425
  kw_intent = _match_intent_keyword(text)
426
  if kw_intent:
427
+ logger.info(f"NLU keyword: matched {text!r} → {kw_intent}")
428
  return kw_intent, entities, "keyword"
429
 
430
+ # Layer 2: Embedding similarity
431
  if not use_llm:
432
  logger.info(f"NLU: use_llm=False, returning unknown for {text!r}")
433
  return "unknown", entities, "unknown"
434
 
435
+ embed_result = _classify_with_embedding(text, expected)
436
+ if embed_result is None:
437
+ logger.warning(f"NLU embedding unavailable, returning unknown for {text!r}")
 
 
 
 
 
 
 
 
 
 
 
 
438
  return "unknown", entities, "unknown"
439
 
440
+ intent, confidence = embed_result
441
+ if confidence < CONFIDENCE_THRESHOLD:
442
+ logger.info(f"NLU embedding: {intent}@{confidence:.3f} below threshold "
443
+ f"{CONFIDENCE_THRESHOLD}, returning unknown")
444
+ return "unknown", entities, "unknown"
445
 
446
+ # Free-text slot pass-through (preserve original Hausa)
447
  if expected == "bundle":
448
  t = text.lower()
449
  for b in ("rana", "mako", "wata"):
450
  if b in t:
451
+ entities["bundle"] = b
452
  break
 
453
  if expected == "text":
454
+ entities["text"] = text.strip()
455
 
456
+ logger.info(f"NLU embedding accepted: {text!r} → {intent} (conf={confidence:.3f})")
457
+ return intent, entities, "embedding"