amirhossein mohammadpour commited on
Commit
73a17b2
·
1 Parent(s): 6ca41c8

hanlde speed

Browse files
Files changed (1) hide show
  1. app.py +150 -64
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os, io, gc, json, re, ast
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import faiss
@@ -10,7 +12,9 @@ import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
  from sentence_transformers import SentenceTransformer
12
  from transformers import AutoTokenizer, AutoModelForCausalLM
13
-
 
 
14
  # =========================
15
  # Config (override in Space → Settings → Variables & secrets)
16
  # =========================
@@ -27,13 +31,15 @@ HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is pri
27
  # Models (CPU-friendly defaults; override via env if desired)
28
  E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
29
  CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
30
- LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct") # small enough for free CPU
 
 
 
31
 
32
- # Generation defaults (also controllable from UI)
33
- MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "192"))
34
- TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0")) # deterministic by default on CPU
35
- TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.9"))
36
- TOP_K_DEFAULT = int(os.getenv("TOP_K", "50"))
37
 
38
  # =========================
39
  # Helpers
@@ -163,15 +169,20 @@ model = AutoModelForCausalLM.from_pretrained(
163
  torch_dtype=dtype,
164
  ).to("cpu").eval()
165
 
 
166
  # =========================
167
  # Retrieval helpers
168
  # =========================
169
- @torch.no_grad()
170
- def _encode_query_e5(q: str) -> np.ndarray:
171
  qn = "query: " + normalize_digits_months(q)
172
  v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
173
  return v.astype("float32")
174
 
 
 
 
 
175
  def _faiss_search(index, q_vec: np.ndarray, k: int):
176
  if q_vec.ndim == 1:
177
  q_vec = q_vec[None, :]
@@ -262,7 +273,7 @@ def make_query_embed(query_text: str,
262
  def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
263
  if index_fusion is None:
264
  raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
265
- qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=True, n_aug=3)
266
  return _faiss_search(index_fusion, qv, k)
267
 
268
  # =========================
@@ -288,7 +299,7 @@ def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None)
288
  ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
289
  return {"route": route, "contexts": ctxs}
290
 
291
- def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
292
  sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»"
293
  sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
294
  system_text = sys_fa if lang == "fa" else sys_en
@@ -311,20 +322,19 @@ def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_c
311
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
312
 
313
  @torch.inference_mode()
314
- def llm_generate(prompt: str,
315
- max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
316
- temperature=TEMPERATURE_DEFAULT,
317
- top_p=TOP_P_DEFAULT,
318
- top_k=TOP_K_DEFAULT,
319
- do_sample=False) -> str:
320
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
321
  out = model.generate(
322
  **inputs,
323
- max_new_tokens=int(max_new_tokens),
324
- do_sample=bool(do_sample),
325
- temperature=float(temperature),
326
- top_p=float(top_p),
327
- top_k=int(top_k),
 
 
328
  pad_token_id=tokenizer.eos_token_id,
329
  eos_token_id=tokenizer.eos_token_id,
330
  )
