Spaces:
Sleeping
Sleeping
amirhossein mohammadpour
commited on
Commit
·
6ca41c8
1
Parent(s):
3f6908f
change interface
Browse files
app.py
CHANGED
|
@@ -335,8 +335,14 @@ def llm_generate(prompt: str,
|
|
| 335 |
|
| 336 |
# ---- MCQ helpers ----
|
| 337 |
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
|
| 338 |
-
sys_fa =
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
system_text = sys_fa if lang == "fa" else sys_en
|
| 341 |
|
| 342 |
parts = []
|
|
@@ -352,17 +358,76 @@ def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str,
|
|
| 352 |
if lang == "fa":
|
| 353 |
user = (
|
| 354 |
f"سؤال: {question}\n\nگزینهها:\n{opts_str}\n\nمتون بازیابیشده:\n{joined}\n\n"
|
| 355 |
-
|
|
|
|
|
|
|
| 356 |
)
|
| 357 |
else:
|
| 358 |
user = (
|
| 359 |
f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
|
| 360 |
-
'
|
|
|
|
| 361 |
)
|
|
|
|
| 362 |
msgs = [{"role": "system", "content": system_text},
|
| 363 |
{"role": "user", "content": user}]
|
| 364 |
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
|
| 367 |
m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
|
| 368 |
if m:
|
|
@@ -409,17 +474,20 @@ def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, t
|
|
| 409 |
prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
|
| 410 |
out = llm_generate(prompt, max_new_tokens=int(max_tokens),
|
| 411 |
temperature=float(temperature), top_p=float(top_p),
|
| 412 |
-
top_k=int(top_k), do_sample=False)
|
| 413 |
-
parsed =
|
| 414 |
pred = parsed["answer_index"]
|
| 415 |
pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
|
|
|
|
| 416 |
rows = []
|
| 417 |
for i, c in enumerate(ret["contexts"], 1):
|
| 418 |
snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
|
| 419 |
rows.append([i, c["id"], round(c["score"], 4), snip])
|
|
|
|
| 420 |
result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
|
| 421 |
return result, out, rows, ret["route"]
|
| 422 |
|
|
|
|
| 423 |
with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
|
| 424 |
gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
|
| 425 |
with gr.Tab("Ask"):
|
|
@@ -438,11 +506,10 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
|
|
| 438 |
route = gr.Textbox(label="Route used (text_e5 or fusion)")
|
| 439 |
table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 440 |
btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
|
| 441 |
-
|
| 442 |
with gr.Tab("MCQ"):
|
| 443 |
with gr.Row():
|
| 444 |
q_mcq = gr.Textbox(label="Question", lines=3)
|
| 445 |
-
opts_mcq = gr.Textbox(label="Options (one per line)", lines=
|
| 446 |
img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
|
| 447 |
with gr.Row():
|
| 448 |
topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
|
|
@@ -452,8 +519,9 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
|
|
| 452 |
top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
|
| 453 |
top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
|
| 454 |
btn2 = gr.Button("Answer MCQ")
|
| 455 |
-
|
| 456 |
-
|
|
|
|
| 457 |
route2 = gr.Textbox(label="Route used")
|
| 458 |
table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 459 |
btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
|
|
|
|
| 335 |
|
| 336 |
# ---- MCQ helpers ----
|
| 337 |
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
|
| 338 |
+
sys_fa = (
|
| 339 |
+
"تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. "
|
| 340 |
+
"باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
|
| 341 |
+
)
|
| 342 |
+
sys_en = (
|
| 343 |
+
"You are a helpful assistant. Answer ONLY using the retrieved passages. "
|
| 344 |
+
"You MUST return a single JSON object and nothing else."
|
| 345 |
+
)
|
| 346 |
system_text = sys_fa if lang == "fa" else sys_en
|
| 347 |
|
| 348 |
parts = []
|
|
|
|
| 358 |
if lang == "fa":
|
| 359 |
user = (
|
| 360 |
f"سؤال: {question}\n\nگزینهها:\n{opts_str}\n\nمتون بازیابیشده:\n{joined}\n\n"
|
| 361 |
+
"دقیقاً و فقط یک JSON برگردان. فرمت اجباری: "
|
| 362 |
+
'{"answer_index": X, "reason": "…"} '
|
| 363 |
+
"که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس."
|
| 364 |
)
|
| 365 |
else:
|
| 366 |
user = (
|
| 367 |
f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
|
| 368 |
+
'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} '
|
| 369 |
+
"where X is the 0-based option index. Do not write anything else."
|
| 370 |
)
|
| 371 |
+
|
| 372 |
msgs = [{"role": "system", "content": system_text},
|
| 373 |
{"role": "user", "content": user}]
|
| 374 |
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 375 |
|
| 376 |
+
import json as _json
|
| 377 |
+
import re as _re
|
| 378 |
+
import numpy as _np
|
| 379 |
+
|
| 380 |
+
def _strict_json_from_text(text: str):
|
| 381 |
+
# فقط اولین بلاک {...} را بگیر و JSON-parse کن
|
| 382 |
+
m = _re.search(r'\{.*\}', text, _re.S)
|
| 383 |
+
if not m:
|
| 384 |
+
return None
|
| 385 |
+
frag = m.group(0)
|
| 386 |
+
try:
|
| 387 |
+
obj = _json.loads(frag)
|
| 388 |
+
return obj
|
| 389 |
+
except Exception:
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
def score_options_by_context(options: List[str], contexts: List[Dict[str, Any]]) -> int:
|
| 393 |
+
"""
|
| 394 |
+
فالبک:
|
| 395 |
+
1) اگر اسم گزینه بهصورت substring در متون بود → امتیاز خیلی بالا
|
| 396 |
+
2) وگرنه شباهت embedding با mE5 بین گزینه و کل کانتکستها
|
| 397 |
+
"""
|
| 398 |
+
text_blob = "\n".join([c.get("bio","") for c in contexts]).lower()
|
| 399 |
+
# 1) substring hit
|
| 400 |
+
hits = []
|
| 401 |
+
for i, opt in enumerate(options):
|
| 402 |
+
o = normalize_digits_months(str(opt).strip().lower())
|
| 403 |
+
score = 0
|
| 404 |
+
if o and (o in text_blob):
|
| 405 |
+
score += 10_000
|
| 406 |
+
hits.append((score, i))
|
| 407 |
+
hits.sort(reverse=True)
|
| 408 |
+
if hits and hits[0][0] > 0:
|
| 409 |
+
return hits[0][1]
|
| 410 |
+
|
| 411 |
+
# 2) embedding similarity (mE5)
|
| 412 |
+
try:
|
| 413 |
+
q_vecs = [_encode_query_e5(opt) for opt in options] # (n, dim)
|
| 414 |
+
ctx_vec = _encode_query_e5(text_blob) # (dim,)
|
| 415 |
+
sims = [float(_np.dot(qv, ctx_vec)) for qv in q_vecs]
|
| 416 |
+
return int(_np.argmax(sims))
|
| 417 |
+
except Exception:
|
| 418 |
+
return 0 # پیشفرض محافظهکارانه
|
| 419 |
+
|
| 420 |
+
def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 421 |
+
obj = _strict_json_from_text(text)
|
| 422 |
+
if obj and "answer_index" in obj:
|
| 423 |
+
idx = obj["answer_index"]
|
| 424 |
+
if isinstance(idx, int) and 0 <= idx < len(options):
|
| 425 |
+
reason = str(obj.get("reason", "")).strip() or "—"
|
| 426 |
+
return {"answer_index": idx, "reason": reason}
|
| 427 |
+
# اگر JSON درست نبود → فالبک
|
| 428 |
+
idx = score_options_by_context(options, contexts)
|
| 429 |
+
return {"answer_index": idx, "reason": "fallback_by_context_matching"}
|
| 430 |
+
|
| 431 |
def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
|
| 432 |
m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
|
| 433 |
if m:
|
|
|
|
| 474 |
prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
|
| 475 |
out = llm_generate(prompt, max_new_tokens=int(max_tokens),
|
| 476 |
temperature=float(temperature), top_p=float(top_p),
|
| 477 |
+
top_k=int(top_k), do_sample=False) # deterministic on CPU
|
| 478 |
+
parsed = parse_mcq_output_strict(out, opts, ret["contexts"])
|
| 479 |
pred = parsed["answer_index"]
|
| 480 |
pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
|
| 481 |
+
|
| 482 |
rows = []
|
| 483 |
for i, c in enumerate(ret["contexts"], 1):
|
| 484 |
snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
|
| 485 |
rows.append([i, c["id"], round(c["score"], 4), snip])
|
| 486 |
+
|
| 487 |
result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
|
| 488 |
return result, out, rows, ret["route"]
|
| 489 |
|
| 490 |
+
|
| 491 |
with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
|
| 492 |
gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
|
| 493 |
with gr.Tab("Ask"):
|
|
|
|
| 506 |
route = gr.Textbox(label="Route used (text_e5 or fusion)")
|
| 507 |
table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 508 |
btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
|
|
|
|
| 509 |
with gr.Tab("MCQ"):
|
| 510 |
with gr.Row():
|
| 511 |
q_mcq = gr.Textbox(label="Question", lines=3)
|
| 512 |
+
opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
|
| 513 |
img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
|
| 514 |
with gr.Row():
|
| 515 |
topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
|
|
|
|
| 519 |
top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
|
| 520 |
top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
|
| 521 |
btn2 = gr.Button("Answer MCQ")
|
| 522 |
+
# 👇 باکسها بزرگتر
|
| 523 |
+
result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
|
| 524 |
+
raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
|
| 525 |
route2 = gr.Textbox(label="Route used")
|
| 526 |
table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 527 |
btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
|