Spaces:
Sleeping
Sleeping
| """ | |
| evo_inference.py — FLAN-optimized + topic router + anti-echo/off-topic | |
| - Routes queries to the right topic (passport / driving / civil status / business) | |
| - Prefers chunks whose filename/text match the topic; filters placeholders | |
| - FLAN-friendly prompt; cleans prompt-echo; falls back if echo/too short/off-topic | |
| - Labels outputs: [Generative] / [Extractive] | |
| """ | |
| from typing import List, Dict | |
| import re | |
| from utils_lang import L, normalize_lang | |
| # Try to load your real Evo plugin first; else use the example; else None. | |
| _GENERATOR = None | |
| try: | |
| from evo_plugin import load_model as _load_real | |
| _GENERATOR = _load_real() | |
| except Exception: | |
| try: | |
| from evo_plugin_example import load_model as _load_example | |
| _GENERATOR = _load_example() | |
| except Exception: | |
| _GENERATOR = None | |
| # Keep snippets short so FLAN-T5 stays within encoder limit (512) | |
| MAX_SNIPPET_CHARS = 220 | |
| def _snippet(text: str) -> str: | |
| text = " ".join(text.split()) | |
| return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "") | |
| def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str: | |
| if not hits: | |
| return "**[Extractive]**\n\n" + L(lang, "intro_err") | |
| bullets = [f"- {_snippet(h['text'])}" for h in hits[:4]] | |
| steps = { | |
| "en": [ | |
| "• Step 1: Check eligibility & gather required documents.", | |
| "• Step 2: Confirm fees & payment options.", | |
| "• Step 3: Apply online or at the indicated office.", | |
| "• Step 4: Keep reference/receipt; track processing time.", | |
| ], | |
| "fr": [ | |
| "• Étape 1 : Vérifiez l’éligibilité et rassemblez les documents requis.", | |
| "• Étape 2 : Confirmez les frais et les moyens de paiement.", | |
| "• Étape 3 : Déposez la demande en ligne ou au bureau indiqué.", | |
| "• Étape 4 : Conservez le reçu/la référence et suivez le délai de traitement.", | |
| ], | |
| "mfe": [ | |
| "• Step 1: Get dokiman neseser ek verifie si to elegib.", | |
| "• Step 2: Konfirm fre ek manyer peyman.", | |
| "• Step 3: Fer demand online ouswa dan biro ki indike.", | |
| "• Step 4: Gard referans/reso; swiv letan tretman.", | |
| ], | |
| }[normalize_lang(lang)] | |
| return ( | |
| "**[Extractive]**\n\n" | |
| f"**{L(lang, 'intro_ok')}**\n\n" | |
| f"**Q:** {user_query}\n\n" | |
| f"**Key information:**\n" + "\n".join(bullets) + "\n\n" | |
| f"**Suggested steps:**\n" + "\n".join(steps) | |
| ) | |
| def _lang_name(code: str) -> str: | |
| return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English") | |
| # --- Topic routing ------------------------------------------------------------- | |
| _TOPIC_MAP = { | |
| "passport": { | |
| "file_hints": ["passport_renewal", "passport"], | |
| "word_hints": ["passport", "passeport", "paspor", "renew", "renouvel"], | |
| "forbid_words": ["business", "cbrd", "brn", "driving", "licence", "license", "civil status"], | |
| }, | |
| "driving": { | |
| "file_hints": ["driving_licence", "driving_license"], | |
| "word_hints": ["driving", "licence", "license", "permit", "idp", "pf-77"], | |
| "forbid_words": ["passport", "cbrd", "brn", "civil status"], | |
| }, | |
| "civil": { | |
| "file_hints": ["birth_marriage_certificate", "civil_status"], | |
| "word_hints": ["birth", "naissance", "nesans", "marriage", "mariage", "maryaz", "certificate", "extract"], | |
| "forbid_words": ["passport", "driving", "cbrd", "brn"], | |
| }, | |
| "business": { | |
| "file_hints": ["business_registration_cbrd", "cbrd"], | |
| "word_hints": ["business", "brn", "cbrd", "register", "trade fee"], | |
| "forbid_words": ["passport", "driving", "civil status"], | |
| }, | |
| } | |
| def _guess_topic(query: str) -> str: | |
| q = (query or "").lower() | |
| if any(w in q for w in _TOPIC_MAP["passport"]["word_hints"]): | |
| return "passport" | |
| if any(w in q for w in _TOPIC_MAP["driving"]["word_hints"]): | |
| return "driving" | |
| if any(w in q for w in _TOPIC_MAP["civil"]["word_hints"]): | |
| return "civil" | |
| if any(w in q for w in _TOPIC_MAP["business"]["word_hints"]): | |
| return "business" | |
| return "" # unknown → no routing | |
| def _hit_file(h: Dict) -> str: | |
| # Try several common fields for filepath | |
| return ( | |
| h.get("file") | |
| or h.get("source") | |
| or (h.get("meta") or {}).get("file") | |
| or "" | |
| ).lower() | |
| def _filter_hits(hits: List[Dict], query: str, keep: int = 4) -> List[Dict]: | |
| """ | |
| Prefer non-placeholder + topic-consistent chunks. | |
| - 1) Drop placeholders | |
| - 2) If topic known: score by filename hits + keyword overlap | |
| - 3) Return top 'keep' items | |
| """ | |
| if not hits: | |
| return [] | |
| # 1) remove placeholders | |
| pool = [ | |
| h for h in hits | |
| if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower() | |
| ] or hits | |
| topic = _guess_topic(query) | |
| if not topic: | |
| return pool[:keep] | |
| hints = _TOPIC_MAP[topic] | |
| file_hints = hints["file_hints"] | |
| word_hints = set(hints["word_hints"]) | |
| forbid = set(hints["forbid_words"]) | |
| def score(h: Dict) -> float: | |
| s = 0.0 | |
| f = _hit_file(h) | |
| t = h["text"].lower() | |
| # filename boosts | |
| if any(k in f for k in file_hints): | |
| s += 2.0 | |
| # keyword overlap boosts | |
| s += sum(1.0 for w in word_hints if w in t) | |
| # forbid words penalty | |
| s -= sum(1.5 for w in forbid if w in t or w in f) | |
| return s | |
| scored = sorted(pool, key=score, reverse=True) | |
| return scored[:keep] | |
| # --- Prompt build & cleaning --------------------------------------------------- | |
| _ECHO_PATTERNS = [ | |
| r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$", | |
| r"^\s*\[Instructions?\].*$", r"^\s*Be concise.*$", r"^\s*Do not invent.*$", | |
| r"^\s*(en|fr|mfe)\s*$", | |
| ] | |
| def _clean_generated(text: str) -> str: | |
| lines = [ln.strip() for ln in text.strip().splitlines()] | |
| out = [] | |
| for ln in lines: | |
| if any(re.match(pat, ln, flags=re.IGNORECASE) for pat in _ECHO_PATTERNS): | |
| continue | |
| out.append(ln) | |
| cleaned = "\n".join(out).strip() | |
| cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) | |
| return cleaned | |
| def _is_echo_or_too_short_or_offtopic(ans: str, question: str, topic: str) -> bool: | |
| a = re.sub(r"\W+", " ", (ans or "").lower()).strip() | |
| q = re.sub(r"\W+", " ", (question or "").lower()).strip() | |
| if len(a) < 60: | |
| return True | |
| if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]): | |
| return True | |
| # crude off-topic guard | |
| if topic == "passport" and ("business" in a or "cbrd" in a or "brn" in a): | |
| return True | |
| if topic == "driving" and ("passport" in a or "cbrd" in a or "brn" in a or "civil status" in a): | |
| return True | |
| if topic == "civil" and ("passport" in a or "driving" in a or "cbrd" in a or "brn" in a): | |
| return True | |
| if topic == "business" and ("passport" in a or "driving" in a or "civil status" in a): | |
| return True | |
| return False | |
| def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str: | |
| lang = normalize_lang(lang) | |
| lang_readable = _lang_name(lang) | |
| topic = _guess_topic(question) | |
| # Strong guardrails in the instruction: stay on topic, bullets only | |
| if lang == "fr": | |
| instruction = ( | |
| "Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. " | |
| "Reste sur le SUJET demandé et ignore les autres documents. Ne répète pas la question. " | |
| "Écris 6–10 puces courtes couvrant: Documents requis, Frais, Où postuler, Délai, Étapes. " | |
| "Si une info manque, dis-le. Pas d'autres sections." | |
| ) | |
| elif lang == "mfe": | |
| instruction = ( | |
| "To enn Copilot Gouv Moris. Servi ZIS konteks. Reste lor SUZET ki finn demande, " | |
| "ignorar lezot dokiman. Pa repete kestyon. Ekri 6–10 pwin kout: Dokiman, Fre, Kot pou al, " | |
| "Letan tretman, Steps. Si info manke, dir li. Pa azout lezot seksion." | |
| ) | |
| else: | |
| instruction = ( | |
| "You are the Mauritius Government Copilot. Use ONLY the context. Stay strictly on the " | |
| "REQUESTED TOPIC and ignore other documents. Do NOT repeat the question. Write 6–10 short " | |
| "bullets covering: Required documents, Fees, Where to apply, Processing time, Steps. " | |
| "If something is missing, say so. No extra sections." | |
| ) | |
| # Add an explicit topic hint to the instruction (helps FLAN stay on track) | |
| if topic: | |
| instruction += f" Topic: {topic}." | |
| ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(hits)] | |
| ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)" | |
| # Prime with leading dash to bias bullet style | |
| prompt = ( | |
| f"Instruction ({lang_readable}): {instruction}\n\n" | |
| f"Context:\n{ctx_block}\n\n" | |
| f"Question: {question}\n\n" | |
| f"Answer ({lang_readable}):\n- " | |
| ) | |
| return prompt | |
| # --- Main entry ---------------------------------------------------------------- | |
| def synthesize_with_evo( | |
| user_query: str, | |
| lang: str, | |
| hits: List[Dict], | |
| mode: str = "extractive", | |
| max_new_tokens: int = 192, | |
| temperature: float = 0.0, | |
| ) -> str: | |
| lang = normalize_lang(lang) | |
| if not hits: | |
| return _extractive_answer(user_query, lang, hits) | |
| # Route/filter hits to keep only on-topic, high-signal chunks | |
| chosen = _filter_hits(hits, user_query, keep=4) | |
| if mode != "generative" or _GENERATOR is None: | |
| return _extractive_answer(user_query, lang, chosen) | |
| prompt = _build_grounded_prompt(user_query, lang, chosen) | |
| try: | |
| text = _GENERATOR.generate( | |
| prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| ) | |
| text = _clean_generated(text) | |
| topic = _guess_topic(user_query) | |
| if _is_echo_or_too_short_or_offtopic(text, user_query, topic): | |
| return _extractive_answer(user_query, lang, chosen) | |
| return "**[Generative]**\n\n" + text | |
| except Exception: | |
| return _extractive_answer(user_query, lang, chosen) | |