Spaces:
Sleeping
Sleeping
amirhossein mohammadpour
commited on
Commit
·
73a17b2
1
Parent(s):
6ca41c8
hanlde speed
Browse files
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import os, io, gc, json, re, ast
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import faiss
|
|
@@ -10,7 +12,9 @@ import gradio as gr
|
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
# =========================
|
| 15 |
# Config (override in Space → Settings → Variables & secrets)
|
| 16 |
# =========================
|
|
@@ -27,13 +31,15 @@ HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is pri
|
|
| 27 |
# Models (CPU-friendly defaults; override via env if desired)
|
| 28 |
E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
|
| 29 |
CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
|
| 30 |
-
LLM_ID
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
TOP_K_DEFAULT = int(os.getenv("TOP_K", "50"))
|
| 37 |
|
| 38 |
# =========================
|
| 39 |
# Helpers
|
|
@@ -163,15 +169,20 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 163 |
torch_dtype=dtype,
|
| 164 |
).to("cpu").eval()
|
| 165 |
|
|
|
|
| 166 |
# =========================
|
| 167 |
# Retrieval helpers
|
| 168 |
# =========================
|
| 169 |
-
@
|
| 170 |
-
def
|
| 171 |
qn = "query: " + normalize_digits_months(q)
|
| 172 |
v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 173 |
return v.astype("float32")
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def _faiss_search(index, q_vec: np.ndarray, k: int):
|
| 176 |
if q_vec.ndim == 1:
|
| 177 |
q_vec = q_vec[None, :]
|
|
@@ -262,7 +273,7 @@ def make_query_embed(query_text: str,
|
|
| 262 |
def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
|
| 263 |
if index_fusion is None:
|
| 264 |
raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
|
| 265 |
-
qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=
|
| 266 |
return _faiss_search(index_fusion, qv, k)
|
| 267 |
|
| 268 |
# =========================
|
|
@@ -288,7 +299,7 @@ def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None)
|
|
| 288 |
ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
|
| 289 |
return {"route": route, "contexts": ctxs}
|
| 290 |
|
| 291 |
-
def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=
|
| 292 |
sys_fa = "تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. اگر پاسخی در متنها نبود، صادقانه بگو «در متنهای بازیابیشده پاسخی پیدا نشد.»"
|
| 293 |
sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
|
| 294 |
system_text = sys_fa if lang == "fa" else sys_en
|
|
@@ -311,20 +322,19 @@ def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_c
|
|
| 311 |
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 312 |
|
| 313 |
@torch.inference_mode()
|
| 314 |
-
def llm_generate(prompt: str,
|
| 315 |
-
max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
|
| 316 |
-
temperature=TEMPERATURE_DEFAULT,
|
| 317 |
-
top_p=TOP_P_DEFAULT,
|
| 318 |
-
top_k=TOP_K_DEFAULT,
|
| 319 |
-
do_sample=False) -> str:
|
| 320 |
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
| 321 |
out = model.generate(
|
| 322 |
**inputs,
|
| 323 |
-
max_new_tokens=
|
| 324 |
-
do_sample=
|
| 325 |
-
temperature=
|
| 326 |
-
top_p=
|
| 327 |
-
top_k=
|
|
|
|
|
|
|
| 328 |
pad_token_id=tokenizer.eos_token_id,
|
| 329 |
eos_token_id=tokenizer.eos_token_id,
|
| 330 |
)
|
|
@@ -334,7 +344,7 @@ def llm_generate(prompt: str,
|
|
| 334 |
return text.strip()
|
| 335 |
|
| 336 |
# ---- MCQ helpers ----
|
| 337 |
-
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=
|
| 338 |
sys_fa = (
|
| 339 |
"تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. "
|
| 340 |
"باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
|
|
@@ -389,33 +399,109 @@ def _strict_json_from_text(text: str):
|
|
| 389 |
except Exception:
|
| 390 |
return None
|
| 391 |
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
"""
|
| 394 |
-
|
| 395 |
-
1)
|
| 396 |
-
2)
|
|
|
|
|
|
|
|
|
|
| 397 |
"""
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
for i, opt in enumerate(options):
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
if
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
try:
|
| 413 |
-
|
| 414 |
-
ctx_vec = _encode_query_e5(
|
| 415 |
-
sims = [
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
except Exception:
|
| 418 |
-
return 0 # پیشفرض محافظهکارانه
|
|
|
|
| 419 |
|
| 420 |
def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 421 |
obj = _strict_json_from_text(text)
|
|
@@ -424,9 +510,8 @@ def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[s
|
|
| 424 |
if isinstance(idx, int) and 0 <= idx < len(options):
|
| 425 |
reason = str(obj.get("reason", "")).strip() or "—"
|
| 426 |
return {"answer_index": idx, "reason": reason}
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
return {"answer_index": idx, "reason": "fallback_by_context_matching"}
|
| 430 |
|
| 431 |
def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
|
| 432 |
m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
|
|
@@ -455,7 +540,7 @@ def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k):
|
|
| 455 |
return "Please enter a question.", [], ""
|
| 456 |
# Retrieve
|
| 457 |
ret = retrieve_context_auto(question, k=int(topk), image=image)
|
| 458 |
-
prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=
|
| 459 |
ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
|
| 460 |
temperature=float(temperature), top_p=float(top_p),
|
| 461 |
top_k=int(top_k), do_sample=False)
|
|
@@ -492,34 +577,35 @@ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as
|
|
| 492 |
gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
|
| 493 |
with gr.Tab("Ask"):
|
| 494 |
with gr.Row():
|
| 495 |
-
q = gr.Textbox(label="Question",
|
| 496 |
-
img = gr.Image(type="pil", label="Optional image
|
|
|
|
| 497 |
with gr.Row():
|
| 498 |
-
topk = gr.Slider(1, 20, value=
|
| 499 |
-
max_tokens = gr.Slider(
|
| 500 |
with gr.Row():
|
| 501 |
-
temperature = gr.Slider(0.0, 1.
|
| 502 |
-
top_p = gr.Slider(0.1, 1.0, value=
|
| 503 |
-
top_k = gr.Slider(1, 100, value=
|
| 504 |
btn = gr.Button("Answer")
|
| 505 |
ans = gr.Textbox(label="Answer", lines=8)
|
| 506 |
route = gr.Textbox(label="Route used (text_e5 or fusion)")
|
| 507 |
table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 508 |
-
btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
|
|
|
|
| 509 |
with gr.Tab("MCQ"):
|
| 510 |
with gr.Row():
|
| 511 |
q_mcq = gr.Textbox(label="Question", lines=3)
|
| 512 |
opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
|
| 513 |
-
img_mcq = gr.Image(type="pil", label="Optional image (fusion if
|
| 514 |
with gr.Row():
|
| 515 |
-
topk2 = gr.Slider(1, 20, value=
|
| 516 |
-
max_tokens2 = gr.Slider(
|
| 517 |
with gr.Row():
|
| 518 |
-
temperature2 = gr.Slider(0.0, 1.
|
| 519 |
-
top_p2 = gr.Slider(0.1, 1.0, value=
|
| 520 |
-
top_k2 = gr.Slider(1, 100, value=
|
| 521 |
btn2 = gr.Button("Answer MCQ")
|
| 522 |
-
# 👇 باکسها بزرگتر
|
| 523 |
result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
|
| 524 |
raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
|
| 525 |
route2 = gr.Textbox(label="Route used")
|
|
|
|
| 1 |
import os, io, gc, json, re, ast
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
import faiss
|
|
|
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
import os, torch
|
| 16 |
+
torch.set_num_threads(2) # vCPUهای Space معمولاً 2 تاست
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
# =========================
|
| 19 |
# Config (override in Space → Settings → Variables & secrets)
|
| 20 |
# =========================
|
|
|
|
| 31 |
# Models (CPU-friendly defaults; override via env if desired)
|
| 32 |
E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
|
| 33 |
CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
|
| 34 |
+
LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct")
|
| 35 |
+
|
| 36 |
+
# خروجی کوتاهتر
|
| 37 |
+
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96")) # قبلاً 256
|
| 38 |
|
| 39 |
+
# نمونهبرداری خاموش (قطعی و سریعتر)
|
| 40 |
+
TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0"))
|
| 41 |
+
TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "1.0"))
|
| 42 |
+
TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "50"))
|
|
|
|
| 43 |
|
| 44 |
# =========================
|
| 45 |
# Helpers
|
|
|
|
| 169 |
torch_dtype=dtype,
|
| 170 |
).to("cpu").eval()
|
| 171 |
|
| 172 |
+
|
| 173 |
# =========================
|
| 174 |
# Retrieval helpers
|
| 175 |
# =========================
|
| 176 |
+
@lru_cache(maxsize=4096)
|
| 177 |
+
def _encode_query_e5_cached(q: str) -> np.ndarray:
|
| 178 |
qn = "query: " + normalize_digits_months(q)
|
| 179 |
v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 180 |
return v.astype("float32")
|
| 181 |
|
| 182 |
+
# استفاده بهجای قدیمی:
|
| 183 |
+
def _encode_query_e5(q: str) -> np.ndarray:
|
| 184 |
+
return _encode_query_e5_cached(q)
|
| 185 |
+
|
| 186 |
def _faiss_search(index, q_vec: np.ndarray, k: int):
|
| 187 |
if q_vec.ndim == 1:
|
| 188 |
q_vec = q_vec[None, :]
|
|
|
|
| 273 |
def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
|
| 274 |
if index_fusion is None:
|
| 275 |
raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
|
| 276 |
+
qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3)
|
| 277 |
return _faiss_search(index_fusion, qv, k)
|
| 278 |
|
| 279 |
# =========================
|
|
|
|
| 299 |
ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
|
| 300 |
return {"route": route, "contexts": ctxs}
|
| 301 |
|
| 302 |
+
def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
|
| 303 |
sys_fa = "تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. اگر پاسخی در متنها نبود، صادقانه بگو «در متنهای بازیابیشده پاسخی پیدا نشد.»"
|
| 304 |
sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
|
| 305 |
system_text = sys_fa if lang == "fa" else sys_en
|
|
|
|
| 322 |
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 323 |
|
| 324 |
@torch.inference_mode()
|
| 325 |
+
def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 327 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 328 |
+
|
| 329 |
out = model.generate(
|
| 330 |
**inputs,
|
| 331 |
+
max_new_tokens=max_new_tokens,
|
| 332 |
+
do_sample=False, # قطعی
|
| 333 |
+
temperature=temperature,
|
| 334 |
+
top_p=top_p,
|
| 335 |
+
top_k=top_k,
|
| 336 |
+
num_beams=1, # بدون beam-search
|
| 337 |
+
use_cache=True, # سریعتر
|
| 338 |
pad_token_id=tokenizer.eos_token_id,
|
| 339 |
eos_token_id=tokenizer.eos_token_id,
|
| 340 |
)
|
|
|
|
| 344 |
return text.strip()
|
| 345 |
|
| 346 |
# ---- MCQ helpers ----
|
| 347 |
+
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
|
| 348 |
sys_fa = (
|
| 349 |
"تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. "
|
| 350 |
"باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
|
|
|
|
| 399 |
except Exception:
|
| 400 |
return None
|
| 401 |
|
| 402 |
+
import re as _re
|
| 403 |
+
import numpy as _np
|
| 404 |
+
|
| 405 |
+
def _norm_text_for_match(s: str) -> str:
|
| 406 |
+
# نرمالسازی ساده: اعداد فارسی/عربی، ZWNJ، فاصلههای اضافه
|
| 407 |
+
s = normalize_digits_months(s or "")
|
| 408 |
+
s = s.replace("\u200c", " ").strip()
|
| 409 |
+
# پایینحرفی و تکفاصله
|
| 410 |
+
s = _re.sub(r"\s+", " ", s.lower())
|
| 411 |
+
return s
|
| 412 |
+
|
| 413 |
+
def _find_snippet(hay: str, needle: str, win: int = 60) -> str:
|
| 414 |
+
"""یک تکه متن کوتاه اطراف اولین مچ را بده."""
|
| 415 |
+
try:
|
| 416 |
+
i = hay.index(needle)
|
| 417 |
+
start = max(0, i - win)
|
| 418 |
+
end = min(len(hay), i + len(needle) + win)
|
| 419 |
+
return hay[start:end].replace("\n", " ")
|
| 420 |
+
except ValueError:
|
| 421 |
+
return ""
|
| 422 |
+
|
| 423 |
+
def score_options_by_context(
|
| 424 |
+
options: List[str],
|
| 425 |
+
contexts: List[Dict[str, Any]],
|
| 426 |
+
return_snippet: bool = False
|
| 427 |
+
):
|
| 428 |
"""
|
| 429 |
+
فالبک هوشمند:
|
| 430 |
+
1) boundary-aware substring در تکتک کانتکستها (امتیاز بالا + تعداد وقوع)
|
| 431 |
+
2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکستها
|
| 432 |
+
خروجی:
|
| 433 |
+
- اگر return_snippet=False → فقط best_idx (int)
|
| 434 |
+
- اگر return_snippet=True → (best_idx, snippet) برمیگرداند
|
| 435 |
"""
|
| 436 |
+
# آمادهسازی کانتکستها
|
| 437 |
+
raw_ctxs = [c.get("bio", "") for c in contexts]
|
| 438 |
+
norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs]
|
| 439 |
+
joined_norm = " \n ".join(norm_ctxs)
|
| 440 |
+
|
| 441 |
+
# 1) جستوجوی دقیقتر: word boundary + شمارش
|
| 442 |
+
# برای فارسی/عربی هم خوب جواب میدهد چون از فاصله استفاده میکنیم.
|
| 443 |
+
best_idx, best_score, best_snip = 0, -1.0, ""
|
| 444 |
for i, opt in enumerate(options):
|
| 445 |
+
o_raw = str(opt).strip()
|
| 446 |
+
o = _norm_text_for_match(o_raw)
|
| 447 |
+
if not o:
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
# الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله)
|
| 451 |
+
# اگر گزینه چندکلمهای است، همین هم خوب جواب میدهد.
|
| 452 |
+
# اگر لازم شد میتوان regex دقیقتر نوشت.
|
| 453 |
+
pat = r"(?<!\S)" + _re.escape(o) + r"(?!\S)"
|
| 454 |
+
|
| 455 |
+
total_hits = 0
|
| 456 |
+
first_snip = ""
|
| 457 |
+
for raw, norm in zip(raw_ctxs, norm_ctxs):
|
| 458 |
+
for m in _re.finditer(pat, norm):
|
| 459 |
+
total_hits += 1
|
| 460 |
+
if not first_snip:
|
| 461 |
+
# اسنیپت از متن خام (خواناتر)
|
| 462 |
+
# موقعیت متن خام را تقریبی میگیریم با جستوجوی ساده
|
| 463 |
+
# (اگر اختلاف normalization زیاد بود، از norm استفاده میکنیم)
|
| 464 |
+
sn = _find_snippet(raw, o_raw) or _find_snippet(norm, o)
|
| 465 |
+
first_snip = sn
|
| 466 |
+
if total_hits > 0:
|
| 467 |
+
# امتیاز بالا برای مچ صریح + تعداد وقوع
|
| 468 |
+
score = 10000.0 + total_hits
|
| 469 |
+
if score > best_score:
|
| 470 |
+
best_score, best_idx, best_snip = score, i, first_snip
|
| 471 |
+
|
| 472 |
+
if best_score > 0:
|
| 473 |
+
return (best_idx, best_snip) if return_snippet else best_idx
|
| 474 |
+
|
| 475 |
+
# 2) اگر هیچ مچی نبود → شباهت embedding (mE5)
|
| 476 |
try:
|
| 477 |
+
# وکتور کل کانتکستها (یکبار)
|
| 478 |
+
ctx_vec = _encode_query_e5(joined_norm) # (dim,)
|
| 479 |
+
sims = []
|
| 480 |
+
for opt in options:
|
| 481 |
+
qv = _encode_query_e5(str(opt))
|
| 482 |
+
sims.append(float(_np.dot(qv, ctx_vec)))
|
| 483 |
+
best_idx = int(_np.argmax(sims))
|
| 484 |
+
|
| 485 |
+
# برای snippet در این مسیر: نزدیکترین کانتکست را با dot جداگانه پیدا کنیم
|
| 486 |
+
# (سریع و بهاندازه کافی خوب)
|
| 487 |
+
best_snip = ""
|
| 488 |
+
try:
|
| 489 |
+
opt_vec = _encode_query_e5(str(options[best_idx]))
|
| 490 |
+
# کوساین تقریباً همان inner-prod چون نرمال شدهاند
|
| 491 |
+
# امتیاز هر کانتکست با گزینهی برنده:
|
| 492 |
+
c_scores = []
|
| 493 |
+
for raw, norm in zip(raw_ctxs, norm_ctxs):
|
| 494 |
+
c_vec = _encode_query_e5(norm)
|
| 495 |
+
c_scores.append(float(_np.dot(opt_vec, c_vec)))
|
| 496 |
+
j = int(_np.argmax(c_scores))
|
| 497 |
+
best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ")
|
| 498 |
+
except Exception:
|
| 499 |
+
pass
|
| 500 |
+
|
| 501 |
+
return (best_idx, best_snip) if return_snippet else best_idx
|
| 502 |
except Exception:
|
| 503 |
+
return (0, "") if return_snippet else 0 # پیشفرض محافظهکارانه
|
| 504 |
+
|
| 505 |
|
| 506 |
def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 507 |
obj = _strict_json_from_text(text)
|
|
|
|
| 510 |
if isinstance(idx, int) and 0 <= idx < len(options):
|
| 511 |
reason = str(obj.get("reason", "")).strip() or "—"
|
| 512 |
return {"answer_index": idx, "reason": reason}
|
| 513 |
+
idx, snip = score_options_by_context(options, contexts, return_snippet=True)
|
| 514 |
+
return {"answer_index": idx, "reason": snip or "matched by context"}
|
|
|
|
| 515 |
|
| 516 |
def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
|
| 517 |
m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
|
|
|
|
| 540 |
return "Please enter a question.", [], ""
|
| 541 |
# Retrieve
|
| 542 |
ret = retrieve_context_auto(question, k=int(topk), image=image)
|
| 543 |
+
prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800)
|
| 544 |
ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
|
| 545 |
temperature=float(temperature), top_p=float(top_p),
|
| 546 |
top_k=int(top_k), do_sample=False)
|
|
|
|
| 577 |
gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
|
| 578 |
with gr.Tab("Ask"):
|
| 579 |
with gr.Row():
|
| 580 |
+
q = gr.Textbox(label="Question", lines=3)
|
| 581 |
+
img = gr.Image(type="pil", label="Optional image")
|
| 582 |
+
use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False)
|
| 583 |
with gr.Row():
|
| 584 |
+
topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
|
| 585 |
+
max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
|
| 586 |
with gr.Row():
|
| 587 |
+
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
|
| 588 |
+
top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
|
| 589 |
+
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
|
| 590 |
btn = gr.Button("Answer")
|
| 591 |
ans = gr.Textbox(label="Answer", lines=8)
|
| 592 |
route = gr.Textbox(label="Route used (text_e5 or fusion)")
|
| 593 |
table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
|
| 594 |
+
btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
|
| 595 |
+
|
| 596 |
with gr.Tab("MCQ"):
|
| 597 |
with gr.Row():
|
| 598 |
q_mcq = gr.Textbox(label="Question", lines=3)
|
| 599 |
opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
|
| 600 |
+
img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)")
|
| 601 |
with gr.Row():
|
| 602 |
+
topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
|
| 603 |
+
max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
|
| 604 |
with gr.Row():
|
| 605 |
+
temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
|
| 606 |
+
top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
|
| 607 |
+
top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k")
|
| 608 |
btn2 = gr.Button("Answer MCQ")
|
|
|
|
| 609 |
result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
|
| 610 |
raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
|
| 611 |
route2 = gr.Textbox(label="Route used")
|