hajimammad commited on
Commit
9ef9347
·
verified ·
1 Parent(s): c6902d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -12
app.py CHANGED
@@ -1,20 +1,208 @@
1
- # app.py (smoke test)
2
- import gradio as gr, os
3
 
4
- def ping(q: str):
5
- return f"pong · echo: {q or ''}"
 
6
 
7
- with gr.Blocks(title="Smoke Test") as demo:
8
- gr.Markdown("✅ UI up. If you see this, Gradio is fine.")
9
- inp = gr.Textbox(label="input")
10
- out = gr.Textbox(label="output")
11
- btn = gr.Button("Ping")
12
- btn.click(ping, inp, out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  if __name__ == "__main__":
15
- # روی Spaces نیازی به پارامتر خاصی نیست؛ SSR را هم خاموش نگه‌دار که safe باشد
16
  try:
17
- demo = demo.queue() # پایدارتر روی 5.x
18
  except TypeError:
19
  pass
20
  demo.launch(ssr_mode=False)
 
1
+ # -*- coding: utf-8 -*-
2
+ # Mahoon Minimal RAG + Generation (ZeroGPU-ready, no training)
3
 
4
+ import os
5
+ import json
6
+ import gradio as gr
7
 
8
+ # =========================
9
+ # ZeroGPU shim & marker
10
+ # =========================
11
+ try:
12
+ import spaces # provided by HF Spaces runtime
13
+ except Exception:
14
+ class _NoSpaces:
15
+ @staticmethod
16
+ def GPU(*a, **k):
17
+ def w(fn): return fn
18
+ return w
19
+ spaces = _NoSpaces()
20
+
21
+ @spaces.GPU(duration=180) # وجود این تابع جلوی ارور No @spaces.GPU را می‌گیرد
22
+ def _zgpu_marker():
23
+ return "ok"
24
+
25
+ # =========================
26
+ # RAG (Chroma)
27
+ # =========================
28
+ import chromadb
29
+ from chromadb.config import Settings
30
+
31
+ CHROMA_DIR = os.environ.get("CHROMA_DIR", "./chroma_db")
32
+ CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "legal_articles")
33
+
34
+ def _norm_id(x: str) -> str:
35
+ x = (x or "").replace("\u064A","ی").replace("\u0643","ک")
36
+ trans = {ord(a): b for a,b in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹","01234567890123456789")}
37
+ return "".join((x.translate(trans))).replace(" ", "")
38
+
39
+ def build_rag():
40
+ client = chromadb.PersistentClient(
41
+ path=CHROMA_DIR,
42
+ settings=Settings(anonymized_telemetry=False)
43
+ )
44
+ try:
45
+ col = client.get_or_create_collection(CHROMA_COLLECTION)
46
+ except Exception:
47
+ col = client.get_collection(CHROMA_COLLECTION)
48
+ return col
49
+
50
+ def retrieve(col, query: str, top_k: int, thr: float):
51
+ try:
52
+ res = col.query(
53
+ query_texts=[query],
54
+ n_results=int(top_k),
55
+ include=["documents","metadatas","distances"]
56
+ )
57
+ docs = res.get("documents",[[]])[0]
58
+ metas= res.get("metadatas",[[]])[0]
59
+ dists= res.get("distances",[[]])[0]
60
+ out=[]
61
+ for i,(d,m,dist) in enumerate(zip(docs, metas, dists)):
62
+ sim = 1.0 - float(dist)
63
+ if sim >= float(thr):
64
+ out.append({
65
+ "article_id": _norm_id((m or {}).get("article_id", f"unk_{i}")),
66
+ "text": d,
67
+ "similarity": sim
68
+ })
69
+ return out
70
+ except Exception:
71
+ return []
72
+
73
+ def build_context(arts, limit=320):
74
+ if not arts: return ""
75
+ bullets = [f"• ماده {a['article_id']}: {a['text'][:limit]}..." for a in arts]
76
+ return "مواد مرتبط:\n" + "\n".join(bullets)
77
+
78
+ # =========================
79
+ # Generation (Transformers)
80
+ # =========================
81
+ # برای اجتناب از نیاز زودهنگام به torch، import را داخل توابع انجام می‌دهیم.
82
+ MODEL_CHOICES = {
83
+ "Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct",
84
+ "Llama 3.2 3B Instruct": "meta-llama/Llama-3.2-3B-Instruct",
85
+ "Mistral 7B Instruct v0.2": "mistralai/Mistral-7B-Instruct-v0.2"
86
+ }
87
+ DEFAULT_MODEL_KEY = os.environ.get("DEFAULT_MODEL_KEY", "Llama 3.2 3B Instruct")
88
+
89
+ _loader = {"tk": None, "model_id": None}
90
+ _rag_col = None
91
+
92
+ def lazy_bootstrap(selected_key: str):
93
+ """اتصال به ایندکس RAG + Warm tokenizer. وزن مدل را بعداً در تابع GPU لود می‌کنیم."""
94
+ global _rag_col, _loader
95
+ # RAG
96
+ if _rag_col is None:
97
+ try:
98
+ _rag_col = build_rag()
99
+ except Exception as e:
100
+ return f"❌ خطا در اتصال RAG: {e}"
101
+
102
+ # Tokenizer
103
+ wanted = MODEL_CHOICES.get(selected_key, MODEL_CHOICES[DEFAULT_MODEL_KEY])
104
+ if _loader["model_id"] != wanted or _loader["tk"] is None:
105
+ from transformers import AutoTokenizer
106
+ tk = AutoTokenizer.from_pretrained(wanted)
107
+ if tk.pad_token is None and tk.eos_token:
108
+ tk.pad_token = tk.eos_token
109
+ _loader.update({"tk": tk, "model_id": wanted})
110
+
111
+ return f"✅ آماده · ایندکس: {CHROMA_COLLECTION} · مدل: {wanted}"
112
+
113
+ def _format_prompt(context: str, question: str) -> str:
114
+ if context:
115
+ return f"{context}\nسوال: {question}\nپاسخ:"
116
+ return f"سوال: {question}\nپاسخ:"
117
+
118
+ @spaces.GPU(duration=240)
119
+ def answer_gpu(model_key, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p):
120
+ """اینفرنس روی GPU (ZeroGPU per-call)."""
121
+ try:
122
+ if not question or not question.strip():
123
+ return "لطفاً سؤال را وارد کنید.", ""
124
+
125
+ # RAG
126
+ arts = retrieve(_rag_col, question, int(top_k), float(thr)) if use_rag else []
127
+ ctx = build_context(arts) if arts else ""
128
+ prompt = _format_prompt(ctx, question)
129
+
130
+ # بارگذاری وزن‌ها روی GPU رزروشده
131
+ from transformers import AutoTokenizer, AutoModelForCausalLM
132
+ model_id = _loader["model_id"] or MODEL_CHOICES.get(model_key) or MODEL_CHOICES[DEFAULT_MODEL_KEY]
133
+ tk = _loader["tk"] or AutoTokenizer.from_pretrained(model_id)
134
+ mdl = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") # ZeroGPU → GPU attach
135
+
136
+ enc = tk(prompt, return_tensors="pt")
137
+ enc = {k: v.to(mdl.device) for k,v in enc.items()}
138
+ out = mdl.generate(
139
+ **enc,
140
+ max_new_tokens=int(max_new_tokens),
141
+ do_sample=True,
142
+ temperature=float(temperature),
143
+ top_p=float(top_p),
144
+ pad_token_id=tk.pad_token_id or tk.eos_token_id
145
+ )
146
+ text = tk.decode(out[0], skip_special_tokens=True)
147
+
148
+ refs = ""
149
+ if arts:
150
+ refs = "\n\n" + "\n".join([
151
+ f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..."
152
+ for a in arts
153
+ ])
154
+ return text, refs
155
+ except Exception as e:
156
+ return f"❌ خطای اینفرنس: {e}", ""
157
+
158
+ # =========================
159
+ # UI (Gradio 5.47)
160
+ # =========================
161
+ with gr.Blocks(title="Mahoon — Minimal RAG+Gen", theme=gr.themes.Soft()) as demo:
162
+ gr.Markdown("""
163
+ <div style='text-align:center;padding:14px'>
164
+ <h2 style='margin:0'>ماحون (مینیمال) — پاسخ حقوقی با RAG</h2>
165
+ <p style='color:#666'>اینفرنس ZeroGPU · ایندکس آماده · بدون آموزش</p>
166
+ </div>
167
+ """)
168
+
169
+ with gr.Row():
170
+ model_dd = gr.Dropdown(choices=list(MODEL_CHOICES.keys()),
171
+ value=DEFAULT_MODEL_KEY,
172
+ label="مدل تولید")
173
+ use_rag = gr.Checkbox(value=True, label="استفاده از RAG؟")
174
+ top_k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
175
+ thr = gr.Slider(0.50, 0.95, value=0.60, step=0.01, label="آستانه شباهت")
176
+
177
+ with gr.Accordion("پارامترهای تولید", open=False):
178
+ max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="max_new_tokens")
179
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
180
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
181
+
182
+ question = gr.Textbox(lines=3, label="سؤال")
183
+ ask_btn = gr.Button("پرسش", variant="primary")
184
+ answer = gr.Markdown(label="پاسخ")
185
+ refs = gr.Markdown(label="مواد مرتبط")
186
+
187
+ status = gr.Markdown("⏳ آماده‌سازی…")
188
+
189
+ def _warmup(mkey):
190
+ try:
191
+ return lazy_bootstrap(mkey)
192
+ except Exception as e:
193
+ return f"❌ Bootstrap error: {e}"
194
+
195
+ demo.load(_warmup, inputs=[model_dd], outputs=status)
196
+
197
+ ask_btn.click(
198
+ answer_gpu,
199
+ inputs=[model_dd, question, use_rag, top_k, thr, max_new_tokens, temperature, top_p],
200
+ outputs=[answer, refs]
201
+ )
202
 
203
  if __name__ == "__main__":
 
204
  try:
205
+ demo = demo.queue() # پایدارتر روی Gradio 5.x
206
  except TypeError:
207
  pass
208
  demo.launch(ssr_mode=False)