@@ -334,7 +344,7 @@ def llm_generate(prompt: str,
334
  return text.strip()
335
 
336
  # ---- MCQ helpers ----
337
- def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
338
  sys_fa = (
339
  "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. "
340
  "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
@@ -389,33 +399,109 @@ def _strict_json_from_text(text: str):
389
  except Exception:
390
  return None
391
 
392
- def score_options_by_context(options: List[str], contexts: List[Dict[str, Any]]) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  """
394
- فال‌بک:
395
- 1) اگر اسم گزینه به‌صورت substring در متون بود امتیاز خیلی بالا
396
- 2) وگرنه شباهت embedding با mE5 بین گزینه و کل کانتکست‌ها
 
 
 
397
  """
398
- text_blob = "\n".join([c.get("bio","") for c in contexts]).lower()
399
- # 1) substring hit
400
- hits = []
 
 
 
 
 
401
  for i, opt in enumerate(options):
402
- o = normalize_digits_months(str(opt).strip().lower())
403
- score = 0
404
- if o and (o in text_blob):
405
- score += 10_000
406
- hits.append((score, i))
407
- hits.sort(reverse=True)
408
- if hits and hits[0][0] > 0:
409
- return hits[0][1]
410
-
411
- # 2) embedding similarity (mE5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  try:
413
- q_vecs = [_encode_query_e5(opt) for opt in options] # (n, dim)
414
- ctx_vec = _encode_query_e5(text_blob) # (dim,)
415
- sims = [float(_np.dot(qv, ctx_vec)) for qv in q_vecs]
416
- return int(_np.argmax(sims))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  except Exception:
418
- return 0 # پیش‌فرض محافظه‌کارانه
 
419
 
420
  def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
421
  obj = _strict_json_from_text(text)
@@ -424,9 +510,8 @@ def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[s
424
  if isinstance(idx, int) and 0 <= idx < len(options):
425
  reason = str(obj.get("reason", "")).strip() or "—"
426
  return {"answer_index": idx, "reason": reason}
427
- # اگر JSON درست نبود → فال‌بک
428
- idx = score_options_by_context(options, contexts)
429
- return {"answer_index": idx, "reason": "fallback_by_context_matching"}
430
 
431
  def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
432
  m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
@@ -455,7 +540,7 @@ def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k):
455
  return "Please enter a question.", [], ""
456
  # Retrieve
457
  ret = retrieve_context_auto(question, k=int(topk), image=image)
458
- prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=5000)
459
  ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
460
  temperature=float(temperature), top_p=float(top_p),
461
  top_k=int(top_k), do_sample=False)
@@ -492,34 +577,35 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
492
  gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
493
  with gr.Tab("Ask"):
494
  with gr.Row():
495
- q = gr.Textbox(label="Question", placeholder="سؤال خود را بنویسید…", lines=3)
496
- img = gr.Image(type="pil", label="Optional image (fusion if provided)")
 
497
  with gr.Row():
498
- topk = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
499
- max_tokens = gr.Slider(32, 1024, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
500
  with gr.Row():
501
- temperature = gr.Slider(0.0, 1.5, value=TEMPERATURE_DEFAULT, step=0.1, label="Temperature")
502
- top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
503
- top_k = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
504
  btn = gr.Button("Answer")
505
  ans = gr.Textbox(label="Answer", lines=8)
506
  route = gr.Textbox(label="Route used (text_e5 or fusion)")
507
  table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
508
- btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
 
509
  with gr.Tab("MCQ"):
510
  with gr.Row():
511
  q_mcq = gr.Textbox(label="Question", lines=3)
512
  opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
513
- img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
514
  with gr.Row():
515
- topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
516
- max_tokens2 = gr.Slider(32, 1024, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
517
  with gr.Row():
518
- temperature2 = gr.Slider(0.0, 1.5, value=TEMPERATURE_DEFAULT, step=0.1, label="Temperature")
519
- top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
520
- top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
521
  btn2 = gr.Button("Answer MCQ")
522
- # 👇 باکس‌ها بزرگ‌تر
523
  result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
524
  raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
525
  route2 = gr.Textbox(label="Route used")
 
1
  import os, io, gc, json, re, ast
2
+ from functools import lru_cache
3
+
4
  import numpy as np
5
  import pandas as pd
6
  import faiss
 
12
  from huggingface_hub import hf_hub_download
13
  from sentence_transformers import SentenceTransformer
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ import os, torch
16
+ torch.set_num_threads(2) # vCPUهای Space معمولاً 2 تاست
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
  # =========================
19
  # Config (override in Space → Settings → Variables & secrets)
20
  # =========================
 
31
  # Models (CPU-friendly defaults; override via env if desired)
32
  E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
33
  CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
34
+ LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct")
35
+
36
+ # خروجی کوتاه‌تر
37
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96")) # قبلاً 256
38
 
39
+ # نمونه‌برداری خاموش (قطعی و سریع‌تر)
40
+ TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0"))
41
+ TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "1.0"))
42
+ TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "50"))
 
43
 
44
  # =========================
45
  # Helpers
 
169
  torch_dtype=dtype,
170
  ).to("cpu").eval()
171
 
172
+
173
  # =========================
174
  # Retrieval helpers
175
  # =========================
176
+ @lru_cache(maxsize=4096)
177
+ def _encode_query_e5_cached(q: str) -> np.ndarray:
178
  qn = "query: " + normalize_digits_months(q)
179
  v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
180
  return v.astype("float32")
181
 
