File size: 11,069 Bytes
08d20bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
NLU — Hybrid Hausa intent + entity extraction.

Three-tier architecture:
  1. Rule-based keyword matcher (fast path, ~80% of demo utterances)
  2. Qwen2.5-1.5B-Instruct zero-shot JSON extractor (paraphrases, novel phrasings)
  3. Rule-based fallback (if LLM fails or returns unparseable output)

The LLM is lazy-loaded on first non-matched utterance so the Space boots fast.
In production this would be replaced with a fine-tuned classifier on
PlotWeaver's Hausa intent corpus.
"""
from __future__ import annotations
import re
import json
import logging
from typing import Optional

logger = logging.getLogger("plotweaver.nlu")

# ---------------------------------------------------------------------------
# Layer 1: rule-based fast path (covers common demo phrases)
# ---------------------------------------------------------------------------
INTENT_KEYWORDS = {
    "check_balance": ["duba", "ma'auni", "balance", "kudi", "asusu"],
    "block_card": ["toshe", "kati", "block"],
    "transfer_money": ["tura", "canji", "canjin", "aika", "transfer"],
    "buy_airtime": ["airtime", "caji"],
    "buy_bundle": ["bundle", "data", "intanet"],
    "complaint": ["korafi", "matsala", "complain"],
    "check_order": ["bincika", "order", "oda"],
    "reschedule": ["sake tsara", "reschedule", "canja lokaci"],
    "return_item": ["mayar", "mayarwa", "return"],
    "human_agent": ["mutum", "wakili", "agent", "human"],
    "yes": ["i ", " i", "eh", "haka ne", "yes", "ok", "okay"],
    "no": ["a'a", "a'aa", "ba haka", " no", "no "],
}

WORD_DIGITS = {
    "sifili": "0", "daya": "1", "ɗaya": "1", "biyu": "2", "uku": "3",
    "hudu": "4", "huɗu": "4", "biyar": "5", "shida": "6", "bakwai": "7",
    "takwas": "8", "tara": "9",
}

WORD_AMOUNTS = {
    "dubu goma": 10000, "dubu biyar": 5000, "dubu biyu": 2000,
    "dubu": 1000, "ɗari biyar": 500, "dari biyar": 500,
    "ɗari": 100, "dari": 100,
}


def _norm(t: str) -> str:
    return " " + t.lower().strip() + " "


def _match_intent_kw(text: str) -> Optional[str]:
    t = _norm(text)
    for intent, kws in INTENT_KEYWORDS.items():
        for kw in kws:
            if kw in t:
                return intent
    return None


def _extract_digits(text: str) -> Optional[str]:
    m = re.findall(r"\d+", text)
    if m:
        return "".join(m)
    tokens = text.lower().split()
    d = [WORD_DIGITS[tok] for tok in tokens if tok in WORD_DIGITS]
    return "".join(d) if d else None


def _extract_amount(text: str) -> Optional[int]:
    m = re.search(r"\d+", text)
    if m:
        return int(m.group())
    t = text.lower()
    for phrase in sorted(WORD_AMOUNTS.keys(), key=len, reverse=True):
        if phrase in t:
            return WORD_AMOUNTS[phrase]
    return None


def _rule_based_parse(text: str, expected: Optional[str]) -> tuple[str, dict]:
    """Layer 1 + 3: deterministic keyword + slot matcher."""
    entities: dict = {}
    if not text or not text.strip():
        return "unknown", entities

    # Universal escape
    if _match_intent_kw(text) == "human_agent":
        return "human_agent", entities

    if expected == "digits":
        d = _extract_digits(text)
        if d:
            entities["digits"] = d
            return "provide_digits", entities

    if expected == "amount":
        a = _extract_amount(text)
        if a is not None:
            entities["amount"] = a
            return "provide_amount", entities

    if expected == "name":
        name = text.strip().split()[-1] if text.strip() else ""
        if name:
            entities["name"] = name
            return "provide_name", entities

    if expected == "date":
        entities["date"] = text.strip()
        return "provide_date", entities

    if expected == "bundle":
        t = text.lower()
        for b in ("rana", "mako", "wata"):
            if b in t:
                entities["bundle"] = b
                return "provide_bundle", entities

    if expected == "text":
        entities["text"] = text.strip()
        return "provide_text", entities

    if expected == "yesno":
        i = _match_intent_kw(text)
        if i in ("yes", "no"):
            return i, entities

    i = _match_intent_kw(text)
    if i:
        return i, entities

    return "unknown", entities


# ---------------------------------------------------------------------------
# Layer 2: Qwen2.5-1.5B-Instruct zero-shot NLU
# ---------------------------------------------------------------------------
_llm_model = None
_llm_tokenizer = None
_llm_failed = False  # set to True after any load failure, to prevent retries


def _load_llm():
    """Lazy-load Qwen2.5-1.5B-Instruct. Called only when rule-based misses."""
    global _llm_model, _llm_tokenizer, _llm_failed
    if _llm_failed:
        return None, None
    if _llm_model is not None:
        return _llm_model, _llm_tokenizer
    try:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        logger.info("Loading Qwen2.5-1.5B-Instruct for NLU…")
        model_id = "Qwen/Qwen2.5-1.5B-Instruct"
        _llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
        _llm_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,  # CPU — bfloat16 not broadly supported
            low_cpu_mem_usage=True,
        )
        _llm_model.eval()
        logger.info("Qwen2.5-1.5B-Instruct ready.")
        return _llm_model, _llm_tokenizer
    except Exception as e:
        logger.warning(f"LLM load failed: {e}")
        _llm_failed = True
        return None, None


# Candidate intents per expected-slot context. Keeps the LLM prompt small
# and constrains output to valid options only.
CANDIDATE_INTENTS = {
    None: ["check_balance", "block_card", "transfer_money",
           "buy_airtime", "buy_bundle", "complaint",
           "check_order", "reschedule", "return_item",
           "human_agent", "unknown"],
    "intent": ["check_balance", "block_card", "transfer_money",
               "buy_airtime", "buy_bundle", "complaint",
               "check_order", "reschedule", "return_item",
               "human_agent", "unknown"],
    "yesno": ["yes", "no", "human_agent", "unknown"],
    "digits": ["provide_digits", "human_agent", "unknown"],
    "amount": ["provide_amount", "human_agent", "unknown"],
    "name": ["provide_name", "human_agent", "unknown"],
    "date": ["provide_date", "human_agent", "unknown"],
    "bundle": ["provide_bundle", "human_agent", "unknown"],
    "text": ["provide_text", "human_agent", "unknown"],
}


SYSTEM_PROMPT = """You are an intent classifier for a Hausa-language customer service voice agent.

