amirhossein mohammadpour
hanlde speed
73a17b2
import os, io, gc, json, re, ast
from functools import lru_cache
import numpy as np
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
from typing import List, Dict, Any
from PIL import Image, ImageFilter, ImageOps, ImageEnhance
import gradio as gr
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import os, torch
torch.set_num_threads(2) # vCPUهای Space معمولاً 2 تاست
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# =========================
# Config (override in Space → Settings → Variables & secrets)
# =========================
DATASET_REPO = os.getenv("DATASET_REPO", "ahm1378/NLP-Project") # <--- CHANGE to your repo
CSV_FILE = os.getenv("CSV_FILE", "final_merged_images.csv")
E5_INDEX_FILE = os.getenv("E5_INDEX_FILE", "faiss_e5_rag_v15.ip")
E5_EMB_FILE = os.getenv("E5_EMB_FILE", "doc_embeds_e5_rag_v15.npy")
FUSION_INDEX_FILE = os.getenv("FUSION_INDEX_FILE", "faiss_fusion.ip")
FUSION_EMB_FILE = os.getenv("FUSION_EMB_FILE", "fusion_doc_emb.npy")
FT_HEAD_FILE = os.getenv("FT_HEAD_FILE", "finetune_clip_fa.pt") # your finetuned text projection (CLIP space)
HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is private
# Models (CPU-friendly defaults; override via env if desired)
E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct")
# خروجی کوتاه‌تر
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96")) # قبلاً 256
# نمونه‌برداری خاموش (قطعی و سریع‌تر)
TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0"))
TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "1.0"))
TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "50"))
# =========================
# Helpers
# =========================
def normalize_digits_months(s: str) -> str:
if not isinstance(s, str):
s = str(s)
trans = str.maketrans("۰۱۲۳۴۵۶۷۸۹٠١٢٣٤٥٦٧٨٩", "01234567890123456789")
s = s.translate(trans).replace("\u200c", " ").strip()
return s
def _truncate_chars(s: str, limit: int) -> str:
return s if (limit is None or len(s) <= limit) else s[:limit] + "…"
def _maybe_hub(file, repo=DATASET_REPO, repo_type="dataset") -> str:
# If present locally, use it. Otherwise download from Hub.
if os.path.isfile(file):
return file
return hf_hub_download(repo_id=repo, filename=file, repo_type=repo_type, token=HF_TOKEN)
# =========================
# Fetch artifacts
# =========================
CSV_PATH = _maybe_hub(CSV_FILE)
E5_INDEX_PATH = _maybe_hub(E5_INDEX_FILE)
# (E5_EMB_PATH not strictly needed at runtime)
FUSION_INDEX_PATH = _maybe_hub(FUSION_INDEX_FILE) if FUSION_INDEX_FILE else None
FT_HEAD_PATH = _maybe_hub(FT_HEAD_FILE) if FT_HEAD_FILE else None
# =========================
# Load dataframe
# =========================
if not os.path.isfile(CSV_PATH):
raise FileNotFoundError(f"CSV missing: {CSV_PATH}")
df = pd.read_csv(CSV_PATH)
# Expect columns: 'id', 'bio', 'image_paths_abs' (list or stringified list)
def first_image(x):
if isinstance(x, list) and x:
return x[0]
if isinstance(x, str) and x.strip():
# try JSON list
try:
lst = json.loads(x)
if isinstance(lst, list) and lst:
return lst[0]
except Exception:
# try Python literal list (handles single quotes)
try:
lst = ast.literal_eval(x)
if isinstance(lst, list) and lst:
return lst[0]
except Exception:
return x # treat as single path
return ""
if "image_paths_abs" in df.columns:
df["first_image"] = df["image_paths_abs"].apply(first_image)
else:
df["first_image"] = ""
if "bio" not in df.columns:
raise KeyError("Expected 'bio' column in CSV.")
df["bio"] = df["bio"].astype(str)
# =========================
# Indices
# =========================
if not os.path.isfile(E5_INDEX_PATH):
raise FileNotFoundError(f"E5 index not found: {E5_INDEX_PATH}")
index_e5 = faiss.read_index(E5_INDEX_PATH)
index_fusion = None
if FUSION_INDEX_PATH and os.path.isfile(FUSION_INDEX_PATH):
index_fusion = faiss.read_index(FUSION_INDEX_PATH)
# =========================
# Models (CPU-only)
# =========================
device = "cpu"
dtype = torch.float32
# Text retrieval encoder (E5)
st_e5 = SentenceTransformer(E5_ID, device=device)
# CLIP text encoder (fallback when no FT head)
st_clip_txt = SentenceTransformer(CLIP_TXT_ID, device=device).eval()
# Optional: finetuned CLIP text projection head (512->512, bias=False)
mclip = SentenceTransformer(CLIP_TXT_ID, device=device).eval()
proj_txt = None
if FT_HEAD_PATH and os.path.isfile(FT_HEAD_PATH):
try:
proj_txt = torch.nn.Linear(512, 512, bias=False)
ckpt = torch.load(FT_HEAD_PATH, map_location="cpu")
if "proj_txt" in ckpt:
proj_txt.load_state_dict(ckpt["proj_txt"])
elif "state_dict" in ckpt:
proj_txt.load_state_dict(ckpt["state_dict"])
else:
raise KeyError("No 'proj_txt' or 'state_dict' key in FT checkpoint.")
proj_txt.eval()
print("[OK] loaded finetuned projection head:", FT_HEAD_PATH)
except Exception as e:
print("[WARN] failed to load finetuned head:", e)
proj_txt = None
# Lazy CLIP image encoder (only load if user actually does fusion)
clip_model = None
clip_preprocess = None
def _ensure_clip_loaded():
global clip_model, clip_preprocess
if clip_model is None:
import open_clip # lazy import
model, _, preprocess_val = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79k", device="cpu"
)
clip_model = model.eval()
clip_preprocess = preprocess_val
print("[OK] CLIP ViT-B/32 loaded on CPU")
# LLM (small; CPU-friendly)
tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
LLM_ID,
torch_dtype=dtype,
).to("cpu").eval()
# =========================
# Retrieval helpers
# =========================
@lru_cache(maxsize=4096)
def _encode_query_e5_cached(q: str) -> np.ndarray:
qn = "query: " + normalize_digits_months(q)
v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
return v.astype("float32")
# استفاده به‌جای قدیمی:
def _encode_query_e5(q: str) -> np.ndarray:
return _encode_query_e5_cached(q)
def _faiss_search(index, q_vec: np.ndarray, k: int):
if q_vec.ndim == 1:
q_vec = q_vec[None, :]
s, I = index.search(q_vec.astype("float32"), k)
return list(zip(I[0].tolist(), s[0].tolist()))
def search_text_rag(query_text: str, k: int = 5):
q = _encode_query_e5(query_text)
return _faiss_search(index_e5, q, k)
# ---- Fusion (CLIP space) ----
def _jpeg(img, quality=40):
buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality, optimize=False)
buf.seek(0); return Image.open(buf).convert("RGB")
def _rand_resized_crop(img, scale=(0.7, 0.9)):
w,h = img.size; s = np.random.uniform(*scale)
nw,nh = max(1,int(w*s)), max(1,int(h*s))
left = np.random.randint(0, max(1, w-nw))
top = np.random.randint(0, max(1, h-nh))
return img.crop((left, top, left+nw, top+nh)).resize((w, h), Image.BICUBIC)
def _color_jitter(img, b=(0.9,1.1), c=(0.9,1.1)):
img = ImageOps.autocontrast(img)
img = ImageEnhance.Brightness(img).enhance(np.random.uniform(*b))
img = ImageEnhance.Contrast(img).enhance(np.random.uniform(*c))
return img
def augment_once(img: Image.Image, level="medium"):
if level == "mild":
img = _rand_resized_crop(img, (0.85, 0.95)); img = _jpeg(img, 60)
elif level == "medium":
img = _rand_resized_crop(img, (0.7, 0.9))
img = img.filter(ImageFilter.GaussianBlur(1.0))
img = _color_jitter(img, (0.9,1.1), (0.9,1.1)); img = _jpeg(img, 40)
else:
img = _rand_resized_crop(img, (0.6, 0.8))
img = img.filter(ImageFilter.GaussianBlur(1.2)); img = _jpeg(img, 30)
return img
@torch.no_grad()
def _encode_pil_clip(img: Image.Image) -> np.ndarray:
_ensure_clip_loaded()
t = clip_preprocess(img).unsqueeze(0)
feat = clip_model.encode_image(t)
feat = F.normalize(feat.float(), dim=-1)
return feat.cpu().numpy().astype("float32") # (1,512)
@torch.no_grad()
def _encode_query_text_clipspace(q: str) -> np.ndarray:
qn = normalize_digits_months(q)
if proj_txt is not None:
# mclip raw → proj → normalize
t = torch.tensor(
mclip.encode([qn], convert_to_numpy=True, normalize_embeddings=False),
dtype=torch.float32
)
t = proj_txt(t)
t = F.normalize(t, dim=-1).cpu().numpy().astype("float32")
return t
else:
# fallback: CLIP multilingual text encoder (already normalized)
t = st_clip_txt.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)
return t.astype("float32")
@torch.no_grad()
def make_query_embed(query_text: str,
image: Image.Image = None,
alpha_q: float = 0.7,
use_aug: bool = True,
n_aug: int = 3) -> np.ndarray:
qt = _encode_query_text_clipspace(query_text) # (1,512)
qi = None
if image is not None:
if clip_model is None: # ensure loaded only if needed
_ensure_clip_loaded()
if use_aug:
feats = [ _encode_pil_clip(augment_once(image, "medium")) for _ in range(max(1,int(n_aug))) ]
qi = np.mean(np.vstack(feats), axis=0, keepdims=True).astype("float32")
else:
qi = _encode_pil_clip(image)
if qi is not None:
qv = torch.from_numpy(alpha_q*qt + (1.0-alpha_q)*qi)
qv = F.normalize(qv, dim=-1).cpu().numpy().astype("float32")
return qv
return qt
def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
if index_fusion is None:
raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3)
return _faiss_search(index_fusion, qv, k)
# =========================
# RAG + LLM
# =========================
def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None) -> Dict[str, Any]:
q = normalize_digits_months(question)
if (image is not None):
route = "fusion"
try:
hits = search_fusion(q, image=image, k=k)
except Exception as e:
route = "text_e5" # graceful fallback
hits = search_text_rag(q, k=k)
else:
route = "text_e5"
hits = search_text_rag(q, k=k)
ctxs = []
for idx, score in hits:
if 0 <= idx < len(df):
row = df.iloc[idx]
ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
return {"route": route, "contexts": ctxs}
def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»"
sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
system_text = sys_fa if lang == "fa" else sys_en
parts = []
for i, c in enumerate(contexts, 1):
bi = c["bio"].strip()
if bi:
parts.append(f"[{i}] {bi}")
joined = _truncate_chars("\n\n".join(parts), max_chars)
user = (f"سؤال: {question}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
f"فقط با اتکا به متون بالا پاسخ بده و منابع را با [1], [2], ... ارجاع بده."
) if lang == "fa" else (
f"Question: {question}\n\nRetrieved passages:\n{joined}\n\n"
f"Answer only using the passages, cite sources as [1], [2], ..."
)
msgs = [{"role": "system", "content": system_text},
{"role": "user", "content": user}]
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
@torch.inference_mode()
def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # قطعی
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=1, # بدون beam-search
use_cache=True, # سریع‌تر
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(out[0], skip_special_tokens=True)
if text.startswith(prompt):
text = text[len(prompt):]
return text.strip()
# ---- MCQ helpers ----
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
sys_fa = (
"تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. "
"باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
)
sys_en = (
"You are a helpful assistant. Answer ONLY using the retrieved passages. "
"You MUST return a single JSON object and nothing else."
)
system_text = sys_fa if lang == "fa" else sys_en
parts = []
for i, c in enumerate(contexts, 1):
bi = c["bio"].strip()
if bi:
parts.append(f"[{i}] {bi}")
joined = _truncate_chars("\n\n".join(parts), max_chars)
labels = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
opts_str = "\n".join([f"{labels[i]}) {o}" for i, o in enumerate(options)])
if lang == "fa":
user = (
f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
"دقیقاً و فقط یک JSON برگردان. فرمت اجباری: "
'{"answer_index": X, "reason": "…"} '
"که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس."
)
else:
user = (
f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} '
"where X is the 0-based option index. Do not write anything else."
)
msgs = [{"role": "system", "content": system_text},
{"role": "user", "content": user}]
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
import json as _json
import re as _re
import numpy as _np
def _strict_json_from_text(text: str):
# فقط اولین بلاک {...} را بگیر و JSON-parse کن
m = _re.search(r'\{.*\}', text, _re.S)
if not m:
return None
frag = m.group(0)
try:
obj = _json.loads(frag)
return obj
except Exception:
return None
import re as _re
import numpy as _np
def _norm_text_for_match(s: str) -> str:
# نرمال‌سازی ساده: اعداد فارسی/عربی، ZWNJ، فاصله‌های اضافه
s = normalize_digits_months(s or "")
s = s.replace("\u200c", " ").strip()
# پایین‌حرفی و تک‌فاصله
s = _re.sub(r"\s+", " ", s.lower())
return s
def _find_snippet(hay: str, needle: str, win: int = 60) -> str:
"""یک تکه متن کوتاه اطراف اولین مچ را بده."""
try:
i = hay.index(needle)
start = max(0, i - win)
end = min(len(hay), i + len(needle) + win)
return hay[start:end].replace("\n", " ")
except ValueError:
return ""
def score_options_by_context(
options: List[str],
contexts: List[Dict[str, Any]],
return_snippet: bool = False
):
"""
فال‌بک هوشمند:
1) boundary-aware substring در تک‌تک کانتکست‌ها (امتیاز بالا + تعداد وقوع)
2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکست‌ها
خروجی:
- اگر return_snippet=False → فقط best_idx (int)
- اگر return_snippet=True → (best_idx, snippet) برمی‌گرداند
"""
# آماده‌سازی کانتکست‌ها
raw_ctxs = [c.get("bio", "") for c in contexts]
norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs]
joined_norm = " \n ".join(norm_ctxs)
# 1) جست‌وجوی دقیق‌تر: word boundary + شمارش
# برای فارسی/عربی هم خوب جواب می‌دهد چون از فاصله استفاده می‌کنیم.
best_idx, best_score, best_snip = 0, -1.0, ""
for i, opt in enumerate(options):
o_raw = str(opt).strip()
o = _norm_text_for_match(o_raw)
if not o:
continue
# الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله)
# اگر گزینه چندکلمه‌ای است، همین هم خوب جواب می‌دهد.
# اگر لازم شد می‌توان regex دقیق‌تر نوشت.
pat = r"(?<!\S)" + _re.escape(o) + r"(?!\S)"
total_hits = 0
first_snip = ""
for raw, norm in zip(raw_ctxs, norm_ctxs):
for m in _re.finditer(pat, norm):
total_hits += 1
if not first_snip:
# اسنیپت از متن خام (خواناتر)
# موقعیت متن خام را تقریبی می‌گیریم با جست‌وجوی ساده
# (اگر اختلاف normalization زیاد بود، از norm استفاده می‌کنیم)
sn = _find_snippet(raw, o_raw) or _find_snippet(norm, o)
first_snip = sn
if total_hits > 0:
# امتیاز بالا برای مچ صریح + تعداد وقوع
score = 10000.0 + total_hits
if score > best_score:
best_score, best_idx, best_snip = score, i, first_snip
if best_score > 0:
return (best_idx, best_snip) if return_snippet else best_idx
# 2) اگر هیچ مچی نبود → شباهت embedding (mE5)
try:
# وکتور کل کانتکست‌ها (یک‌بار)
ctx_vec = _encode_query_e5(joined_norm) # (dim,)
sims = []
for opt in options:
qv = _encode_query_e5(str(opt))
sims.append(float(_np.dot(qv, ctx_vec)))
best_idx = int(_np.argmax(sims))
# برای snippet در این مسیر: نزدیک‌ترین کانتکست را با dot جداگانه پیدا کنیم
# (سریع و به‌اندازه کافی خوب)
best_snip = ""
try:
opt_vec = _encode_query_e5(str(options[best_idx]))
# کوساین تقریباً همان inner-prod چون نرمال شده‌اند
# امتیاز هر کانتکست با گزینه‌ی برنده:
c_scores = []
for raw, norm in zip(raw_ctxs, norm_ctxs):
c_vec = _encode_query_e5(norm)
c_scores.append(float(_np.dot(opt_vec, c_vec)))
j = int(_np.argmax(c_scores))
best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ")
except Exception:
pass
return (best_idx, best_snip) if return_snippet else best_idx
except Exception:
return (0, "") if return_snippet else 0 # پیش‌فرض محافظه‌کارانه
def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
obj = _strict_json_from_text(text)
if obj and "answer_index" in obj:
idx = obj["answer_index"]
if isinstance(idx, int) and 0 <= idx < len(options):
reason = str(obj.get("reason", "")).strip() or "—"
return {"answer_index": idx, "reason": reason}
idx, snip = score_options_by_context(options, contexts, return_snippet=True)
return {"answer_index": idx, "reason": snip or "matched by context"}
def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
if m:
idx = int(m.group(1)); reason = m.group(2).strip()
if 0 <= idx < n:
return {"answer_index": idx, "reason": reason}
letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
m2 = re.search(r'\b([A-D])\b', text, re.I)
if m2:
idx = letters.index(m2.group(1).upper())
if idx < n:
return {"answer_index": idx, "reason": text.strip()}
m3 = re.search(r'\b([1-9])\b', text)
if m3:
idx = int(m3.group(1)) - 1
if 0 <= idx < n:
return {"answer_index": idx, "reason": text.strip()}
return {"answer_index": None, "reason": text.strip()}
# =========================
# Gradio UI
# =========================
def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k):
if not question or not question.strip():
return "Please enter a question.", [], ""
# Retrieve
ret = retrieve_context_auto(question, k=int(topk), image=image)
prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800)
ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
temperature=float(temperature), top_p=float(top_p),
top_k=int(top_k), do_sample=False)
# Sources
rows = []
for i, c in enumerate(ret["contexts"], 1):
snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
rows.append([i, c["id"], round(c["score"], 4), snip])
return ans, rows, ret["route"]
def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, top_k):
opts = [o.strip() for o in (options_txt or "").splitlines() if o.strip()]
if not question or len(opts) < 2:
return "Provide a question and at least 2 options.", "", [], ""
ret = retrieve_context_auto(question, k=int(topk), image=image)
prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
out = llm_generate(prompt, max_new_tokens=int(max_tokens),
temperature=float(temperature), top_p=float(top_p),
top_k=int(top_k), do_sample=False) # deterministic on CPU
parsed = parse_mcq_output_strict(out, opts, ret["contexts"])
pred = parsed["answer_index"]
pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
rows = []
for i, c in enumerate(ret["contexts"], 1):
snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
rows.append([i, c["id"], round(c["score"], 4), snip])
result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
return result, out, rows, ret["route"]
with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
with gr.Tab("Ask"):
with gr.Row():
q = gr.Textbox(label="Question", lines=3)
img = gr.Image(type="pil", label="Optional image")
use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False)
with gr.Row():
topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
with gr.Row():
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
btn = gr.Button("Answer")
ans = gr.Textbox(label="Answer", lines=8)
route = gr.Textbox(label="Route used (text_e5 or fusion)")
table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
with gr.Tab("MCQ"):
with gr.Row():
q_mcq = gr.Textbox(label="Question", lines=3)
opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)")
with gr.Row():
topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
with gr.Row():
temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k")
btn2 = gr.Button("Answer MCQ")
result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
route2 = gr.Textbox(label="Route used")
table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
[result, raw, table2, route2])
demo.launch()