182
+ # استفاده به‌جای قدیمی:
183
+ def _encode_query_e5(q: str) -> np.ndarray:
184
+ return _encode_query_e5_cached(q)
185
+
186
  def _faiss_search(index, q_vec: np.ndarray, k: int):
187
  if q_vec.ndim == 1:
188
  q_vec = q_vec[None, :]
 
273
  def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
274
  if index_fusion is None:
275
  raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
276
+ qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3)
277
  return _faiss_search(index_fusion, qv, k)
278
 
279
  # =========================
 
299
  ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
300
  return {"route": route, "contexts": ctxs}
301
 
302
+ def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
303
  sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»"
304
  sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
305
  system_text = sys_fa if lang == "fa" else sys_en
 
322
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
323
 
324
  @torch.inference_mode()
325
+ def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str:
 
 
 
 
 
326
  inputs = tokenizer(prompt, return_tensors="pt")
327
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
328
+
329
  out = model.generate(
330
  **inputs,
331
+ max_new_tokens=max_new_tokens,
332
+ do_sample=False, # قطعی
333
+ temperature=temperature,
334
+ top_p=top_p,
335
+ top_k=top_k,
336
+ num_beams=1, # بدون beam-search
337
+ use_cache=True, # سریع‌تر
338
  pad_token_id=tokenizer.eos_token_id,
339
  eos_token_id=tokenizer.eos_token_id,
340
  )
 
344
  return text.strip()
345
 
346
  # ---- MCQ helpers ----
347
+ def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
348
  sys_fa = (
349
  "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. "
350
  "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
 
399
  except Exception:
400
  return None
401
 
402
+ import re as _re
403
+ import numpy as _np
404
+
405
+ def _norm_text_for_match(s: str) -> str:
406
+ # نرمال‌سازی ساده: اعداد فارسی/عربی، ZWNJ، فاصله‌های اضافه
407
+ s = normalize_digits_months(s or "")
408
+ s = s.replace("\u200c", " ").strip()
409
+ # پایین‌حرفی و تک‌فاصله
410
+ s = _re.sub(r"\s+", " ", s.lower())
411
+ return s
412
+
413
+ def _find_snippet(hay: str, needle: str, win: int = 60) -> str:
414
+ """یک تکه متن کوتاه اطراف اولین مچ را بده."""
415
+ try:
416
+ i = hay.index(needle)
417
+ start = max(0, i - win)
418
+ end = min(len(hay), i + len(needle) + win)
419
+ return hay[start:end].replace("\n", " ")
420
+ except ValueError:
421
+ return ""
422
+
423
+ def score_options_by_context(
424
+ options: List[str],
425
+ contexts: List[Dict[str, Any]],
426
+ return_snippet: bool = False
427
+ ):
428
  """
429
+ فال‌بک هوشمند:
430
+ 1) boundary-aware substring در تک‌تک کانتکست‌ها (امتیاز بالا + تعداد وقوع)
431
+ 2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکست‌ها
432
+ خروجی:
433
+ - اگر return_snippet=False → فقط best_idx (int)
434
+ - اگر return_snippet=True → (best_idx, snippet) برمی‌گرداند
435
  """
436
+ # آماده‌سازی کانتکست‌ها
437
+ raw_ctxs = [c.get("bio", "") for c in contexts]
438
+ norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs]
439
+ joined_norm = " \n ".join(norm_ctxs)
440
+
441
+ # 1) جست‌وجوی دقیق‌تر: word boundary + شمارش
442
+ # برای فارسی/عربی هم خوب جواب می‌دهد چون از فاصله استفاده می‌کنیم.
443
+ best_idx, best_score, best_snip = 0, -1.0, ""
444
  for i, opt in enumerate(options):
