Spaces:
Sleeping
Sleeping
| """ | |
| NLU — NLLB + Qwen pivot-through-English architecture with keyword fast-path. | |
| Flow: | |
| 1. Deterministic structural extractors run FIRST on the original Hausa | |
| text (digits, amounts, yes/no keywords). These MUST be deterministic | |
| because "1234" → "provide_digits" with digits="1234" is non-negotiable | |
| for banks, and regex is faster + more reliable than any model for | |
| this sub-task. | |
| 2. Keyword fast-path for common Hausa + English intent phrases. Matches | |
| "check balance", "duba ma'auni", "canjin kuɗi", etc. in <10ms without | |
| loading any model. This is what real voice bots use for 90% of turns. | |
| 3. If structural + keyword layers don't match, the text is translated | |
| Hausa → English via NLLB-200 (skipped if input is already English), | |
| then classified by Qwen2.5-1.5B in English (where it is strong) into | |
| one of a small fixed set of intent labels. | |
| 4. If NLLB or Qwen fails, we return "unknown" cleanly — the dialogue | |
| manager routes to a vertical-specific fallback prompt. | |
| All heavy models are lazy-loaded on first use. Cold-start downloads: | |
| - NLLB-200-distilled-600M: ~2.4 GB | |
| - Qwen2.5-1.5B-Instruct: ~3 GB | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import json | |
| import logging | |
| from typing import Optional | |
| logger = logging.getLogger("plotweaver.nlu") | |
| # --------------------------------------------------------------------------- | |
| # Deterministic structural extractors (run on raw Hausa text) | |
| # --------------------------------------------------------------------------- | |
| WORD_DIGITS = { | |
| "sifili": "0", "daya": "1", "ɗaya": "1", "biyu": "2", "uku": "3", | |
| "hudu": "4", "huɗu": "4", "biyar": "5", "shida": "6", "bakwai": "7", | |
| "takwas": "8", "tara": "9", | |
| } | |
| WORD_AMOUNTS = { | |
| "dubu goma": 10000, "dubu biyar": 5000, "dubu biyu": 2000, | |
| "dubu": 1000, "ɗari biyar": 500, "dari biyar": 500, | |
| "ɗari": 100, "dari": 100, | |
| } | |
| # Hausa yes/no keywords for the sole case where we short-circuit Qwen | |
| HAUSA_YES = {"i", "eh", "haka ne", "haka", "ok", "okay", "yes"} | |
| HAUSA_NO = {"a'a", "a'aa", "ba haka", "ba", "no"} | |
| # Human-agent escape hatch | |
| HUMAN_KEYWORDS = {"mutum", "wakili", "agent", "human"} | |
| def _extract_digits(text: str) -> Optional[str]: | |
| m = re.findall(r"\d+", text) | |
| if m: | |
| return "".join(m) | |
| tokens = text.lower().split() | |
| d = [WORD_DIGITS[tok] for tok in tokens if tok in WORD_DIGITS] | |
| return "".join(d) if d else None | |
| def _extract_amount(text: str) -> Optional[int]: | |
| m = re.search(r"\d+", text) | |
| if m: | |
| return int(m.group()) | |
| t = text.lower() | |
| for phrase in sorted(WORD_AMOUNTS.keys(), key=len, reverse=True): | |
| if phrase in t: | |
| return WORD_AMOUNTS[phrase] | |
| return None | |
| def _match_yesno(text: str) -> Optional[str]: | |
| t = " " + text.lower().strip() + " " | |
| for kw in HAUSA_YES: | |
| if f" {kw} " in t or t.strip() == kw: | |
| return "yes" | |
| for kw in HAUSA_NO: | |
| if f" {kw} " in t or t.strip() == kw: | |
| return "no" | |
| return None | |
| def _contains_human_keyword(text: str) -> bool: | |
| t = text.lower() | |
| return any(kw in t for kw in HUMAN_KEYWORDS) | |
| # Keyword fast-path for common intents. Runs BEFORE NLLB+Qwen so that the | |
| # scripted demo flows don't require a 6GB LLM load. Phrases are Hausa and | |
| # English pairs that customers actually use. When none match, we fall | |
| # through to NLLB+Qwen for paraphrases. | |
| INTENT_KEYWORDS = { | |
| "check_balance": [ | |
| "duba ma'auni", "ma'auni", "balance", "check balance", | |
| "account balance", "how much", "kudin asusu", | |
| ], | |
| "block_card": [ | |
| "toshe kati", "block card", "cancel card", "freeze card", | |
| "toshe", "lost card", "ɓatar da kati", | |
| ], | |
| "transfer_money": [ | |
| "canjin kuɗi", "canjin kudi", "transfer", "transfer money", | |
| "send money", "aiki kuɗi", "aiki kudi", | |
| ], | |
| "buy_airtime": [ | |
| "saya airtime", "airtime", "buy airtime", "top up", "topup", | |
| "recharge", "karɓi airtime", | |
| ], | |
| "buy_bundle": [ | |
| "saya bundle", "bundle", "buy bundle", "buy data", "data", | |
| "internet", "megabyte", | |
| ], | |
| "complaint": [ | |
| "yin korafi", "korafi", "complaint", "complain", "problem", | |
| "matsala", "file complaint", | |
| ], | |
| "check_order": [ | |
| "bincika oda", "oda", "check order", "order status", "my order", | |
| "where is my order", "track order", | |
| ], | |
| "reschedule": [ | |
| "sake tsara", "reschedule", "change time", "another day", | |
| "later", "tomorrow", | |
| ], | |
| "return_item": [ | |
| "mayar da kaya", "return", "return item", "send back", "mayar", | |
| ], | |
| } | |
| def _match_intent_keyword(text: str) -> Optional[str]: | |
| """Keyword fast-path for common customer-service intents. | |
| Returns the intent name if a keyword matches, else None.""" | |
| t = text.lower().strip() | |
| # Check longer phrases first so "check balance" wins over "check order" | |
| all_kw = [(intent, kw) for intent, kws in INTENT_KEYWORDS.items() for kw in kws] | |
| all_kw.sort(key=lambda x: len(x[1]), reverse=True) | |
| for intent, kw in all_kw: | |
| if kw in t: | |
| return intent | |
| return None | |
| def _looks_english(text: str) -> bool: | |
| """Heuristic: if text contains no Hausa-specific characters and is majority | |
| ASCII, treat as English and skip NLLB translation. Hausa uses ɓ, ɗ, ƙ, ƴ | |
| and the apostrophe in 'a'a', 'ma'auni', 'jumma'a' etc.""" | |
| hausa_chars = set("ɓɗƙƴƁƊƘƳ") | |
| if any(c in hausa_chars for c in text): | |
| return False | |
| # Common Hausa words — if any match, treat as Hausa | |
| hausa_markers = { | |
| "duba", "ma'auni", "toshe", "kati", "canjin", "kuɗi", "kudi", | |
| "saya", "airtime", "bundle", "korafi", "bincika", "oda", | |
| "sake", "tsara", "mayar", "kaya", "wakili", "mutum", | |
| "sannu", "nagode", "don", "allah", "ka", "yana", "tana", | |
| "dubu", "ɗari", "dari", "biyar", "biyu", "uku", "hudu", "huɗu", | |
| } | |
| tokens = set(text.lower().split()) | |
| return not bool(tokens & hausa_markers) | |
| # --------------------------------------------------------------------------- | |
| # NLLB-200 Ha → En translation (lazy-loaded) | |
| # --------------------------------------------------------------------------- | |
| _nllb_model = None | |
| _nllb_tokenizer = None | |
| _nllb_failed = False | |
| def _load_nllb(): | |
| """Lazy-load NLLB-200-distilled-600M.""" | |
| global _nllb_model, _nllb_tokenizer, _nllb_failed | |
| if _nllb_failed: | |
| return None, None | |
| if _nllb_model is not None: | |
| return _nllb_model, _nllb_tokenizer | |
| try: | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| logger.info("Loading NLLB-200-distilled-600M…") | |
| model_id = "facebook/nllb-200-distilled-600M" | |
| _nllb_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| _nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| _nllb_model.eval() | |
| logger.info("NLLB-200 ready.") | |
| return _nllb_model, _nllb_tokenizer | |
| except Exception as e: | |
| logger.warning(f"NLLB load failed: {e}") | |
| _nllb_failed = True | |
| return None, None | |
| def translate_ha_to_en(text: str) -> Optional[str]: | |
| """Translate Hausa to English via NLLB. Returns None on failure.""" | |
| model, tokenizer = _load_nllb() | |
| if model is None or not text.strip(): | |
| return None | |
| try: | |
| import torch | |
| # NLLB requires source language token set on tokenizer | |
| tokenizer.src_lang = "hau_Latn" | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) | |
| # Force English output via forced_bos_token_id | |
| forced_bos_id = tokenizer.convert_tokens_to_ids("eng_Latn") | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| forced_bos_token_id=forced_bos_id, | |
| max_new_tokens=128, | |
| num_beams=2, | |
| ) | |
| translated = tokenizer.batch_decode(out, skip_special_tokens=True)[0].strip() | |
| logger.info(f"NLLB Ha→En: {text!r} → {translated!r}") | |
| return translated | |
| except Exception as e: | |
| logger.warning(f"NLLB translate failed: {e}") | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Qwen2.5-1.5B intent classifier (operates on English text) | |
| # --------------------------------------------------------------------------- | |
| _llm_model = None | |
| _llm_tokenizer = None | |
| _llm_failed = False | |
| def _load_llm(): | |
| global _llm_model, _llm_tokenizer, _llm_failed | |
| if _llm_failed: | |
| return None, None | |
| if _llm_model is not None: | |
| return _llm_model, _llm_tokenizer | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info("Loading Qwen2.5-1.5B-Instruct…") | |
| model_id = "Qwen/Qwen2.5-1.5B-Instruct" | |
| _llm_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| _llm_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| _llm_model.eval() | |
| logger.info("Qwen2.5-1.5B ready.") | |
| return _llm_model, _llm_tokenizer | |
| except Exception as e: | |
| logger.warning(f"Qwen load failed: {e}") | |
| _llm_failed = True | |
| return None, None | |
| CANDIDATE_INTENTS = { | |
| None: ["check_balance", "block_card", "transfer_money", | |
| "buy_airtime", "buy_bundle", "complaint", | |
| "check_order", "reschedule", "return_item", | |
| "human_agent", "unknown"], | |
| "intent": ["check_balance", "block_card", "transfer_money", | |
| "buy_airtime", "buy_bundle", "complaint", | |
| "check_order", "reschedule", "return_item", | |
| "human_agent", "unknown"], | |
| "yesno": ["yes", "no", "human_agent", "unknown"], | |
| "name": ["provide_name", "human_agent", "unknown"], | |
| "date": ["provide_date", "human_agent", "unknown"], | |
| "bundle": ["provide_bundle", "human_agent", "unknown"], | |
| "text": ["provide_text", "human_agent", "unknown"], | |
| } | |
| SYSTEM_PROMPT = """You are an intent classifier for a customer-service voice bot. | |
| You will be given an English-language utterance (translated from Hausa) and a list of candidate intents. Return JSON with the single best-matching intent and any entities you can extract. | |
| Intent meanings: | |
| - check_balance: user wants to check an account balance | |
| - block_card: user wants to block, freeze, or cancel a bank card | |
| - transfer_money: user wants to send or transfer money | |
| - buy_airtime: user wants to buy phone airtime / top-up | |
| - buy_bundle: user wants to buy a data bundle / internet package | |
| - complaint: user wants to file a complaint or report a problem | |
| - check_order: user wants to check the status of an order | |
| - reschedule: user wants to reschedule a delivery | |
| - return_item: user wants to return an item | |
| - human_agent: user wants to speak to a human person | |
| - yes / no: affirmative or negative reply | |
| - provide_name / provide_date / provide_bundle / provide_text: user is supplying information | |
| - unknown: cannot determine intent | |
| Return ONLY valid JSON. No explanation, no markdown. Example: {"intent": "check_balance", "entities": {}}""" | |
| def _qwen_classify(english_text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]: | |
| """Classify an English utterance into an intent. Returns None on failure.""" | |
| model, tokenizer = _load_llm() | |
| if model is None: | |
| return None | |
| candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None]) | |
| user_prompt = ( | |
| f'Utterance: "{english_text}"\n' | |
| f'Candidate intents: {", ".join(candidates)}\n\n' | |
| 'Return JSON only.' | |
| ) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| try: | |
| import torch | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=60, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() | |
| logger.info(f"Qwen raw: {generated}") | |
| m = re.search(r"\{.*?\}", generated, re.DOTALL) | |
| if not m: | |
| return None | |
| parsed = json.loads(m.group()) | |
| intent = parsed.get("intent", "unknown") | |
| entities = parsed.get("entities", {}) or {} | |
| if not isinstance(entities, dict): | |
| entities = {} | |
| if intent not in candidates: | |
| logger.info(f"Qwen returned out-of-candidate intent: {intent}") | |
| return None | |
| return intent, entities | |
| except Exception as e: | |
| logger.warning(f"Qwen inference failed: {e}") | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def parse(text: str, expected: Optional[str] = None, | |
| use_llm: bool = True) -> tuple[str, dict, str]: | |
| """ | |
| NLU. Returns (intent, entities, source) where source is one of: | |
| - 'structural': deterministic extractor caught digits/amount/yes-no | |
| - 'keyword': fast-path keyword matcher caught a common intent | |
| - 'qwen_en': input was English, classified directly by Qwen | |
| - 'nllb+qwen': translated via NLLB then classified via Qwen | |
| - 'human_keyword': caught human-agent escape hatch by keyword | |
| - 'unknown': nothing matched | |
| """ | |
| entities: dict = {} | |
| if not text or not text.strip(): | |
| return "unknown", entities, "unknown" | |
| # Always-on human-agent escape (safety) | |
| if _contains_human_keyword(text): | |
| return "human_agent", entities, "human_keyword" | |
| # Layer 1: deterministic structural extractors for strict-format slots | |
| if expected == "digits": | |
| d = _extract_digits(text) | |
| if d: | |
| entities["digits"] = d | |
| return "provide_digits", entities, "structural" | |
| if expected == "amount": | |
| a = _extract_amount(text) | |
| if a is not None: | |
| entities["amount"] = a | |
| return "provide_amount", entities, "structural" | |
| if expected == "yesno": | |
| yn = _match_yesno(text) | |
| if yn: | |
| return yn, entities, "structural" | |
| if expected == "name": | |
| # Name is free-form; take the last token as a quick heuristic. | |
| name = text.strip().split()[-1] if text.strip() else "" | |
| if name: | |
| entities["name"] = name | |
| return "provide_name", entities, "structural" | |
| if expected == "date": | |
| entities["date"] = text.strip() | |
| return "provide_date", entities, "structural" | |
| # Layer 1.5: Keyword fast-path for common intents (Hausa + English). | |
| # Runs in ANY state so users can pivot intent mid-flow ("actually I want | |
| # to transfer money instead"). Structural extractors above already | |
| # claimed strict-slot cases, so if we're in a slot-filling state and | |
| # the text didn't match the slot, it's fair game to re-interpret as a | |
| # new intent. | |
| kw_intent = _match_intent_keyword(text) | |
| if kw_intent: | |
| logger.info(f"NLU: keyword matched {text!r} → {kw_intent}") | |
| return kw_intent, entities, "keyword" | |
| # Layer 2: NLLB Ha → En (skip if input already English), then Qwen | |
| if not use_llm: | |
| logger.info(f"NLU: use_llm=False, returning unknown for {text!r}") | |
| return "unknown", entities, "unknown" | |
| if _looks_english(text): | |
| logger.info(f"NLU: input looks English, skipping NLLB: {text!r}") | |
| english_text = text | |
| source_tag = "qwen_en" | |
| else: | |
| logger.info(f"NLU: translating Hausa via NLLB: {text!r}") | |
| english_text = translate_ha_to_en(text) | |
| if english_text is None: | |
| logger.warning("NLU: NLLB failed, returning unknown") | |
| return "unknown", entities, "unknown" | |
| source_tag = "nllb+qwen" | |
| qwen_result = _qwen_classify(english_text, expected) | |
| if qwen_result is None: | |
| logger.warning(f"NLU: Qwen returned no valid intent for {english_text!r}") | |
| return "unknown", entities, "unknown" | |
| intent, llm_entities = qwen_result | |
| logger.info(f"NLU: Qwen classified {english_text!r} → intent={intent}") | |
| # For free-text slots, pass the original Hausa text through | |
| if expected == "bundle": | |
| t = text.lower() | |
| for b in ("rana", "mako", "wata"): | |
| if b in t: | |
| llm_entities["bundle"] = b | |
| break | |
| if expected == "text": | |
| llm_entities["text"] = text.strip() | |
| return intent, llm_entities, source_tag |