Analyze the user's Hausa utterance and return a JSON object with:
- "intent": one of the candidate intents provided
- "entities": a dict of extracted values (may be empty)

Intent meanings:
- check_balance: user wants to check their account balance
- block_card: user wants to block or freeze their bank card
- transfer_money: user wants to transfer or send money
- buy_airtime: user wants to buy phone airtime
- buy_bundle: user wants to buy a data bundle
- complaint: user wants to file a complaint
- check_order: user wants to check an order status
- reschedule: user wants to reschedule a delivery
- return_item: user wants to return an item
- human_agent: user wants to speak to a human
- yes / no: affirmative or negative response
- provide_digits / provide_amount / provide_name / provide_date / provide_bundle / provide_text: user is providing specific information
- unknown: cannot determine the intent

Return ONLY a valid JSON object, no explanation. Example: {"intent": "check_balance", "entities": {}}"""


def _llm_parse(text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]:
    """Layer 2: zero-shot LLM classification. Returns None on any failure."""
    model, tokenizer = _load_llm()
    if model is None:
        return None

    candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None])
    user_prompt = (
        f'Hausa utterance: "{text}"\n'
        f'Expected slot type: {expected or "any"}\n'
        f'Candidate intents: {", ".join(candidates)}\n\n'
        'Respond with JSON only.'
    )
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
    ]
    try:
        import torch
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=80,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
            )
        generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
        logger.info(f"LLM raw output: {generated}")

        # Extract JSON (model sometimes wraps it in markdown fences or prose)
        m = re.search(r"\{.*?\}", generated, re.DOTALL)
        if not m:
            return None
        parsed = json.loads(m.group())
        intent = parsed.get("intent", "unknown")
        entities = parsed.get("entities", {}) or {}
        if not isinstance(entities, dict):
            entities = {}
        # Validate intent is in candidate list
        if intent not in candidates:
            logger.info(f"LLM returned out-of-candidate intent: {intent}")
            return None
        return intent, entities
    except Exception as e:
        logger.warning(f"LLM inference failed: {e}")
        return None


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def parse(text: str, expected: Optional[str] = None,
          use_llm: bool = True) -> tuple[str, dict, str]:
    """
    Hybrid NLU. Returns (intent, entities, source) where source is one of
    'rule', 'llm', or 'rule_fallback'.

    Flow:
      1. Try rule-based keyword/slot matcher (fast, deterministic)
      2. If result is 'unknown' AND use_llm=True: try Qwen2.5 zero-shot
      3. If LLM fails or returns invalid output: return rule-based 'unknown'
    """
    intent, entities = _rule_based_parse(text, expected)

    if intent != "unknown":
        return intent, entities, "rule"

    if not use_llm:
        return intent, entities, "rule"

    # Rule-based missed — try LLM
    llm_result = _llm_parse(text, expected)
    if llm_result is None:
        return intent, entities, "rule_fallback"

    llm_intent, llm_entities = llm_result

    # Sanity-check entities for slot-typed expected (LLM might hallucinate
    # digits; re-run our deterministic extractors for strict-format slots)
    if expected == "digits":
        d = _extract_digits(text)
        if d:
            llm_entities["digits"] = d
    elif expected == "amount":
        a = _extract_amount(text)
        if a is not None:
            llm_entities["amount"] = a

    return llm_intent, llm_entities, "llm"