445
+ o_raw = str(opt).strip()
446
+ o = _norm_text_for_match(o_raw)
447
+ if not o:
448
+ continue
449
+
450
+ # الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله)
451
+ # اگر گزینه چندکلمه‌ای است، همین هم خوب جواب می‌دهد.
452
+ # اگر لازم شد می‌توان regex دقیق‌تر نوشت.
453
+ pat = r"(?<!\S)" + _re.escape(o) + r"(?!\S)"
454
+
455
+ total_hits = 0
456
+ first_snip = ""
457
+ for raw, norm in zip(raw_ctxs, norm_ctxs):
458
+ for m in _re.finditer(pat, norm):
459
+ total_hits += 1
460
+ if not first_snip:
461
+ # اسنیپت از متن خام (خواناتر)
462
+ # موقعیت متن خام را تقریبی می‌گیریم با جست‌وجوی ساده
463
+ # (اگر اختلاف normalization زیاد بود، از norm استفاده می‌کنیم)
464
+ sn = _find_snippet(raw, o_raw) or _find_snippet(norm, o)
465
+ first_snip = sn
466
+ if total_hits > 0:
467
+ # امتیاز بالا برای مچ صریح + تعداد وقوع
468
+ score = 10000.0 + total_hits
469
+ if score > best_score:
470
+ best_score, best_idx, best_snip = score, i, first_snip
471
+
472
+ if best_score > 0:
473
+ return (best_idx, best_snip) if return_snippet else best_idx
474
+
475
+ # 2) اگر هیچ مچی نبود → شباهت embedding (mE5)
476
  try:
477
+ # وکتور کل کانتکست‌ها (یک‌بار)
478
+ ctx_vec = _encode_query_e5(joined_norm) # (dim,)
479
+ sims = []
480
+ for opt in options:
481
+ qv = _encode_query_e5(str(opt))
482
+ sims.append(float(_np.dot(qv, ctx_vec)))
483
+ best_idx = int(_np.argmax(sims))
484
+
485
+ # برای snippet در این مسیر: نزدیک‌ترین کانتکست را با dot جداگانه پیدا کنیم
486
+ # (سریع و به‌اندازه کافی خوب)
487
+ best_snip = ""
488
+ try:
489
+ opt_vec = _encode_query_e5(str(options[best_idx]))
490
+ # کوساین تقریباً همان inner-prod چون نرمال شده‌اند
491
+ # امتیاز هر کانتکست با گزینه‌ی برنده:
492
+ c_scores = []
493
+ for raw, norm in zip(raw_ctxs, norm_ctxs):
494
+ c_vec = _encode_query_e5(norm)
495
+ c_scores.append(float(_np.dot(opt_vec, c_vec)))
496
+ j = int(_np.argmax(c_scores))
497
+ best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ")
498
+ except Exception:
499
+ pass
500
+
501
+ return (best_idx, best_snip) if return_snippet else best_idx
502
  except Exception:
503
+ return (0, "") if return_snippet else 0 # پیش‌فرض محافظه‌کارانه
504
+
505
 
506
  def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
507
  obj = _strict_json_from_text(text)
 
510
  if isinstance(idx, int) and 0 <= idx < len(options):
511
  reason = str(obj.get("reason", "")).strip() or "—"
512
  return {"answer_index": idx, "reason": reason}
513
+ idx, snip = score_options_by_context(options, contexts, return_snippet=True)
514
+ return {"answer_index": idx, "reason": snip or "matched by context"}
 
515
 
516
  def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
517
  m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
 
540
  return "Please enter a question.", [], ""
541
  # Retrieve
542
  ret = retrieve_context_auto(question, k=int(topk), image=image)
543
+ prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800)
544
  ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
545
  temperature=float(temperature), top_p=float(top_p),
546
  top_k=int(top_k), do_sample=False)
 
577
  gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
578
  with gr.Tab("Ask"):
579
  with gr.Row():
580
+ q = gr.Textbox(label="Question", lines=3)
581
+ img = gr.Image(type="pil", label="Optional image")
582
+ use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False)
583
  with gr.Row():
584
+ topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
585
+ max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
586
  with gr.Row():
587
+ temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
588
+ top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
589
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
590
  btn = gr.Button("Answer")
591
  ans = gr.Textbox(label="Answer", lines=8)
592
  route = gr.Textbox(label="Route used (text_e5 or fusion)")
593
  table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
594
+ btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
595
+
596
  with gr.Tab("MCQ"):
597
  with gr.Row():
598
  q_mcq = gr.Textbox(label="Question", lines=3)
599
  opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
600
+ img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)")
601
  with gr.Row():
602
+ topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
603
+ max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
604
  with gr.Row():
605
+ temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
606
+ top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
607
+ top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k")
608
  btn2 = gr.Button("Answer MCQ")
 
609
  result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
610
  raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
611
  route2 = gr.Textbox(label="Route used")