Spaces:
Sleeping
Sleeping
Update evo_inference.py
Browse files- evo_inference.py +166 -61
evo_inference.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
-
evo_inference.py — FLAN-optimized + anti-echo
|
| 3 |
-
-
|
| 4 |
-
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
- Labeled outputs: [Generative] / [Extractive]
|
| 8 |
"""
|
| 9 |
|
| 10 |
from typing import List, Dict
|
|
@@ -23,12 +22,15 @@ except Exception:
|
|
| 23 |
except Exception:
|
| 24 |
_GENERATOR = None
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def _snippet(text: str) -> str:
|
| 29 |
text = " ".join(text.split())
|
| 30 |
return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
|
| 31 |
|
|
|
|
| 32 |
def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
| 33 |
if not hits:
|
| 34 |
return "**[Extractive]**\n\n" + L(lang, "intro_err")
|
|
@@ -61,57 +63,102 @@ def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
|
| 61 |
f"**Suggested steps:**\n" + "\n".join(steps)
|
| 62 |
)
|
| 63 |
|
|
|
|
| 64 |
def _lang_name(code: str) -> str:
|
| 65 |
return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
|
| 66 |
|
| 67 |
-
def _filter_hits(hits: List[Dict], keep: int = 4) -> List[Dict]:
|
| 68 |
-
# Prefer non-placeholder chunks; if all are placeholders, return originals.
|
| 69 |
-
filtered = [h for h in hits if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()]
|
| 70 |
-
if not filtered:
|
| 71 |
-
filtered = hits
|
| 72 |
-
return filtered[:keep]
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
Answer: - bullet - bullet ...
|
| 81 |
"""
|
| 82 |
-
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
| 91 |
-
elif lang == "mfe":
|
| 92 |
-
instruction = (
|
| 93 |
-
"To enn Copilot Gouv Moris. Reponn zis lor konteks. Pa repete kestyon. Donn 6–10 pwin kout "
|
| 94 |
-
"lor: Dokiman, Fre, Kot pou al, Letan tretman, Steps. Si info manke, dir li. Pa azout seksion anplis."
|
| 95 |
-
)
|
| 96 |
-
else:
|
| 97 |
-
instruction = (
|
| 98 |
-
"You are the Mauritius Government Copilot. Use ONLY the context. Do not repeat the question. "
|
| 99 |
-
"Write 6–10 short bullet points covering: Required documents, Fees, Where to apply, "
|
| 100 |
-
"Processing time, and Steps. If something is missing, say so. No extra sections."
|
| 101 |
-
)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
_ECHO_PATTERNS = [
|
| 117 |
r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
|
|
@@ -120,7 +167,6 @@ _ECHO_PATTERNS = [
|
|
| 120 |
]
|
| 121 |
|
| 122 |
def _clean_generated(text: str) -> str:
|
| 123 |
-
# Remove common echoed lines from the model output.
|
| 124 |
lines = [ln.strip() for ln in text.strip().splitlines()]
|
| 125 |
out = []
|
| 126 |
for ln in lines:
|
|
@@ -131,34 +177,92 @@ def _clean_generated(text: str) -> str:
|
|
| 131 |
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
| 132 |
return cleaned
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
|
| 137 |
q = re.sub(r"\W+", " ", (question or "").lower()).strip()
|
| 138 |
-
if len(a) <
|
| 139 |
return True
|
| 140 |
if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
|
| 141 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
return False
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
def synthesize_with_evo(
|
| 145 |
user_query: str,
|
| 146 |
lang: str,
|
| 147 |
hits: List[Dict],
|
| 148 |
mode: str = "extractive",
|
| 149 |
max_new_tokens: int = 192,
|
| 150 |
-
temperature: float = 0.
|
| 151 |
) -> str:
|
| 152 |
-
# No context → safe fallback
|
| 153 |
lang = normalize_lang(lang)
|
|
|
|
| 154 |
if not hits:
|
| 155 |
return _extractive_answer(user_query, lang, hits)
|
| 156 |
|
| 157 |
-
#
|
|
|
|
|
|
|
| 158 |
if mode != "generative" or _GENERATOR is None:
|
| 159 |
-
return _extractive_answer(user_query, lang,
|
| 160 |
|
| 161 |
-
prompt = _build_grounded_prompt(user_query, lang,
|
| 162 |
try:
|
| 163 |
text = _GENERATOR.generate(
|
| 164 |
prompt,
|
|
@@ -166,8 +270,9 @@ def synthesize_with_evo(
|
|
| 166 |
temperature=float(temperature),
|
| 167 |
)
|
| 168 |
text = _clean_generated(text)
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
return "**[Generative]**\n\n" + text
|
| 172 |
except Exception:
|
| 173 |
-
return _extractive_answer(user_query, lang,
|
|
|
|
| 1 |
"""
|
| 2 |
+
evo_inference.py — FLAN-optimized + topic router + anti-echo/off-topic
|
| 3 |
+
- Routes queries to the right topic (passport / driving / civil status / business)
|
| 4 |
+
- Prefers chunks whose filename/text match the topic; filters placeholders
|
| 5 |
+
- FLAN-friendly prompt; cleans prompt-echo; falls back if echo/too short/off-topic
|
| 6 |
+
- Labels outputs: [Generative] / [Extractive]
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from typing import List, Dict
|
|
|
|
| 22 |
except Exception:
|
| 23 |
_GENERATOR = None
|
| 24 |
|
| 25 |
+
# Keep snippets short so FLAN-T5 stays within encoder limit (512)
|
| 26 |
+
MAX_SNIPPET_CHARS = 220
|
| 27 |
+
|
| 28 |
|
| 29 |
def _snippet(text: str) -> str:
|
| 30 |
text = " ".join(text.split())
|
| 31 |
return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
|
| 32 |
|
| 33 |
+
|
| 34 |
def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
| 35 |
if not hits:
|
| 36 |
return "**[Extractive]**\n\n" + L(lang, "intro_err")
|
|
|
|
| 63 |
f"**Suggested steps:**\n" + "\n".join(steps)
|
| 64 |
)
|
| 65 |
|
| 66 |
+
|
| 67 |
def _lang_name(code: str) -> str:
|
| 68 |
return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# --- Topic routing -------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
_TOPIC_MAP = {
|
| 74 |
+
"passport": {
|
| 75 |
+
"file_hints": ["passport_renewal", "passport"],
|
| 76 |
+
"word_hints": ["passport", "passeport", "paspor", "renew", "renouvel"],
|
| 77 |
+
"forbid_words": ["business", "cbrd", "brn", "driving", "licence", "license", "civil status"],
|
| 78 |
+
},
|
| 79 |
+
"driving": {
|
| 80 |
+
"file_hints": ["driving_licence", "driving_license"],
|
| 81 |
+
"word_hints": ["driving", "licence", "license", "permit", "idp", "pf-77"],
|
| 82 |
+
"forbid_words": ["passport", "cbrd", "brn", "civil status"],
|
| 83 |
+
},
|
| 84 |
+
"civil": {
|
| 85 |
+
"file_hints": ["birth_marriage_certificate", "civil_status"],
|
| 86 |
+
"word_hints": ["birth", "naissance", "nesans", "marriage", "mariage", "maryaz", "certificate", "extract"],
|
| 87 |
+
"forbid_words": ["passport", "driving", "cbrd", "brn"],
|
| 88 |
+
},
|
| 89 |
+
"business": {
|
| 90 |
+
"file_hints": ["business_registration_cbrd", "cbrd"],
|
| 91 |
+
"word_hints": ["business", "brn", "cbrd", "register", "trade fee"],
|
| 92 |
+
"forbid_words": ["passport", "driving", "civil status"],
|
| 93 |
+
},
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def _guess_topic(query: str) -> str:
|
| 97 |
+
q = (query or "").lower()
|
| 98 |
+
if any(w in q for w in _TOPIC_MAP["passport"]["word_hints"]):
|
| 99 |
+
return "passport"
|
| 100 |
+
if any(w in q for w in _TOPIC_MAP["driving"]["word_hints"]):
|
| 101 |
+
return "driving"
|
| 102 |
+
if any(w in q for w in _TOPIC_MAP["civil"]["word_hints"]):
|
| 103 |
+
return "civil"
|
| 104 |
+
if any(w in q for w in _TOPIC_MAP["business"]["word_hints"]):
|
| 105 |
+
return "business"
|
| 106 |
+
return "" # unknown → no routing
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _hit_file(h: Dict) -> str:
|
| 110 |
+
# Try several common fields for filepath
|
| 111 |
+
return (
|
| 112 |
+
h.get("file")
|
| 113 |
+
or h.get("source")
|
| 114 |
+
or (h.get("meta") or {}).get("file")
|
| 115 |
+
or ""
|
| 116 |
+
).lower()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _filter_hits(hits: List[Dict], query: str, keep: int = 4) -> List[Dict]:
|
| 120 |
"""
|
| 121 |
+
Prefer non-placeholder + topic-consistent chunks.
|
| 122 |
+
- 1) Drop placeholders
|
| 123 |
+
- 2) If topic known: score by filename hits + keyword overlap
|
| 124 |
+
- 3) Return top 'keep' items
|
|
|
|
| 125 |
"""
|
| 126 |
+
if not hits:
|
| 127 |
+
return []
|
| 128 |
|
| 129 |
+
# 1) remove placeholders
|
| 130 |
+
pool = [
|
| 131 |
+
h for h in hits
|
| 132 |
+
if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()
|
| 133 |
+
] or hits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
topic = _guess_topic(query)
|
| 136 |
+
if not topic:
|
| 137 |
+
return pool[:keep]
|
| 138 |
|
| 139 |
+
hints = _TOPIC_MAP[topic]
|
| 140 |
+
file_hints = hints["file_hints"]
|
| 141 |
+
word_hints = set(hints["word_hints"])
|
| 142 |
+
forbid = set(hints["forbid_words"])
|
| 143 |
+
|
| 144 |
+
def score(h: Dict) -> float:
|
| 145 |
+
s = 0.0
|
| 146 |
+
f = _hit_file(h)
|
| 147 |
+
t = h["text"].lower()
|
| 148 |
+
# filename boosts
|
| 149 |
+
if any(k in f for k in file_hints):
|
| 150 |
+
s += 2.0
|
| 151 |
+
# keyword overlap boosts
|
| 152 |
+
s += sum(1.0 for w in word_hints if w in t)
|
| 153 |
+
# forbid words penalty
|
| 154 |
+
s -= sum(1.5 for w in forbid if w in t or w in f)
|
| 155 |
+
return s
|
| 156 |
+
|
| 157 |
+
scored = sorted(pool, key=score, reverse=True)
|
| 158 |
+
return scored[:keep]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# --- Prompt build & cleaning ---------------------------------------------------
|
| 162 |
|
| 163 |
_ECHO_PATTERNS = [
|
| 164 |
r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
|
|
|
|
| 167 |
]
|
| 168 |
|
| 169 |
def _clean_generated(text: str) -> str:
|
|
|
|
| 170 |
lines = [ln.strip() for ln in text.strip().splitlines()]
|
| 171 |
out = []
|
| 172 |
for ln in lines:
|
|
|
|
| 177 |
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
| 178 |
return cleaned
|
| 179 |
|
| 180 |
+
|
| 181 |
+
def _is_echo_or_too_short_or_offtopic(ans: str, question: str, topic: str) -> bool:
|
| 182 |
a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
|
| 183 |
q = re.sub(r"\W+", " ", (question or "").lower()).strip()
|
| 184 |
+
if len(a) < 60:
|
| 185 |
return True
|
| 186 |
if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
|
| 187 |
return True
|
| 188 |
+
# crude off-topic guard
|
| 189 |
+
if topic == "passport" and ("business" in a or "cbrd" in a or "brn" in a):
|
| 190 |
+
return True
|
| 191 |
+
if topic == "driving" and ("passport" in a or "cbrd" in a or "brn" in a or "civil status" in a):
|
| 192 |
+
return True
|
| 193 |
+
if topic == "civil" and ("passport" in a or "driving" in a or "cbrd" in a or "brn" in a):
|
| 194 |
+
return True
|
| 195 |
+
if topic == "business" and ("passport" in a or "driving" in a or "civil status" in a):
|
| 196 |
+
return True
|
| 197 |
return False
|
| 198 |
|
| 199 |
+
|
| 200 |
+
def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str:
|
| 201 |
+
lang = normalize_lang(lang)
|
| 202 |
+
lang_readable = _lang_name(lang)
|
| 203 |
+
topic = _guess_topic(question)
|
| 204 |
+
|
| 205 |
+
# Strong guardrails in the instruction: stay on topic, bullets only
|
| 206 |
+
if lang == "fr":
|
| 207 |
+
instruction = (
|
| 208 |
+
"Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. "
|
| 209 |
+
"Reste sur le SUJET demandé et ignore les autres documents. Ne répète pas la question. "
|
| 210 |
+
"Écris 6–10 puces courtes couvrant: Documents requis, Frais, Où postuler, Délai, Étapes. "
|
| 211 |
+
"Si une info manque, dis-le. Pas d'autres sections."
|
| 212 |
+
)
|
| 213 |
+
elif lang == "mfe":
|
| 214 |
+
instruction = (
|
| 215 |
+
"To enn Copilot Gouv Moris. Servi ZIS konteks. Reste lor SUZET ki finn demande, "
|
| 216 |
+
"ignorar lezot dokiman. Pa repete kestyon. Ekri 6–10 pwin kout: Dokiman, Fre, Kot pou al, "
|
| 217 |
+
"Letan tretman, Steps. Si info manke, dir li. Pa azout lezot seksion."
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
instruction = (
|
| 221 |
+
"You are the Mauritius Government Copilot. Use ONLY the context. Stay strictly on the "
|
| 222 |
+
"REQUESTED TOPIC and ignore other documents. Do NOT repeat the question. Write 6–10 short "
|
| 223 |
+
"bullets covering: Required documents, Fees, Where to apply, Processing time, Steps. "
|
| 224 |
+
"If something is missing, say so. No extra sections."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Add an explicit topic hint to the instruction (helps FLAN stay on track)
|
| 228 |
+
if topic:
|
| 229 |
+
instruction += f" Topic: {topic}."
|
| 230 |
+
|
| 231 |
+
ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(hits)]
|
| 232 |
+
ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)"
|
| 233 |
+
|
| 234 |
+
# Prime with leading dash to bias bullet style
|
| 235 |
+
prompt = (
|
| 236 |
+
f"Instruction ({lang_readable}): {instruction}\n\n"
|
| 237 |
+
f"Context:\n{ctx_block}\n\n"
|
| 238 |
+
f"Question: {question}\n\n"
|
| 239 |
+
f"Answer ({lang_readable}):\n- "
|
| 240 |
+
)
|
| 241 |
+
return prompt
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# --- Main entry ----------------------------------------------------------------
|
| 245 |
+
|
| 246 |
def synthesize_with_evo(
|
| 247 |
user_query: str,
|
| 248 |
lang: str,
|
| 249 |
hits: List[Dict],
|
| 250 |
mode: str = "extractive",
|
| 251 |
max_new_tokens: int = 192,
|
| 252 |
+
temperature: float = 0.0,
|
| 253 |
) -> str:
|
|
|
|
| 254 |
lang = normalize_lang(lang)
|
| 255 |
+
|
| 256 |
if not hits:
|
| 257 |
return _extractive_answer(user_query, lang, hits)
|
| 258 |
|
| 259 |
+
# Route/filter hits to keep only on-topic, high-signal chunks
|
| 260 |
+
chosen = _filter_hits(hits, user_query, keep=4)
|
| 261 |
+
|
| 262 |
if mode != "generative" or _GENERATOR is None:
|
| 263 |
+
return _extractive_answer(user_query, lang, chosen)
|
| 264 |
|
| 265 |
+
prompt = _build_grounded_prompt(user_query, lang, chosen)
|
| 266 |
try:
|
| 267 |
text = _GENERATOR.generate(
|
| 268 |
prompt,
|
|
|
|
| 270 |
temperature=float(temperature),
|
| 271 |
)
|
| 272 |
text = _clean_generated(text)
|
| 273 |
+
topic = _guess_topic(user_query)
|
| 274 |
+
if _is_echo_or_too_short_or_offtopic(text, user_query, topic):
|
| 275 |
+
return _extractive_answer(user_query, lang, chosen)
|
| 276 |
return "**[Generative]**\n\n" + text
|
| 277 |
except Exception:
|
| 278 |
+
return _extractive_answer(user_query, lang, chosen)
|