mbaza-model / assistant.py
mugwaneza's picture
Deploy Mbaza Legal AI Model with inference endpoint
fc1c893
from pathlib import Path
import json
import re
from typing import Dict, Any, List, Optional
import random
import pandas as pd
from retriever import get_retriever
from config import KINYARWANDA_STOPWORDS, calculate_similarity_score
ROOT = Path(__file__).parent
CONTEXT_PATH = ROOT / 'conversation_contexts.json'
class Assistant:
def __init__(self):
self.retriever = get_retriever()
# load datasets
self.laws = None
self.punishments = None
self.greetings = None
self.contexts: Dict[str, List[Dict[str, Any]]] = {}
self.intent_keywords = {
'greeting': ['mwaramutse', 'muraho', 'amakuru', 'bite', 'mwiriwe', 'urabeho', 'urakomeye', 'ndabona'],
'law': ['itegeko', 'ingingo', 'article', 'ingingo ya', 'itegeko rya', 'law', 'article', 'ingingo'],
'punishment': ['igihano', 'ibihano', 'ihazabu', 'igifungo', 'fine', 'imyaka', 'years', 'punishment']
}
self.load_datasets()
self._load_contexts()
def load_datasets(self):
# laws are loaded by retriever; ensure available
try:
if self.retriever.laws_df is None:
self.retriever.load_laws()
self.laws = self.retriever.laws_df
except Exception:
self.laws = None
# punishments fallback to penal_code.csv
ppath = ROOT / 'penal_code.csv'
if ppath.exists():
try:
self.punishments = pd.read_csv(ppath).fillna('')
except Exception:
self.punishments = None
else:
self.punishments = None
# greetings
gpath = ROOT / 'greetings.csv'
if gpath.exists():
try:
self.greetings = pd.read_csv(gpath).fillna('')
except Exception:
self.greetings = None
else:
# retriever uses greetings.csv or kin... so try retriever load
try:
self.greetings = self.retriever.greetings_df
except Exception:
self.greetings = None
def _load_contexts(self):
if CONTEXT_PATH.exists():
try:
with open(CONTEXT_PATH, 'r', encoding='utf-8') as f:
self.contexts = json.load(f)
except Exception:
self.contexts = {}
def _save_contexts(self):
try:
with open(CONTEXT_PATH, 'w', encoding='utf-8') as f:
json.dump(self.contexts, f, ensure_ascii=False, indent=2)
except Exception:
pass
def tokenize(self, text: str) -> List[str]:
if not text:
return []
txt = str(text).lower()
# simple cleaning
txt = re.sub(r"[\r\n]+", " ", txt)
txt = re.sub(r"[^\w\s\u00C0-\u017F]", " ", txt)
toks = [t for t in txt.split() if t and t not in KINYARWANDA_STOPWORDS]
return toks
def detect_intent(self, text: str) -> str:
t = str(text).lower()
toks = set(self.tokenize(t))
# keyword scoring
scores = {k: 0 for k in self.intent_keywords}
for intent, keywords in self.intent_keywords.items():
for kw in keywords:
if kw in t or kw in toks:
scores[intent] += 1
# greeting detection via retriever as fallback
if scores.get('greeting', 0) > 0:
return 'greeting'
# if punishment keywords present
if scores.get('punishment', 0) > 0:
return 'punishment'
if scores.get('law', 0) > 0:
return 'law'
# try to detect if users mention known law numbers or article numbers
if re.search(r'\bingingo\b|\bingingo ya\b|\barticle\b|\bitegeko\b', t):
return 'law'
# default unclear
return 'unclear'
def _update_context(self, user_id: str, entry: Dict[str, Any]):
self.contexts.setdefault(user_id, []).append(entry)
# keep only last 50
if len(self.contexts[user_id]) > 50:
self.contexts[user_id] = self.contexts[user_id][-50:]
self._save_contexts()
def handle_query(self, user_id: str, text: str) -> Dict[str, Any]:
# first, try language-aware greeting detection (this handles ky/en/fr greetings)
# record the raw user message
self._update_context(user_id, {'role': 'user', 'text': text})
try:
reply = self.retriever.detect_and_reply_greeting(text)
except Exception:
reply = None
if reply:
out = {'type': 'greeting', 'response': reply.get('response', ''), 'followup': reply.get('followup', '')}
# record assistant response
self._update_context(user_id, {'role': 'assistant', 'text': out})
return out
# no greeting detected, continue with intent detection
intent = self.detect_intent(text)
# update the last user message with intent information as well
self._update_context(user_id, {'role': 'user', 'text': text, 'intent': intent})
if intent == 'law':
# use retriever find_similar
try:
# ensure embeddings built
self.retriever.build_or_load_embeddings()
results = self.retriever.find_similar(text, top_k=1)
except Exception:
results = []
if results:
score, meta = results[0]
law_row = meta.get('row')
out = {'type': 'law', 'score': score, 'law': law_row}
self._update_context(user_id, {'role': 'assistant', 'text': out})
return out
return {'type': 'unclear', 'text': "I couldn't find a matching law. Can you be more specific?"}
if intent == 'punishment':
# use simple matching against penal_code.csv if available, otherwise use overlap scoring
if self.punishments is not None:
# score by overlap
best = None
best_score = 0.0
for _, row in self.punishments.iterrows():
desc = ' '.join([str(row.get(c, '')) for c in row.index])
s = calculate_similarity_score(text, desc)
if s > best_score:
best_score = s
best = row.to_dict()
if best is not None and best_score > 0:
out = {'type': 'punishment', 'score': best_score, 'punishment_row': best}
self._update_context(user_id, {'role': 'assistant', 'text': out})
return out
# fallback: return unclear
return {'type': 'unclear', 'text': "I couldn't find a matching punishment. Can you provide more detail?"}
# unclear
out = {'type': 'unclear', 'text': "Can you please try a legal question? I'm here to assist you."}
self._update_context(user_id, {'role': 'assistant', 'text': out})
return out
# Singleton assistant
_ASSISTANT: Optional[Assistant] = None
def get_assistant() -> Assistant:
global _ASSISTANT
if _ASSISTANT is None:
_ASSISTANT = Assistant()
return _ASSISTANT