| import os |
| import time |
| import pickle |
| import subprocess |
| from pathlib import Path |
|
|
| import re |
| import json |
| import textwrap |
| from typing import List, Dict, Any, Optional |
| from pydantic import BaseModel |
|
|
| import numpy as np |
| import faiss |
| import gradio as gr |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") |
|
|
| _index = None |
| _texts = None |
| _metas = None |
| _embed_model = None |
|
|
|
|
| |
| try: |
| from google import genai |
| except Exception: |
| genai = None |
|
|
| |
| ROOT = Path(__file__).resolve().parent |
| FLATTENED_JSONL = ROOT / "data" / "processed" / "flattened_docs.jsonl" |
| INDEX_FILE = ROOT / "law_index.faiss" |
| META_FILE = ROOT / "law_meta.pkl" |
| EMB_MODEL = "paraphrase-multilingual-MiniLM-L12-v2" |
|
|
| TOP_K = 5 |
| MAX_NEW_TOKENS = 1024 |
| EMB_MODEL_NAME = EMB_MODEL |
|
|
| |
| ARABIC_DIACRITICS_RE = re.compile(r'[\u0610-\u061A\u064B-\u065F\u06D6-\u06DC\u06DF-\u06E8\u06EA-\u06ED]') |
| ARABIC_STOPWORDS = {"من","في","على","إلى","عن","ما","هو","هي","لم","لن","إن","أن","كل","قد","أو","و","التي","الذي","الذين","هذا","هذه","ذلك","تلك","مع","أنّ","إلا","كان","كانت","هناك","أي","سواء","بعد","قبل","حتى"} |
|
|
| def strip_diacritics_arabic(text: str) -> str: |
| return ARABIC_DIACRITICS_RE.sub("", text) if text else "" |
|
|
| def normalize_text_for_match(s: str) -> str: |
| if not s: |
| return "" |
| s = strip_diacritics_arabic(s) |
| s = s.lower() |
| s = re.sub(r"[^\w\u0600-\u06FF]+", " ", s) |
| s = re.sub(r"\s+", " ", s).strip() |
| return s |
|
|
| def strip_prefixes(token: str) -> str: |
| if not token: |
| return token |
| t = token |
| while t.startswith("و") and len(t) > 1: |
| t = t[1:] |
| if t.startswith("ال") and len(t) > 2: |
| t = t[2:] |
| if t.startswith("و") and len(t) > 1: |
| t = t[1:] |
| return t |
|
|
| def tokenize_and_clean(s: str): |
| if not s: |
| return [] |
| s_norm = normalize_text_for_match(s) |
| tokens = [t for t in s_norm.split() if t] |
| processed = [] |
| for t in tokens: |
| t2 = strip_prefixes(t).strip() |
| if not t2: |
| continue |
| if t2 in ARABIC_STOPWORDS: |
| continue |
| processed.append(t2) |
| return processed |
|
|
| |
| ARABIC_CHAR_RE = re.compile(r'[\u0600-\u06FF]') |
| LATIN_CHAR_RE = re.compile(r'[A-Za-zÀ-ÖØ-öø-ÿ]') |
|
|
| def detect_language(s: str) -> str: |
| if not s or not isinstance(s, str): |
| return "other" |
| ar_count = len(ARABIC_CHAR_RE.findall(s)) |
| lat_count = len(LATIN_CHAR_RE.findall(s)) |
| if ar_count > 0 and ar_count >= lat_count: |
| return "ar" |
| if lat_count > 0 and lat_count > ar_count: |
| return "fr" |
| return "other" |
|
|
| def prepare_index_and_meta(): |
| flat = FLATTENED_JSONL |
| if not flat.exists(): |
| raise FileNotFoundError(f"{FLATTENED_JSONL} not found.") |
|
|
| texts = [] |
| metas = [] |
| with flat.open("r", encoding="utf-8") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| if not line: |
| continue |
| obj = json.loads(line) |
| text = obj.get("text","") |
| mada = obj.get("mada") or obj.get("id", f"m{i+1}") |
| bab = obj.get("bab") or obj.get("fasl") or "" |
| source = obj.get("source") or bab or "" |
| _id = obj.get("id") or f"{i+1:05d}" |
| texts.append(text) |
| lang = detect_language(text or mada or bab or source) |
| metas.append({"id": _id, "mada": mada, "bab": bab, "source": source, "lang": lang}) |
|
|
| if INDEX_FILE.exists() and META_FILE.exists(): |
| try: |
| index = faiss.read_index(str(INDEX_FILE)) |
| with open(META_FILE, "rb") as f: |
| meta_pkl = pickle.load(f) |
| if isinstance(meta_pkl, list) and len(meta_pkl) == len(texts): |
| for i, m in enumerate(meta_pkl): |
| if "lang" not in m or not m.get("lang"): |
| meta_pkl[i]["lang"] = detect_language(texts[i] or m.get("mada") or m.get("bab") or m.get("source") or "") |
| metas = meta_pkl |
| except Exception: |
| pass |
| else: |
| embedder = SentenceTransformer(EMB_MODEL_NAME) |
| embs = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True) |
| if embs.dtype != np.float32: |
| embs = embs.astype(np.float32) |
| faiss.normalize_L2(embs) |
| dim = embs.shape[1] |
| index = faiss.IndexFlatIP(dim) |
| index.add(embs) |
| faiss.write_index(index, str(INDEX_FILE)) |
| with open(META_FILE, "wb") as f: |
| pickle.dump(metas, f) |
| return texts, metas, index |
|
|
| def load_index_and_embedder(): |
| global _index, _texts, _metas, _embed_model |
| if _index is None or _texts is None: |
| _texts, _metas, _index = prepare_index_and_meta() |
| _embed_model = SentenceTransformer(EMB_MODEL_NAME) |
| return _index, _texts, _metas, _embed_model |
|
|
|
|
| |
| if not INDEX_FILE.exists() or not META_FILE.exists(): |
| print("⚠️ Index missing → building with build.py") |
| subprocess.run(["python", "build.py"], check=True) |
| print("✅ Index built") |
| else: |
| print("✅ Index already exists") |
|
|
| |
| print("🔄 Loading FAISS + embedder...") |
|
|
|
|
| _index, _texts, _metas, _embed_model = load_index_and_embedder() |
|
|
| texts = _texts |
| metas = _metas |
| index = _index |
| embedder = _embed_model |
|
|
| print("✅ Ready") |
|
|
| def embed_query(text): |
| _, _, _, embed_model = load_index_and_embedder() |
| v = embed_model.encode([text], convert_to_numpy=True) |
| if v.dtype != np.float32: |
| v = v.astype(np.float32) |
| faiss.normalize_L2(v) |
| return v |
|
|
|
|
|
|
| |
| def call_gemini_generate(prompt, model_name="gemini-2.5-flash", max_output_tokens=None, temperature=0.0): |
| api_key = GEMINI_API_KEY |
| if genai is None: |
| raise RuntimeError("google-genai library not installed. pip install google-genai") |
| if not api_key: |
| raise RuntimeError("GEMINI_API_KEY not found in environment. Set it before running.") |
|
|
| client = genai.Client(api_key=api_key) |
| kwargs = {"model": model_name, "contents": prompt} |
| try: |
| resp = client.models.generate_content(**kwargs) |
| except Exception as e: |
| raise |
|
|
| text = extract_text_from_gemini_response(resp) |
| return text |
|
|
|
|
| def retrieve(query, top_k=TOP_K, prefer_same_language=True, strict_same_language=False): |
| index, texts, metas, _ = load_index_and_embedder() |
| q_lang = detect_language(query) |
| qv = embed_query(query) |
| D, I = index.search(qv, max(top_k * 6, top_k)) |
| candidates = [] |
| for score, idx in zip(D[0], I[0]): |
| if idx < 0 or idx >= len(texts): |
| continue |
| candidates.append({"score": float(score), "idx": int(idx), "text": texts[idx], "meta": metas[idx]}) |
|
|
| if strict_same_language and q_lang in ("ar", "fr"): |
| same_lang = [c for c in candidates if c.get("meta", {}).get("lang") == q_lang] |
| return sorted(same_lang, key=lambda x: x["score"], reverse=True)[:top_k] |
|
|
| if prefer_same_language and q_lang in ("ar", "fr"): |
| same_lang = [c for c in candidates if c.get("meta", {}).get("lang") == q_lang] |
| if len(same_lang) >= top_k: |
| return sorted(same_lang, key=lambda x: x["score"], reverse=True)[:top_k] |
| selected = sorted(same_lang, key=lambda x: x["score"], reverse=True) |
| others = [c for c in candidates if c.get("meta", {}).get("lang") != q_lang] |
| others_sorted = sorted(others, key=lambda x: x["score"], reverse=True) |
| selected.extend(others_sorted[: max(0, top_k - len(selected))]) |
| return selected |
|
|
| return sorted(candidates, key=lambda x: x["score"], reverse=True)[:top_k] |
|
|
|
|
| def build_instructional_prompt_from_retrieved(query, retrieved): |
| context_parts = [] |
| for i, r in enumerate(retrieved, start=1): |
| m = r.get("meta", {}) |
| mada = m.get("mada", "") |
| bab = m.get("bab", "") |
| src = m.get("source", "") or bab or "" |
| text = r.get("text", "") |
| |
| context_parts.append(f"المصدر {i}: ({mada} : {bab} : {src})\n{text}") |
|
|
| context = "\n\n".join(context_parts) |
| q_lang = detect_language(query) |
| |
| if q_lang == "fr": |
| context = context.replace("المصدر ", "Source ") |
| system_line = "SYSTEM: Vous êtes un avocat virtuel spécialisé en droit marocain." |
| instr = textwrap.dedent(f"""\ |
| Reformulez les extraits suivants et produisez une réponse juridique unique et structurée — en français — comprenant, dans l'ordre : |
| 1) Un résumé bref (2-3 phrases). |
| 2) Une analyse juridique détaillée en s'appuyant exclusivement sur les extraits, en citant après chaque point (Article : Chapitre : Source). |
| 3) Une conclusion / conseil pratique court. |
| 4) Liste des références utilisées. |
| |
| Ne rajoutez pas d'informations extérieures aux extraits. Si les extraits sont insuffisants, indiquez-le clairement. |
| |
| Extraits: |
| -------------------- |
| {context} |
| -------------------- |
| |
| Exigence : réponse organisée avec sous-titres (Résumé, Analyse juridique, Conclusion/Conseil, Références). |
| """) |
| else: |
| system_line = "SYSTEM: أنت محامٍ افتراضي متخصص في القانون المغربي." |
| instr = textwrap.dedent(f"""\ |
| أعد صياغة المقتطفات التالية وأنتج إجابة قانونية واحدة ومتكاملة — باللغة العربية الفصحى — وتتضمن بالترتيب: |
| 1) خلاصة موجزة (2-3 جمل). |
| 2) تحليل قانوني مفصّل يستند حصريًا إلى المقتطفات مع الإشارة بعد كل نقطة بالشكل (المادة : الباب : المصدر). |
| 3) استنتاج / نصيحة عملية قصيرة. |
| 4) قائمة المراجع المستخدمة. |
| |
| التزم بالمقتطفات ولا تضف معلومات خارجها. إن كانت المقتطفات غير كافية فاذكر ذلك صراحة. |
| |
| المقتطفات: |
| -------------------- |
| {context} |
| -------------------- |
| |
| المطلوب: إجابة واحدة منظمة مع عناوين فرعية: (الخلاصة، التحليل القانوني، الاستنتاج/النصيحة العملية، المراجع). |
| """) |
| full_prompt = f"{system_line}\nQUESTION: {query}\n\n{instr}\nANSWER:\n" |
| return full_prompt |
| |
| def extract_text_from_gemini_response(resp) -> str: |
| try: |
| if hasattr(resp, "text"): |
| txt = resp.text |
| if callable(txt): |
| txt = txt() |
| if txt: |
| return txt |
| except Exception: |
| pass |
|
|
| try: |
| candidates = getattr(resp, "candidates", None) or getattr(resp, "Candidates", None) |
| if candidates: |
| first = candidates[0] |
| for attr in ("content", "Content"): |
| cont = getattr(first, attr, None) or (first.get(attr) if isinstance(first, dict) else None) |
| if cont: |
| parts = getattr(cont, "parts", None) or (cont.get("parts") if isinstance(cont, dict) else None) |
| if parts and len(parts) > 0: |
| p0 = parts[0] |
| if isinstance(p0, dict): |
| t = p0.get("text") or p0.get("Text") |
| else: |
| t = getattr(p0, "text", None) or getattr(p0, "Text", None) |
| if t: |
| return t |
| if hasattr(first, "text"): |
| t = first.text |
| if callable(t): |
| t = t() |
| if t: |
| return t |
| except Exception: |
| pass |
|
|
| try: |
| return str(resp) |
| except Exception: |
| return None |
|
|
|
|
|
|
| class QueryRequest(BaseModel): |
| query: str |
| top_k: int = TOP_K |
| max_tokens: int = MAX_NEW_TOKENS |
| temperature: float = 0.0 |
| model_name: str = "gemini-2.5-flash" |
| prefer_same_language: bool = True |
| strict_same_language: bool = False |
| include_prompt: bool = False |
|
|
| class QueryResponse(BaseModel): |
| answer: str |
| retrieved: List[Dict[str, Any]] |
| prompt: Optional[str] = None |
| total_time: float |
| retrieval_time: float |
| generation_time: float |
| query_lang: str |
|
|
|
|
| |
|
|
| def show_loading(): |
| return "🔄 **جارٍ البحث والتوليد…**\n\n⏳" |
|
|
| def hide_loading(): |
| return gr.update(visible=True) |
|
|
| def render_history(history): |
| if not history: |
| return "لا يوجد سجل بعد" |
| return "\n".join( |
| [f"**{i+1}. {h['query']}** \n⏱️ {h['total']:.2f}s" |
| for i, h in enumerate(history)] |
| ) |
|
|
| def select_item(idx, history): |
| if idx is None or idx >= len(history): |
| return "", "", "" |
| h = history[idx] |
| stats = ( |
| f"⏱️ الإجمالي: {h['total']:.2f}s | " |
| f"🔍 الاسترجاع: {h['retr']:.2f}s | " |
| f"✨ التوليد: {h['gen']:.2f}s" |
| ) |
| return h["answer"], h["sources"], stats |
|
|
| def format_sources(retrieved): |
| if not retrieved: |
| return "لا توجد مصادر" |
| return "\n\n".join( |
| [ |
| f"**{i+1}. {r['meta'].get('mada','')}** \n{r['text']}" |
| for i, r in enumerate(retrieved) |
| ] |
| ) |
|
|
|
|
| def update_history(q, ans, src, hist): |
| hist.append({ |
| "query": q, |
| "answer": ans, |
| "sources": src, |
| "total": 0.0, |
| "retr": 0.0, |
| "gen": 0.0, |
| }) |
| return hist, render_history(hist) |
|
|
|
|
| |
| with gr.Blocks(title="⚖️ محامي افتراضي مغربي") as demo: |
| gr.Markdown("# ⚖️ محامي افتراضي\n### القانون المغربي بين يديك") |
|
|
| history_state = gr.State([]) |
| retrieved_state = gr.State() |
| prompt_state = gr.State() |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("## 📝 السجل") |
| history_md = gr.Markdown("لا يوجد سجل بعد") |
| history_index = gr.Number( |
| label="اختر رقم السؤال", |
| precision=0, |
| value=1 |
| ) |
|
|
| |
| with gr.Column(scale=3): |
| query = gr.Textbox( |
| placeholder="اكتب هنا / Écrivez ici", |
| label="❓ سؤالك القانوني", |
| lines=2 |
| ) |
|
|
| ask = gr.Button("🔍 اسأل", variant="primary") |
|
|
| loading_md = gr.Markdown("🔄 **جارٍ البحث والتوليد…**\n\n⏳", visible=False) |
|
|
| answer = gr.Markdown(label="📜 الجواب") |
|
|
| with gr.Accordion("📚 المصادر المستخدمة", open=False): |
| sources = gr.Markdown("—") |
|
|
| stats = gr.Markdown("—") |
|
|
| |
| with gr.Column(scale=1): |
| gr.Markdown("## ⚙️ الإعدادات") |
| top_k = gr.Slider(1, 8, value=3, step=1, label="عدد المقتطفات") |
| temperature = gr.Slider(0, 1, value=0.0, step=0.05, label="العشوائية") |
|
|
|
|
| |
| ask.click( |
| show_loading, |
| outputs=loading_md |
| ).then( |
| lambda q, k: retrieve(q, top_k=k), |
| inputs=[query, top_k], |
| outputs=retrieved_state |
| ).then( |
| lambda q, r: build_instructional_prompt_from_retrieved(q, r), |
| inputs=[query, retrieved_state], |
| outputs=prompt_state |
| ).then( |
| lambda p, t: call_gemini_generate(p, temperature=t), |
| inputs=[prompt_state, temperature], |
| outputs=answer |
| ).then( |
| lambda r: format_sources(r), |
| inputs=retrieved_state, |
| outputs=sources |
| ).then( |
| update_history, |
| inputs=[query, answer, sources, history_state], |
| outputs=[history_state, history_md] |
| ).then( |
| hide_loading, |
| outputs=loading_md |
| ) |
|
|
| history_index.change( |
| lambda i, h: select_item(int(i) - 1, h), |
| inputs=[history_index, history_state], |
| outputs=[answer, sources, stats] |
| ) |
|
|
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|