mahoon-legal-ai / app.py
hajimammad's picture
Update app.py
9ef9347 verified
raw
history blame
7.74 kB
# -*- coding: utf-8 -*-
# Mahoon — Minimal RAG + Generation (ZeroGPU-ready, no training)
import os
import json
import gradio as gr
# =========================
# ZeroGPU shim & marker
# =========================
try:
import spaces # provided by HF Spaces runtime
except Exception:
class _NoSpaces:
@staticmethod
def GPU(*a, **k):
def w(fn): return fn
return w
spaces = _NoSpaces()
@spaces.GPU(duration=180) # وجود این تابع جلوی ارور No @spaces.GPU را می‌گیرد
def _zgpu_marker():
return "ok"
# =========================
# RAG (Chroma)
# =========================
import chromadb
from chromadb.config import Settings
CHROMA_DIR = os.environ.get("CHROMA_DIR", "./chroma_db")
CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "legal_articles")
def _norm_id(x: str) -> str:
x = (x or "").replace("\u064A","ی").replace("\u0643","ک")
trans = {ord(a): b for a,b in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹","01234567890123456789")}
return "".join((x.translate(trans))).replace(" ", "")
def build_rag():
client = chromadb.PersistentClient(
path=CHROMA_DIR,
settings=Settings(anonymized_telemetry=False)
)
try:
col = client.get_or_create_collection(CHROMA_COLLECTION)
except Exception:
col = client.get_collection(CHROMA_COLLECTION)
return col
def retrieve(col, query: str, top_k: int, thr: float):
try:
res = col.query(
query_texts=[query],
n_results=int(top_k),
include=["documents","metadatas","distances"]
)
docs = res.get("documents",[[]])[0]
metas= res.get("metadatas",[[]])[0]
dists= res.get("distances",[[]])[0]
out=[]
for i,(d,m,dist) in enumerate(zip(docs, metas, dists)):
sim = 1.0 - float(dist)
if sim >= float(thr):
out.append({
"article_id": _norm_id((m or {}).get("article_id", f"unk_{i}")),
"text": d,
"similarity": sim
})
return out
except Exception:
return []
def build_context(arts, limit=320):
if not arts: return ""
bullets = [f"• ماده {a['article_id']}: {a['text'][:limit]}..." for a in arts]
return "مواد مرتبط:\n" + "\n".join(bullets)
# =========================
# Generation (Transformers)
# =========================
# برای اجتناب از نیاز زودهنگام به torch، import را داخل توابع انجام می‌دهیم.
MODEL_CHOICES = {
"Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct",
"Llama 3.2 3B Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Mistral 7B Instruct v0.2": "mistralai/Mistral-7B-Instruct-v0.2"
}
DEFAULT_MODEL_KEY = os.environ.get("DEFAULT_MODEL_KEY", "Llama 3.2 3B Instruct")
_loader = {"tk": None, "model_id": None}
_rag_col = None
def lazy_bootstrap(selected_key: str):
"""اتصال به ایندکس RAG + Warm tokenizer. وزن مدل را بعداً در تابع GPU لود می‌کنیم."""
global _rag_col, _loader
# RAG
if _rag_col is None:
try:
_rag_col = build_rag()
except Exception as e:
return f"❌ خطا در اتصال RAG: {e}"
# Tokenizer
wanted = MODEL_CHOICES.get(selected_key, MODEL_CHOICES[DEFAULT_MODEL_KEY])
if _loader["model_id"] != wanted or _loader["tk"] is None:
from transformers import AutoTokenizer
tk = AutoTokenizer.from_pretrained(wanted)
if tk.pad_token is None and tk.eos_token:
tk.pad_token = tk.eos_token
_loader.update({"tk": tk, "model_id": wanted})
return f"✅ آماده · ایندکس: {CHROMA_COLLECTION} · مدل: {wanted}"
def _format_prompt(context: str, question: str) -> str:
if context:
return f"{context}\nسوال: {question}\nپاسخ:"
return f"سوال: {question}\nپاسخ:"
@spaces.GPU(duration=240)
def answer_gpu(model_key, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p):
"""اینفرنس روی GPU (ZeroGPU per-call)."""
try:
if not question or not question.strip():
return "لطفاً سؤال را وارد کنید.", ""
# RAG
arts = retrieve(_rag_col, question, int(top_k), float(thr)) if use_rag else []
ctx = build_context(arts) if arts else ""
prompt = _format_prompt(ctx, question)
# بارگذاری وزن‌ها روی GPU رزروشده
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = _loader["model_id"] or MODEL_CHOICES.get(model_key) or MODEL_CHOICES[DEFAULT_MODEL_KEY]
tk = _loader["tk"] or AutoTokenizer.from_pretrained(model_id)
mdl = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") # ZeroGPU → GPU attach
enc = tk(prompt, return_tensors="pt")
enc = {k: v.to(mdl.device) for k,v in enc.items()}
out = mdl.generate(
**enc,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tk.pad_token_id or tk.eos_token_id
)
text = tk.decode(out[0], skip_special_tokens=True)
refs = ""
if arts:
refs = "\n\n" + "\n".join([
f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..."
for a in arts
])
return text, refs
except Exception as e:
return f"❌ خطای اینفرنس: {e}", ""
# =========================
# UI (Gradio 5.47)
# =========================
with gr.Blocks(title="Mahoon — Minimal RAG+Gen", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<div style='text-align:center;padding:14px'>
<h2 style='margin:0'>ماحون (مینیمال) — پاسخ حقوقی با RAG</h2>
<p style='color:#666'>اینفرنس ZeroGPU · ایندکس آماده · بدون آموزش</p>
</div>
""")
with gr.Row():
model_dd = gr.Dropdown(choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL_KEY,
label="مدل تولید")
use_rag = gr.Checkbox(value=True, label="استفاده از RAG؟")
top_k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
thr = gr.Slider(0.50, 0.95, value=0.60, step=0.01, label="آستانه شباهت")
with gr.Accordion("پارامترهای تولید", open=False):
max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
question = gr.Textbox(lines=3, label="سؤال")
ask_btn = gr.Button("پرسش", variant="primary")
answer = gr.Markdown(label="پاسخ")
refs = gr.Markdown(label="مواد مرتبط")
status = gr.Markdown("⏳ آماده‌سازی…")
def _warmup(mkey):
try:
return lazy_bootstrap(mkey)
except Exception as e:
return f"❌ Bootstrap error: {e}"
demo.load(_warmup, inputs=[model_dd], outputs=status)
ask_btn.click(
answer_gpu,
inputs=[model_dd, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p],
outputs=[answer, refs]
)
if __name__ == "__main__":
try:
demo = demo.queue() # پایدارتر روی Gradio 5.x
except TypeError:
pass
demo.launch(ssr_mode=False)