amirhossein mohammadpour commited on
Commit
6ca41c8
·
1 Parent(s): 3f6908f

change interface

Browse files
Files changed (1) hide show
  1. app.py +78 -10
app.py CHANGED
@@ -335,8 +335,14 @@ def llm_generate(prompt: str,
335
 
336
  # ---- MCQ helpers ----
337
  def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
338
- sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی."
339
- sys_en = "You are a helpful assistant. Answer only using the retrieved passages."
 
 
 
 
 
 
340
  system_text = sys_fa if lang == "fa" else sys_en
341
 
342
  parts = []
@@ -352,17 +358,76 @@ def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str,
352
  if lang == "fa":
353
  user = (
354
  f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
355
- 'فقط براساس متون بالا پاسخ بده. دقیقاً در این قالب برگردان:\n{"answer_index": X, "reason": "…"}'
 
 
356
  )
357
  else:
358
  user = (
359
  f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
360
- 'Answer strictly based on passages. Return exactly:\n{"answer_index": X, "reason": "..."}'
 
361
  )
 
362
  msgs = [{"role": "system", "content": system_text},
363
  {"role": "user", "content": user}]
364
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
367
  m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
368
  if m:
@@ -409,17 +474,20 @@ def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, t
409
  prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
410
  out = llm_generate(prompt, max_new_tokens=int(max_tokens),
411
  temperature=float(temperature), top_p=float(top_p),
412
- top_k=int(top_k), do_sample=False)
413
- parsed = parse_mcq_output(out, len(opts))
414
  pred = parsed["answer_index"]
415
  pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
 
416
  rows = []
417
  for i, c in enumerate(ret["contexts"], 1):
418
  snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
419
  rows.append([i, c["id"], round(c["score"], 4), snip])
 
420
  result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
421
  return result, out, rows, ret["route"]
422
 
 
423
  with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
424
  gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
425
  with gr.Tab("Ask"):
@@ -438,11 +506,10 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
438
  route = gr.Textbox(label="Route used (text_e5 or fusion)")
439
  table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
440
  btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
441
-
442
  with gr.Tab("MCQ"):
443
  with gr.Row():
444
  q_mcq = gr.Textbox(label="Question", lines=3)
445
- opts_mcq = gr.Textbox(label="Options (one per line)", lines=6)
446
  img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
447
  with gr.Row():
448
  topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
@@ -452,8 +519,9 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
452
  top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
453
  top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
454
  btn2 = gr.Button("Answer MCQ")
455
- result = gr.Textbox(label="Prediction")
456
- raw = gr.Textbox(label="Raw LLM output", lines=6)
 
457
  route2 = gr.Textbox(label="Route used")
458
  table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
459
  btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
 
335
 
336
  # ---- MCQ helpers ----
337
  def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
338
+ sys_fa = (
339
+ "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. "
340
+ "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
341
+ )
342
+ sys_en = (
343
+ "You are a helpful assistant. Answer ONLY using the retrieved passages. "
344
+ "You MUST return a single JSON object and nothing else."
345
+ )
346
  system_text = sys_fa if lang == "fa" else sys_en
347
 
348
  parts = []
 
358
  if lang == "fa":
359
  user = (
360
  f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
361
+ "دقیقاً و فقط یک JSON برگردان. فرمت اجباری: "
362
+ '{"answer_index": X, "reason": "…"} '
363
+ "که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس."
364
  )
365
  else:
366
  user = (
367
  f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
368
+ 'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} '
369
+ "where X is the 0-based option index. Do not write anything else."
370
  )
371
+
372
  msgs = [{"role": "system", "content": system_text},
373
  {"role": "user", "content": user}]
374
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
375
 
376
+ import json as _json
377
+ import re as _re
378
+ import numpy as _np
379
+
380
+ def _strict_json_from_text(text: str):
381
+ # فقط اولین بلاک {...} را بگیر و JSON-parse کن
382
+ m = _re.search(r'\{.*\}', text, _re.S)
383
+ if not m:
384
+ return None
385
+ frag = m.group(0)
386
+ try:
387
+ obj = _json.loads(frag)
388
+ return obj
389
+ except Exception:
390
+ return None
391
+
392
+ def score_options_by_context(options: List[str], contexts: List[Dict[str, Any]]) -> int:
393
+ """
394
+ فال‌بک:
395
+ 1) اگر اسم گزینه به‌صورت substring در متون بود → امتیاز خیلی بالا
396
+ 2) وگرنه شباهت embedding با mE5 بین گزینه و کل کانتکست‌ها
397
+ """
398
+ text_blob = "\n".join([c.get("bio","") for c in contexts]).lower()
399
+ # 1) substring hit
400
+ hits = []
401
+ for i, opt in enumerate(options):
402
+ o = normalize_digits_months(str(opt).strip().lower())
403
+ score = 0
404
+ if o and (o in text_blob):
405
+ score += 10_000
406
+ hits.append((score, i))
407
+ hits.sort(reverse=True)
408
+ if hits and hits[0][0] > 0:
409
+ return hits[0][1]
410
+
411
+ # 2) embedding similarity (mE5)
412
+ try:
413
+ q_vecs = [_encode_query_e5(opt) for opt in options] # (n, dim)
414
+ ctx_vec = _encode_query_e5(text_blob) # (dim,)
415
+ sims = [float(_np.dot(qv, ctx_vec)) for qv in q_vecs]
416
+ return int(_np.argmax(sims))
417
+ except Exception:
418
+ return 0 # پیش‌فرض محافظه‌کارانه
419
+
420
+ def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
421
+ obj = _strict_json_from_text(text)
422
+ if obj and "answer_index" in obj:
423
+ idx = obj["answer_index"]
424
+ if isinstance(idx, int) and 0 <= idx < len(options):
425
+ reason = str(obj.get("reason", "")).strip() or "—"
426
+ return {"answer_index": idx, "reason": reason}
427
+ # اگر JSON درست نبود → فال‌بک
428
+ idx = score_options_by_context(options, contexts)
429
+ return {"answer_index": idx, "reason": "fallback_by_context_matching"}
430
+
431
  def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
432
  m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
433
  if m:
 
474
  prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
475
  out = llm_generate(prompt, max_new_tokens=int(max_tokens),
476
  temperature=float(temperature), top_p=float(top_p),
477
+ top_k=int(top_k), do_sample=False) # deterministic on CPU
478
+ parsed = parse_mcq_output_strict(out, opts, ret["contexts"])
479
  pred = parsed["answer_index"]
480
  pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
481
+
482
  rows = []
483
  for i, c in enumerate(ret["contexts"], 1):
484
  snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
485
  rows.append([i, c["id"], round(c["score"], 4), snip])
486
+
487
  result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
488
  return result, out, rows, ret["route"]
489
 
490
+
491
  with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
492
  gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
493
  with gr.Tab("Ask"):
 
506
  route = gr.Textbox(label="Route used (text_e5 or fusion)")
507
  table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
508
  btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
 
509
  with gr.Tab("MCQ"):
510
  with gr.Row():
511
  q_mcq = gr.Textbox(label="Question", lines=3)
512
+ opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
513
  img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
514
  with gr.Row():
515
  topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
 
519
  top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
520
  top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
521
  btn2 = gr.Button("Answer MCQ")
522
+ # 👇 باکس‌ها بزرگ‌تر
523
+ result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
524
+ raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
525
  route2 = gr.Textbox(label="Route used")
526
  table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
527
  btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],