Toadoum's picture
Update nlu.py
ae6619f verified
"""
NLU — NLLB + Qwen pivot-through-English architecture with keyword fast-path.
Flow:
1. Deterministic structural extractors run FIRST on the original Hausa
text (digits, amounts, yes/no keywords). These MUST be deterministic
because "1234" → "provide_digits" with digits="1234" is non-negotiable
for banks, and regex is faster + more reliable than any model for
this sub-task.
2. Keyword fast-path for common Hausa + English intent phrases. Matches
"check balance", "duba ma'auni", "canjin kuɗi", etc. in <10ms without
loading any model. This is what real voice bots use for 90% of turns.
3. If structural + keyword layers don't match, the text is translated
Hausa → English via NLLB-200 (skipped if input is already English),
then classified by Qwen2.5-1.5B in English (where it is strong) into
one of a small fixed set of intent labels.
4. If NLLB or Qwen fails, we return "unknown" cleanly — the dialogue
manager routes to a vertical-specific fallback prompt.
All heavy models are lazy-loaded on first use. Cold-start downloads:
- NLLB-200-distilled-600M: ~2.4 GB
- Qwen2.5-1.5B-Instruct: ~3 GB
"""
from __future__ import annotations
import re
import json
import logging
from typing import Optional
logger = logging.getLogger("plotweaver.nlu")
# ---------------------------------------------------------------------------
# Deterministic structural extractors (run on raw Hausa text)
# ---------------------------------------------------------------------------
WORD_DIGITS = {
"sifili": "0", "daya": "1", "ɗaya": "1", "biyu": "2", "uku": "3",
"hudu": "4", "huɗu": "4", "biyar": "5", "shida": "6", "bakwai": "7",
"takwas": "8", "tara": "9",
}
WORD_AMOUNTS = {
"dubu goma": 10000, "dubu biyar": 5000, "dubu biyu": 2000,
"dubu": 1000, "ɗari biyar": 500, "dari biyar": 500,
"ɗari": 100, "dari": 100,
}
# Hausa yes/no keywords for the sole case where we short-circuit Qwen
HAUSA_YES = {"i", "eh", "haka ne", "haka", "ok", "okay", "yes"}
HAUSA_NO = {"a'a", "a'aa", "ba haka", "ba", "no"}
# Human-agent escape hatch
HUMAN_KEYWORDS = {"mutum", "wakili", "agent", "human"}
def _extract_digits(text: str) -> Optional[str]:
m = re.findall(r"\d+", text)
if m:
return "".join(m)
tokens = text.lower().split()
d = [WORD_DIGITS[tok] for tok in tokens if tok in WORD_DIGITS]
return "".join(d) if d else None
def _extract_amount(text: str) -> Optional[int]:
m = re.search(r"\d+", text)
if m:
return int(m.group())
t = text.lower()
for phrase in sorted(WORD_AMOUNTS.keys(), key=len, reverse=True):
if phrase in t:
return WORD_AMOUNTS[phrase]
return None
def _match_yesno(text: str) -> Optional[str]:
t = " " + text.lower().strip() + " "
for kw in HAUSA_YES:
if f" {kw} " in t or t.strip() == kw:
return "yes"
for kw in HAUSA_NO:
if f" {kw} " in t or t.strip() == kw:
return "no"
return None
def _contains_human_keyword(text: str) -> bool:
t = text.lower()
return any(kw in t for kw in HUMAN_KEYWORDS)
# Keyword fast-path for common intents. Runs BEFORE NLLB+Qwen so that the
# scripted demo flows don't require a 6GB LLM load. Phrases are Hausa and
# English pairs that customers actually use. When none match, we fall
# through to NLLB+Qwen for paraphrases.
INTENT_KEYWORDS = {
"check_balance": [
"duba ma'auni", "ma'auni", "balance", "check balance",
"account balance", "how much", "kudin asusu",
],
"block_card": [
"toshe kati", "block card", "cancel card", "freeze card",
"toshe", "lost card", "ɓatar da kati",
],
"transfer_money": [
"canjin kuɗi", "canjin kudi", "transfer", "transfer money",
"send money", "aiki kuɗi", "aiki kudi",
],
"buy_airtime": [
"saya airtime", "airtime", "buy airtime", "top up", "topup",
"recharge", "karɓi airtime",
],
"buy_bundle": [
"saya bundle", "bundle", "buy bundle", "buy data", "data",
"internet", "megabyte",
],
"complaint": [
"yin korafi", "korafi", "complaint", "complain", "problem",
"matsala", "file complaint",
],
"check_order": [
"bincika oda", "oda", "check order", "order status", "my order",
"where is my order", "track order",
],
"reschedule": [
"sake tsara", "reschedule", "change time", "another day",
"later", "tomorrow",
],
"return_item": [
"mayar da kaya", "return", "return item", "send back", "mayar",
],
}
def _match_intent_keyword(text: str) -> Optional[str]:
"""Keyword fast-path for common customer-service intents.
Returns the intent name if a keyword matches, else None."""
t = text.lower().strip()
# Check longer phrases first so "check balance" wins over "check order"
all_kw = [(intent, kw) for intent, kws in INTENT_KEYWORDS.items() for kw in kws]
all_kw.sort(key=lambda x: len(x[1]), reverse=True)
for intent, kw in all_kw:
if kw in t:
return intent
return None
def _looks_english(text: str) -> bool:
"""Heuristic: if text contains no Hausa-specific characters and is majority
ASCII, treat as English and skip NLLB translation. Hausa uses ɓ, ɗ, ƙ, ƴ
and the apostrophe in 'a'a', 'ma'auni', 'jumma'a' etc."""
hausa_chars = set("ɓɗƙƴƁƊƘƳ")
if any(c in hausa_chars for c in text):
return False
# Common Hausa words — if any match, treat as Hausa
hausa_markers = {
"duba", "ma'auni", "toshe", "kati", "canjin", "kuɗi", "kudi",
"saya", "airtime", "bundle", "korafi", "bincika", "oda",
"sake", "tsara", "mayar", "kaya", "wakili", "mutum",
"sannu", "nagode", "don", "allah", "ka", "yana", "tana",
"dubu", "ɗari", "dari", "biyar", "biyu", "uku", "hudu", "huɗu",
}
tokens = set(text.lower().split())
return not bool(tokens & hausa_markers)
# ---------------------------------------------------------------------------
# NLLB-200 Ha → En translation (lazy-loaded)
# ---------------------------------------------------------------------------
_nllb_model = None
_nllb_tokenizer = None
_nllb_failed = False
def _load_nllb():
"""Lazy-load NLLB-200-distilled-600M."""
global _nllb_model, _nllb_tokenizer, _nllb_failed
if _nllb_failed:
return None, None
if _nllb_model is not None:
return _nllb_model, _nllb_tokenizer
try:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger.info("Loading NLLB-200-distilled-600M…")
model_id = "facebook/nllb-200-distilled-600M"
_nllb_tokenizer = AutoTokenizer.from_pretrained(model_id)
_nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
_nllb_model.eval()
logger.info("NLLB-200 ready.")
return _nllb_model, _nllb_tokenizer
except Exception as e:
logger.warning(f"NLLB load failed: {e}")
_nllb_failed = True
return None, None
def translate_ha_to_en(text: str) -> Optional[str]:
"""Translate Hausa to English via NLLB. Returns None on failure."""
model, tokenizer = _load_nllb()
if model is None or not text.strip():
return None
try:
import torch
# NLLB requires source language token set on tokenizer
tokenizer.src_lang = "hau_Latn"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
# Force English output via forced_bos_token_id
forced_bos_id = tokenizer.convert_tokens_to_ids("eng_Latn")
with torch.no_grad():
out = model.generate(
**inputs,
forced_bos_token_id=forced_bos_id,
max_new_tokens=128,
num_beams=2,
)
translated = tokenizer.batch_decode(out, skip_special_tokens=True)[0].strip()
logger.info(f"NLLB Ha→En: {text!r}{translated!r}")
return translated
except Exception as e:
logger.warning(f"NLLB translate failed: {e}")
return None
# ---------------------------------------------------------------------------
# Qwen2.5-1.5B intent classifier (operates on English text)
# ---------------------------------------------------------------------------
_llm_model = None
_llm_tokenizer = None
_llm_failed = False
def _load_llm():
global _llm_model, _llm_tokenizer, _llm_failed
if _llm_failed:
return None, None
if _llm_model is not None:
return _llm_model, _llm_tokenizer
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("Loading Qwen2.5-1.5B-Instruct…")
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
_llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
_llm_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
_llm_model.eval()
logger.info("Qwen2.5-1.5B ready.")
return _llm_model, _llm_tokenizer
except Exception as e:
logger.warning(f"Qwen load failed: {e}")
_llm_failed = True
return None, None
CANDIDATE_INTENTS = {
None: ["check_balance", "block_card", "transfer_money",
"buy_airtime", "buy_bundle", "complaint",
"check_order", "reschedule", "return_item",
"human_agent", "unknown"],
"intent": ["check_balance", "block_card", "transfer_money",
"buy_airtime", "buy_bundle", "complaint",
"check_order", "reschedule", "return_item",
"human_agent", "unknown"],
"yesno": ["yes", "no", "human_agent", "unknown"],
"name": ["provide_name", "human_agent", "unknown"],
"date": ["provide_date", "human_agent", "unknown"],
"bundle": ["provide_bundle", "human_agent", "unknown"],
"text": ["provide_text", "human_agent", "unknown"],
}
SYSTEM_PROMPT = """You are an intent classifier for a customer-service voice bot.
You will be given an English-language utterance (translated from Hausa) and a list of candidate intents. Return JSON with the single best-matching intent and any entities you can extract.
Intent meanings:
- check_balance: user wants to check an account balance
- block_card: user wants to block, freeze, or cancel a bank card
- transfer_money: user wants to send or transfer money
- buy_airtime: user wants to buy phone airtime / top-up
- buy_bundle: user wants to buy a data bundle / internet package
- complaint: user wants to file a complaint or report a problem
- check_order: user wants to check the status of an order
- reschedule: user wants to reschedule a delivery
- return_item: user wants to return an item
- human_agent: user wants to speak to a human person
- yes / no: affirmative or negative reply
- provide_name / provide_date / provide_bundle / provide_text: user is supplying information
- unknown: cannot determine intent
Return ONLY valid JSON. No explanation, no markdown. Example: {"intent": "check_balance", "entities": {}}"""
def _qwen_classify(english_text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]:
"""Classify an English utterance into an intent. Returns None on failure."""
model, tokenizer = _load_llm()
if model is None:
return None
candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None])
user_prompt = (
f'Utterance: "{english_text}"\n'
f'Candidate intents: {", ".join(candidates)}\n\n'
'Return JSON only.'
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
try:
import torch
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=60,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
logger.info(f"Qwen raw: {generated}")
m = re.search(r"\{.*?\}", generated, re.DOTALL)
if not m:
return None
parsed = json.loads(m.group())
intent = parsed.get("intent", "unknown")
entities = parsed.get("entities", {}) or {}
if not isinstance(entities, dict):
entities = {}
if intent not in candidates:
logger.info(f"Qwen returned out-of-candidate intent: {intent}")
return None
return intent, entities
except Exception as e:
logger.warning(f"Qwen inference failed: {e}")
return None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def parse(text: str, expected: Optional[str] = None,
use_llm: bool = True) -> tuple[str, dict, str]:
"""
NLU. Returns (intent, entities, source) where source is one of:
- 'structural': deterministic extractor caught digits/amount/yes-no
- 'keyword': fast-path keyword matcher caught a common intent
- 'qwen_en': input was English, classified directly by Qwen
- 'nllb+qwen': translated via NLLB then classified via Qwen
- 'human_keyword': caught human-agent escape hatch by keyword
- 'unknown': nothing matched
"""
entities: dict = {}
if not text or not text.strip():
return "unknown", entities, "unknown"
# Always-on human-agent escape (safety)
if _contains_human_keyword(text):
return "human_agent", entities, "human_keyword"
# Layer 1: deterministic structural extractors for strict-format slots
if expected == "digits":
d = _extract_digits(text)
if d:
entities["digits"] = d
return "provide_digits", entities, "structural"
if expected == "amount":
a = _extract_amount(text)
if a is not None:
entities["amount"] = a
return "provide_amount", entities, "structural"
if expected == "yesno":
yn = _match_yesno(text)
if yn:
return yn, entities, "structural"
if expected == "name":
# Name is free-form; take the last token as a quick heuristic.
name = text.strip().split()[-1] if text.strip() else ""
if name:
entities["name"] = name
return "provide_name", entities, "structural"
if expected == "date":
entities["date"] = text.strip()
return "provide_date", entities, "structural"
# Layer 1.5: Keyword fast-path for common intents (Hausa + English).
# Runs in ANY state so users can pivot intent mid-flow ("actually I want
# to transfer money instead"). Structural extractors above already
# claimed strict-slot cases, so if we're in a slot-filling state and
# the text didn't match the slot, it's fair game to re-interpret as a
# new intent.
kw_intent = _match_intent_keyword(text)
if kw_intent:
logger.info(f"NLU: keyword matched {text!r}{kw_intent}")
return kw_intent, entities, "keyword"
# Layer 2: NLLB Ha → En (skip if input already English), then Qwen
if not use_llm:
logger.info(f"NLU: use_llm=False, returning unknown for {text!r}")
return "unknown", entities, "unknown"
if _looks_english(text):
logger.info(f"NLU: input looks English, skipping NLLB: {text!r}")
english_text = text
source_tag = "qwen_en"
else:
logger.info(f"NLU: translating Hausa via NLLB: {text!r}")
english_text = translate_ha_to_en(text)
if english_text is None:
logger.warning("NLU: NLLB failed, returning unknown")
return "unknown", entities, "unknown"
source_tag = "nllb+qwen"
qwen_result = _qwen_classify(english_text, expected)
if qwen_result is None:
logger.warning(f"NLU: Qwen returned no valid intent for {english_text!r}")
return "unknown", entities, "unknown"
intent, llm_entities = qwen_result
logger.info(f"NLU: Qwen classified {english_text!r} → intent={intent}")
# For free-text slots, pass the original Hausa text through
if expected == "bundle":
t = text.lower()
for b in ("rana", "mako", "wata"):
if b in t:
llm_entities["bundle"] = b
break
if expected == "text":
llm_entities["text"] = text.strip()
return intent, llm_entities, source_tag