ENA-Chatbot / rag_engine.py
Ines1994's picture
Upload 2 files
d808d33 verified
import os
import zipfile
import time
import json
import re
import numpy as np
import requests
from bs4 import BeautifulSoup
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from rank_bm25 import BM25Okapi
from groq import Groq
from config import *
from utils import normalize_arabic, detect_lang
CATEGORY_KEYWORDS = {
"A1": ["عليا", "superieur", "أ1"],
"A2": ["متوسطة", "moyen", "أ2", "cadre"],
"A3": ["أعوان", "a3", "صنف"],
}
# Keywords that signal a different context (to avoid topic collision)
CONTEXT_PENALTIES = {
"journal": ["مجلة", "مقال", "نشر", "ملكية فكرية", "revue", "article"],
"competition": ["مناظرة", "دخول", "مرحلة", "concours", "cycle", "تسجيل"]
}
class ENAEngine:
def __init__(self, groq_token=None):
# Auto-extract DB if missing
if os.path.exists("chroma_ena_db.zip") and not os.path.exists(CHROMA_PATH):
try:
with zipfile.ZipFile("chroma_ena_db.zip", 'r') as zip_ref:
zip_ref.extractall(".")
except:
pass
self.embeddings = HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
model_kwargs={"device":"cpu"},
encode_kwargs={"normalize_embeddings":True}
)
self.vectordb = Chroma(
persist_directory=CHROMA_PATH,
collection_name=COLLECTION_NAME,
embedding_function=self.embeddings
)
self.reranker = CrossEncoder(RERANK_MODEL, device="cpu")
self.llm = Groq(api_key=groq_token) if groq_token else None
self.bm25 = self._load_bm25()
def _load_bm25(self):
try:
col = self.vectordb._collection.get(include=["documents"])
chunks = col["documents"]
if not chunks: return None
return BM25Okapi([c.lower().split() for c in chunks])
except:
return None
def hybrid_search(self, query: str, k: int = TOP_K_SEARCH):
qn = normalize_arabic(query)
vw, bw = (0.7, 0.3) if detect_lang(qn) == "ar" else (0.8, 0.2)
# Vector Search
vdocs = self.vectordb.similarity_search(qn, k=k)
# If BM25 is not ready, return vector results in the correct format
if not self.bm25:
return [{"content": d.page_content, "meta": d.metadata, "rrf_score": 0.5} for d in vdocs]
vrank = {d.page_content: i for i, d in enumerate(vdocs)}
# BM25 Search
try:
col = self.vectordb._collection.get(include=["documents", "metadatas"])
chunks, metas = col["documents"], col["metadatas"]
bsc = self.bm25.get_scores(qn.lower().split())
btop = np.argsort(bsc)[::-1][:k]
brank = {chunks[i]: j for j, i in enumerate(btop)}
# 1. Detection: Is this about a specific stage?
target_cat = None
for cat, bits in CATEGORY_KEYWORDS.items():
if any(b in qn.lower() for b in bits):
target_cat = cat
break
# 2. Metadata Reconstruction mapping
tmeta = {d.page_content: d.metadata for d in vdocs}
for i in btop:
if chunks[i] not in tmeta: tmeta[chunks[i]] = metas[i]
# 3. RRF Fusion with Categorical Boosting
texts = set(vrank)|set(brank)
fused = {}
for t in texts:
# Rank score
score = vw/(vrank.get(t,k+10)+RRF_K) + bw/(brank.get(t,k+10)+RRF_K)
# Metadata-based Boost/Penalty
m = tmeta.get(t, {})
m_cat = m.get("category", "")
m_url = m.get("url", "")
# Boost صفحات المناظرة عند سؤال الشروط
if any(kw in qn for kw in ["شروط", "ترشح", "condition", "candidature"]):
if m_cat in ("concours_ar", "concours_fr"):
score *= 2.0
# عقوبة للصفحة الرئيسية
if m_url.rstrip("/") in ("https://www.ena.tn/ar", "https://www.ena.tn/fr"):
score *= 0.3
# Boost المعلومات العامة للشروط
if "شروط" in qn or "condition" in qn.lower():
if "informations-generales" in m_url:
score *= 1.5
content_lower = t.lower()
# Boost الاستثناءات القانونية للسن — مهمة جداً
EXCEPTION_KEYWORDS = ["استثناء", "مكتب تشغيل", "مكتب التشغيل",
"سنوات العمل", "الجماعات المحلية",
"الأمر عدد 1031", "dérogation", "bureau d'emploi"]
is_conditions_q = any(kw in qn for kw in ["شروط", "سن", "عمر", "condition", "age"])
has_exception = any(kw in content_lower for kw in EXCEPTION_KEYWORDS)
if is_conditions_q and has_exception:
score *= 3.0 # رفع قوي لضمان ظهور الاستثناءات دائماً
# 4. Contextual Penalty (The "Journal vs Competition" fix)
is_competition_q = any(b in qn.lower() for b in CONTEXT_PENALTIES["competition"])
is_journal_q = any(b in qn.lower() for b in CONTEXT_PENALTIES["journal"])
has_journal_terms = any(b in content_lower for b in CONTEXT_PENALTIES["journal"])
# If it's a competition query, penalize journal content heavily
if is_competition_q and has_journal_terms:
score *= 0.1
# If it's specifically a documents query, boost known registration terms
if "وثائق" in qn or "ملف" in qn:
if "استمارة" in content_lower or "بطاقة تعريف" in content_lower:
score *= 1.5
fused[t] = score
ranked_texts = sorted(fused, key=fused.get, reverse=True)[:k]
# Reconstruct results
results = []
for t in ranked_texts:
results.append({
"content": t,
"meta": tmeta.get(t, {}),
"rrf_score": fused[t],
"is_global": tmeta.get(t, {}).get("category") == "other"
})
return results
except Exception:
# Fallback if anything goes wrong with BM25
return [{"content": d.page_content, "meta": d.metadata, "rrf_score": 0.5} for d in vdocs]
def rerank(self, query: str, results: list):
if not results: return []
cands = results[:20]
scores = self.reranker.predict([(query, r["content"][:1024]) for r in cands])
scored = []
for r, s in zip(cands, scores):
conf = 1 / (1 + np.exp(-float(s))) # Sigmoid
scored.append({**r, "confidence": conf})
return sorted(scored, key=lambda x:x["confidence"], reverse=True)[:TOP_K_RERANK]
def expand_query(self, query: str):
if not self.llm: return [query]
try:
resp = self.llm.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role":"user","content":f'Generate 3 short alternative search queries for: "{query}" in Arabic and French. Return JSON list only.'}],
max_tokens=100
)
m = re.search(r'\[.*?\]', resp.choices[0].message.content, re.DOTALL)
if m: return [query] + json.loads(m.group())
except:
pass
return [query]
def should_scrape(self, query: str) -> str:
"""Agentic decision: Should I fetch a live page?"""
# Simplified for now, can be LLM-based later
q = query.lower()
triggers = ["جديد", "موعد", "متى", "date", "nouvelle", "actualité"]
if any(t in q for t in triggers):
# Find best URL match
for key, url in PAGE_URLS.items():
if key.replace('_',' ') in q: return url
return None
def scrape_url(self, url: str):
try:
r = requests.get(url, headers={"User-Agent":"Mozilla/5.0"}, timeout=10)
soup = BeautifulSoup(r.text, "html.parser")
for t in soup(["script","style","nav","footer","header"]): t.decompose()
return soup.get_text(" ", strip=True)[:3500]
except:
return None
def get_system_prompt(self, lang="ar"):
if lang == "ar":
return """أنت 'خبير قانوني' مختص حصرياً في قوانين المدرسة الوطنية للإدارة (ENA) بتونس.
### قواعد صارمة للإجابة:
1. **الالتزام بالسياق**: أجب فقط بناءً على النصوص المرفقة (Context). لا تستخدم معلوماتك العامة أبداً.
2. **شروط الترشح — الأركان الأربعة الإلزامية**:
عند أي سؤال عن "شروط الترشح" أو "شروط المناظرة"، يجب ذكر الأركان الأربعة كاملة:
- [السن]: اذكر الرقم بدقة (35 للمرحلة العليا، 40 لأ2 وأ3).
- [الشهادات]: اذكر جميع الشهادات المذكورة في السياق.
- [الجنسية]: الجنسية التونسية.
- [الحقوق المدنية]: التمتع بالحقوق المدنية.
- [الاستثناءات]: ذكر استثناءات السن (مكتب التشغيل + سنوات العمل الإداري) وجوبي إذا وُجدت في السياق.
3. **التمييز بين المناظرة الخارجية والداخلية**:
- الخارجية: للطلبة وحاملي الشهادات من خارج الإدارة.
- الداخلية: للموظفين والأعوان العموميين المرسمين.
ميّز بينهما دائماً إذا وُجدا في السياق.
4. **التمييز بين ملفات الترشح**:
- ملف مناظرة (concours): استمارة + نسخة ب.ت.و + شهادة علمية. هذا المقصود في 99% من الأسئلة.
- ملف المجلة (Journal/Revue): مقال + سيرة داخلية + حقوق ملكية. لا تخلط بينهما أبداً!
5. **التثبت من المصدر**: إذا كان النص يتحدث عن 'المجلة التونسية للإدارة'، فهو ليس ملف المناظرة.
6. **في حال فقدان المعلومة**: إذا لم تجد الإجابة في النصوص المرفقة، قل صراحةً:
"هذه المعلومة غير متوفرة في وثائقي الحالية، يرجى التواصل مع ENA مباشرة: info@ena.tn أو 71 848 300"
7. **الدقة الرقمية**: يُمنع تجاهل أي رقم أو سن أو استثناء قانوني مذكور في السياق.
8. **المصادر**: اذكر رقم المصدر [1] بعد كل معلومة مباشرة."""
else:
return """Tu es un Expert Juridique ENA Tunisie.
### Règles strictes :
1. **Contexte uniquement** : Ne réponds que sur la base du contexte fourni. N'utilise JAMAIS tes connaissances générales.
2. **Conditions de candidature — 4 éléments obligatoires** :
Pour toute question sur les "conditions", mentionner impérativement :
- [Âge] : 35 ans max (Cycle Supérieur), 40 ans (A2/A3).
- [Diplômes] : Tous les diplômes mentionnés dans le contexte.
- [Nationalité] : Nationalité tunisienne.
- [Droits civils] : Jouissance des droits civils.
- [Dérogations] : Dérogations d'âge (bureau d'emploi + service public) si présentes.
3. **Concours Externe vs Interne** :
- Externe : Pour les étudiants et diplômés hors administration.
- Interne : Pour les fonctionnaires titulaires.
4. **Information manquante** : Si l'info n'est pas dans le contexte :
"Information non disponible. Contactez l'ENA : info@ena.tn ou 71 848 300"
5. **Rigueur numérique** : Ne jamais omettre un chiffre, un âge ou une dérogation légale.
6. **Sources** : Citer la référence [1] après chaque information."""