|
|
from pathlib import Path |
|
|
import json |
|
|
import re |
|
|
from typing import Dict, Any, List, Optional |
|
|
import random |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
from retriever import get_retriever |
|
|
from config import KINYARWANDA_STOPWORDS, calculate_similarity_score |
|
|
|
|
|
ROOT = Path(__file__).parent |
|
|
CONTEXT_PATH = ROOT / 'conversation_contexts.json' |
|
|
|
|
|
|
|
|
class Assistant: |
|
|
def __init__(self): |
|
|
self.retriever = get_retriever() |
|
|
|
|
|
self.laws = None |
|
|
self.punishments = None |
|
|
self.greetings = None |
|
|
self.contexts: Dict[str, List[Dict[str, Any]]] = {} |
|
|
self.intent_keywords = { |
|
|
'greeting': ['mwaramutse', 'muraho', 'amakuru', 'bite', 'mwiriwe', 'urabeho', 'urakomeye', 'ndabona'], |
|
|
'law': ['itegeko', 'ingingo', 'article', 'ingingo ya', 'itegeko rya', 'law', 'article', 'ingingo'], |
|
|
'punishment': ['igihano', 'ibihano', 'ihazabu', 'igifungo', 'fine', 'imyaka', 'years', 'punishment'] |
|
|
} |
|
|
self.load_datasets() |
|
|
self._load_contexts() |
|
|
|
|
|
def load_datasets(self): |
|
|
|
|
|
try: |
|
|
if self.retriever.laws_df is None: |
|
|
self.retriever.load_laws() |
|
|
self.laws = self.retriever.laws_df |
|
|
except Exception: |
|
|
self.laws = None |
|
|
|
|
|
|
|
|
ppath = ROOT / 'penal_code.csv' |
|
|
if ppath.exists(): |
|
|
try: |
|
|
self.punishments = pd.read_csv(ppath).fillna('') |
|
|
except Exception: |
|
|
self.punishments = None |
|
|
else: |
|
|
self.punishments = None |
|
|
|
|
|
|
|
|
gpath = ROOT / 'greetings.csv' |
|
|
if gpath.exists(): |
|
|
try: |
|
|
self.greetings = pd.read_csv(gpath).fillna('') |
|
|
except Exception: |
|
|
self.greetings = None |
|
|
else: |
|
|
|
|
|
try: |
|
|
self.greetings = self.retriever.greetings_df |
|
|
except Exception: |
|
|
self.greetings = None |
|
|
|
|
|
def _load_contexts(self): |
|
|
if CONTEXT_PATH.exists(): |
|
|
try: |
|
|
with open(CONTEXT_PATH, 'r', encoding='utf-8') as f: |
|
|
self.contexts = json.load(f) |
|
|
except Exception: |
|
|
self.contexts = {} |
|
|
|
|
|
def _save_contexts(self): |
|
|
try: |
|
|
with open(CONTEXT_PATH, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.contexts, f, ensure_ascii=False, indent=2) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def tokenize(self, text: str) -> List[str]: |
|
|
if not text: |
|
|
return [] |
|
|
txt = str(text).lower() |
|
|
|
|
|
txt = re.sub(r"[\r\n]+", " ", txt) |
|
|
txt = re.sub(r"[^\w\s\u00C0-\u017F]", " ", txt) |
|
|
toks = [t for t in txt.split() if t and t not in KINYARWANDA_STOPWORDS] |
|
|
return toks |
|
|
|
|
|
def detect_intent(self, text: str) -> str: |
|
|
t = str(text).lower() |
|
|
toks = set(self.tokenize(t)) |
|
|
|
|
|
scores = {k: 0 for k in self.intent_keywords} |
|
|
for intent, keywords in self.intent_keywords.items(): |
|
|
for kw in keywords: |
|
|
if kw in t or kw in toks: |
|
|
scores[intent] += 1 |
|
|
|
|
|
|
|
|
if scores.get('greeting', 0) > 0: |
|
|
return 'greeting' |
|
|
|
|
|
|
|
|
if scores.get('punishment', 0) > 0: |
|
|
return 'punishment' |
|
|
|
|
|
if scores.get('law', 0) > 0: |
|
|
return 'law' |
|
|
|
|
|
|
|
|
if re.search(r'\bingingo\b|\bingingo ya\b|\barticle\b|\bitegeko\b', t): |
|
|
return 'law' |
|
|
|
|
|
|
|
|
return 'unclear' |
|
|
|
|
|
def _update_context(self, user_id: str, entry: Dict[str, Any]): |
|
|
self.contexts.setdefault(user_id, []).append(entry) |
|
|
|
|
|
if len(self.contexts[user_id]) > 50: |
|
|
self.contexts[user_id] = self.contexts[user_id][-50:] |
|
|
self._save_contexts() |
|
|
|
|
|
def handle_query(self, user_id: str, text: str) -> Dict[str, Any]: |
|
|
|
|
|
|
|
|
self._update_context(user_id, {'role': 'user', 'text': text}) |
|
|
|
|
|
try: |
|
|
reply = self.retriever.detect_and_reply_greeting(text) |
|
|
except Exception: |
|
|
reply = None |
|
|
|
|
|
if reply: |
|
|
out = {'type': 'greeting', 'response': reply.get('response', ''), 'followup': reply.get('followup', '')} |
|
|
|
|
|
self._update_context(user_id, {'role': 'assistant', 'text': out}) |
|
|
return out |
|
|
|
|
|
|
|
|
intent = self.detect_intent(text) |
|
|
|
|
|
self._update_context(user_id, {'role': 'user', 'text': text, 'intent': intent}) |
|
|
|
|
|
if intent == 'law': |
|
|
|
|
|
try: |
|
|
|
|
|
self.retriever.build_or_load_embeddings() |
|
|
results = self.retriever.find_similar(text, top_k=1) |
|
|
except Exception: |
|
|
results = [] |
|
|
|
|
|
if results: |
|
|
score, meta = results[0] |
|
|
law_row = meta.get('row') |
|
|
out = {'type': 'law', 'score': score, 'law': law_row} |
|
|
self._update_context(user_id, {'role': 'assistant', 'text': out}) |
|
|
return out |
|
|
|
|
|
return {'type': 'unclear', 'text': "I couldn't find a matching law. Can you be more specific?"} |
|
|
|
|
|
if intent == 'punishment': |
|
|
|
|
|
if self.punishments is not None: |
|
|
|
|
|
best = None |
|
|
best_score = 0.0 |
|
|
for _, row in self.punishments.iterrows(): |
|
|
desc = ' '.join([str(row.get(c, '')) for c in row.index]) |
|
|
s = calculate_similarity_score(text, desc) |
|
|
if s > best_score: |
|
|
best_score = s |
|
|
best = row.to_dict() |
|
|
|
|
|
if best is not None and best_score > 0: |
|
|
out = {'type': 'punishment', 'score': best_score, 'punishment_row': best} |
|
|
self._update_context(user_id, {'role': 'assistant', 'text': out}) |
|
|
return out |
|
|
|
|
|
|
|
|
return {'type': 'unclear', 'text': "I couldn't find a matching punishment. Can you provide more detail?"} |
|
|
|
|
|
|
|
|
out = {'type': 'unclear', 'text': "Can you please try a legal question? I'm here to assist you."} |
|
|
self._update_context(user_id, {'role': 'assistant', 'text': out}) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
_ASSISTANT: Optional[Assistant] = None |
|
|
|
|
|
|
|
|
def get_assistant() -> Assistant: |
|
|
global _ASSISTANT |
|
|
if _ASSISTANT is None: |
|
|
_ASSISTANT = Assistant() |
|
|
return _ASSISTANT |
|
|
|