Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Mahoon — Minimal RAG + Generation (ZeroGPU-ready, no training) | |
| import os | |
| import json | |
| import gradio as gr | |
| # ========================= | |
| # ZeroGPU shim & marker | |
| # ========================= | |
| try: | |
| import spaces # provided by HF Spaces runtime | |
| except Exception: | |
| class _NoSpaces: | |
| def GPU(*a, **k): | |
| def w(fn): return fn | |
| return w | |
| spaces = _NoSpaces() | |
| # وجود این تابع جلوی ارور No @spaces.GPU را میگیرد | |
| def _zgpu_marker(): | |
| return "ok" | |
| # ========================= | |
| # RAG (Chroma) | |
| # ========================= | |
| import chromadb | |
| from chromadb.config import Settings | |
| CHROMA_DIR = os.environ.get("CHROMA_DIR", "./chroma_db") | |
| CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "legal_articles") | |
| def _norm_id(x: str) -> str: | |
| x = (x or "").replace("\u064A","ی").replace("\u0643","ک") | |
| trans = {ord(a): b for a,b in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹","01234567890123456789")} | |
| return "".join((x.translate(trans))).replace(" ", "") | |
| def build_rag(): | |
| client = chromadb.PersistentClient( | |
| path=CHROMA_DIR, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| try: | |
| col = client.get_or_create_collection(CHROMA_COLLECTION) | |
| except Exception: | |
| col = client.get_collection(CHROMA_COLLECTION) | |
| return col | |
| def retrieve(col, query: str, top_k: int, thr: float): | |
| try: | |
| res = col.query( | |
| query_texts=[query], | |
| n_results=int(top_k), | |
| include=["documents","metadatas","distances"] | |
| ) | |
| docs = res.get("documents",[[]])[0] | |
| metas= res.get("metadatas",[[]])[0] | |
| dists= res.get("distances",[[]])[0] | |
| out=[] | |
| for i,(d,m,dist) in enumerate(zip(docs, metas, dists)): | |
| sim = 1.0 - float(dist) | |
| if sim >= float(thr): | |
| out.append({ | |
| "article_id": _norm_id((m or {}).get("article_id", f"unk_{i}")), | |
| "text": d, | |
| "similarity": sim | |
| }) | |
| return out | |
| except Exception: | |
| return [] | |
| def build_context(arts, limit=320): | |
| if not arts: return "" | |
| bullets = [f"• ماده {a['article_id']}: {a['text'][:limit]}..." for a in arts] | |
| return "مواد مرتبط:\n" + "\n".join(bullets) | |
| # ========================= | |
| # Generation (Transformers) | |
| # ========================= | |
| # برای اجتناب از نیاز زودهنگام به torch، import را داخل توابع انجام میدهیم. | |
| MODEL_CHOICES = { | |
| "Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct", | |
| "Llama 3.2 3B Instruct": "meta-llama/Llama-3.2-3B-Instruct", | |
| "Mistral 7B Instruct v0.2": "mistralai/Mistral-7B-Instruct-v0.2" | |
| } | |
| DEFAULT_MODEL_KEY = os.environ.get("DEFAULT_MODEL_KEY", "Llama 3.2 3B Instruct") | |
| _loader = {"tk": None, "model_id": None} | |
| _rag_col = None | |
| def lazy_bootstrap(selected_key: str): | |
| """اتصال به ایندکس RAG + Warm tokenizer. وزن مدل را بعداً در تابع GPU لود میکنیم.""" | |
| global _rag_col, _loader | |
| # RAG | |
| if _rag_col is None: | |
| try: | |
| _rag_col = build_rag() | |
| except Exception as e: | |
| return f"❌ خطا در اتصال RAG: {e}" | |
| # Tokenizer | |
| wanted = MODEL_CHOICES.get(selected_key, MODEL_CHOICES[DEFAULT_MODEL_KEY]) | |
| if _loader["model_id"] != wanted or _loader["tk"] is None: | |
| from transformers import AutoTokenizer | |
| tk = AutoTokenizer.from_pretrained(wanted) | |
| if tk.pad_token is None and tk.eos_token: | |
| tk.pad_token = tk.eos_token | |
| _loader.update({"tk": tk, "model_id": wanted}) | |
| return f"✅ آماده · ایندکس: {CHROMA_COLLECTION} · مدل: {wanted}" | |
| def _format_prompt(context: str, question: str) -> str: | |
| if context: | |
| return f"{context}\nسوال: {question}\nپاسخ:" | |
| return f"سوال: {question}\nپاسخ:" | |
| def answer_gpu(model_key, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p): | |
| """اینفرنس روی GPU (ZeroGPU per-call).""" | |
| try: | |
| if not question or not question.strip(): | |
| return "لطفاً سؤال را وارد کنید.", "" | |
| # RAG | |
| arts = retrieve(_rag_col, question, int(top_k), float(thr)) if use_rag else [] | |
| ctx = build_context(arts) if arts else "" | |
| prompt = _format_prompt(ctx, question) | |
| # بارگذاری وزنها روی GPU رزروشده | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_id = _loader["model_id"] or MODEL_CHOICES.get(model_key) or MODEL_CHOICES[DEFAULT_MODEL_KEY] | |
| tk = _loader["tk"] or AutoTokenizer.from_pretrained(model_id) | |
| mdl = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") # ZeroGPU → GPU attach | |
| enc = tk(prompt, return_tensors="pt") | |
| enc = {k: v.to(mdl.device) for k,v in enc.items()} | |
| out = mdl.generate( | |
| **enc, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| pad_token_id=tk.pad_token_id or tk.eos_token_id | |
| ) | |
| text = tk.decode(out[0], skip_special_tokens=True) | |
| refs = "" | |
| if arts: | |
| refs = "\n\n" + "\n".join([ | |
| f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." | |
| for a in arts | |
| ]) | |
| return text, refs | |
| except Exception as e: | |
| return f"❌ خطای اینفرنس: {e}", "" | |
| # ========================= | |
| # UI (Gradio 5.47) | |
| # ========================= | |
| with gr.Blocks(title="Mahoon — Minimal RAG+Gen", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| <div style='text-align:center;padding:14px'> | |
| <h2 style='margin:0'>ماحون (مینیمال) — پاسخ حقوقی با RAG</h2> | |
| <p style='color:#666'>اینفرنس ZeroGPU · ایندکس آماده · بدون آموزش</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| model_dd = gr.Dropdown(choices=list(MODEL_CHOICES.keys()), | |
| value=DEFAULT_MODEL_KEY, | |
| label="مدل تولید") | |
| use_rag = gr.Checkbox(value=True, label="استفاده از RAG؟") | |
| top_k = gr.Slider(1, 10, value=5, step=1, label="Top-K") | |
| thr = gr.Slider(0.50, 0.95, value=0.60, step=0.01, label="آستانه شباهت") | |
| with gr.Accordion("پارامترهای تولید", open=False): | |
| max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") | |
| question = gr.Textbox(lines=3, label="سؤال") | |
| ask_btn = gr.Button("پرسش", variant="primary") | |
| answer = gr.Markdown(label="پاسخ") | |
| refs = gr.Markdown(label="مواد مرتبط") | |
| status = gr.Markdown("⏳ آمادهسازی…") | |
| def _warmup(mkey): | |
| try: | |
| return lazy_bootstrap(mkey) | |
| except Exception as e: | |
| return f"❌ Bootstrap error: {e}" | |
| demo.load(_warmup, inputs=[model_dd], outputs=status) | |
| ask_btn.click( | |
| answer_gpu, | |
| inputs=[model_dd, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p], | |
| outputs=[answer, refs] | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| demo = demo.queue() # پایدارتر روی Gradio 5.x | |
| except TypeError: | |
| pass | |
| demo.launch(ssr_mode=False) | |