Spaces:
Sleeping
Sleeping
| import os, io, gc, json, re, ast | |
| from functools import lru_cache | |
| import numpy as np | |
| import pandas as pd | |
| import faiss | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import List, Dict, Any | |
| from PIL import Image, ImageFilter, ImageOps, ImageEnhance | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os, torch | |
| torch.set_num_threads(2) # vCPUهای Space معمولاً 2 تاست | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # ========================= | |
| # Config (override in Space → Settings → Variables & secrets) | |
| # ========================= | |
| DATASET_REPO = os.getenv("DATASET_REPO", "ahm1378/NLP-Project") # <--- CHANGE to your repo | |
| CSV_FILE = os.getenv("CSV_FILE", "final_merged_images.csv") | |
| E5_INDEX_FILE = os.getenv("E5_INDEX_FILE", "faiss_e5_rag_v15.ip") | |
| E5_EMB_FILE = os.getenv("E5_EMB_FILE", "doc_embeds_e5_rag_v15.npy") | |
| FUSION_INDEX_FILE = os.getenv("FUSION_INDEX_FILE", "faiss_fusion.ip") | |
| FUSION_EMB_FILE = os.getenv("FUSION_EMB_FILE", "fusion_doc_emb.npy") | |
| FT_HEAD_FILE = os.getenv("FT_HEAD_FILE", "finetune_clip_fa.pt") # your finetuned text projection (CLIP space) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is private | |
| # Models (CPU-friendly defaults; override via env if desired) | |
| E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small") | |
| CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1") | |
| LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct") | |
| # خروجی کوتاهتر | |
| MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96")) # قبلاً 256 | |
| # نمونهبرداری خاموش (قطعی و سریعتر) | |
| TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0")) | |
| TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "1.0")) | |
| TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "50")) | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def normalize_digits_months(s: str) -> str: | |
| if not isinstance(s, str): | |
| s = str(s) | |
| trans = str.maketrans("۰۱۲۳۴۵۶۷۸۹٠١٢٣٤٥٦٧٨٩", "01234567890123456789") | |
| s = s.translate(trans).replace("\u200c", " ").strip() | |
| return s | |
| def _truncate_chars(s: str, limit: int) -> str: | |
| return s if (limit is None or len(s) <= limit) else s[:limit] + "…" | |
| def _maybe_hub(file, repo=DATASET_REPO, repo_type="dataset") -> str: | |
| # If present locally, use it. Otherwise download from Hub. | |
| if os.path.isfile(file): | |
| return file | |
| return hf_hub_download(repo_id=repo, filename=file, repo_type=repo_type, token=HF_TOKEN) | |
| # ========================= | |
| # Fetch artifacts | |
| # ========================= | |
| CSV_PATH = _maybe_hub(CSV_FILE) | |
| E5_INDEX_PATH = _maybe_hub(E5_INDEX_FILE) | |
| # (E5_EMB_PATH not strictly needed at runtime) | |
| FUSION_INDEX_PATH = _maybe_hub(FUSION_INDEX_FILE) if FUSION_INDEX_FILE else None | |
| FT_HEAD_PATH = _maybe_hub(FT_HEAD_FILE) if FT_HEAD_FILE else None | |
| # ========================= | |
| # Load dataframe | |
| # ========================= | |
| if not os.path.isfile(CSV_PATH): | |
| raise FileNotFoundError(f"CSV missing: {CSV_PATH}") | |
| df = pd.read_csv(CSV_PATH) | |
| # Expect columns: 'id', 'bio', 'image_paths_abs' (list or stringified list) | |
| def first_image(x): | |
| if isinstance(x, list) and x: | |
| return x[0] | |
| if isinstance(x, str) and x.strip(): | |
| # try JSON list | |
| try: | |
| lst = json.loads(x) | |
| if isinstance(lst, list) and lst: | |
| return lst[0] | |
| except Exception: | |
| # try Python literal list (handles single quotes) | |
| try: | |
| lst = ast.literal_eval(x) | |
| if isinstance(lst, list) and lst: | |
| return lst[0] | |
| except Exception: | |
| return x # treat as single path | |
| return "" | |
| if "image_paths_abs" in df.columns: | |
| df["first_image"] = df["image_paths_abs"].apply(first_image) | |
| else: | |
| df["first_image"] = "" | |
| if "bio" not in df.columns: | |
| raise KeyError("Expected 'bio' column in CSV.") | |
| df["bio"] = df["bio"].astype(str) | |
| # ========================= | |
| # Indices | |
| # ========================= | |
| if not os.path.isfile(E5_INDEX_PATH): | |
| raise FileNotFoundError(f"E5 index not found: {E5_INDEX_PATH}") | |
| index_e5 = faiss.read_index(E5_INDEX_PATH) | |
| index_fusion = None | |
| if FUSION_INDEX_PATH and os.path.isfile(FUSION_INDEX_PATH): | |
| index_fusion = faiss.read_index(FUSION_INDEX_PATH) | |
| # ========================= | |
| # Models (CPU-only) | |
| # ========================= | |
| device = "cpu" | |
| dtype = torch.float32 | |
| # Text retrieval encoder (E5) | |
| st_e5 = SentenceTransformer(E5_ID, device=device) | |
| # CLIP text encoder (fallback when no FT head) | |
| st_clip_txt = SentenceTransformer(CLIP_TXT_ID, device=device).eval() | |
| # Optional: finetuned CLIP text projection head (512->512, bias=False) | |
| mclip = SentenceTransformer(CLIP_TXT_ID, device=device).eval() | |
| proj_txt = None | |
| if FT_HEAD_PATH and os.path.isfile(FT_HEAD_PATH): | |
| try: | |
| proj_txt = torch.nn.Linear(512, 512, bias=False) | |
| ckpt = torch.load(FT_HEAD_PATH, map_location="cpu") | |
| if "proj_txt" in ckpt: | |
| proj_txt.load_state_dict(ckpt["proj_txt"]) | |
| elif "state_dict" in ckpt: | |
| proj_txt.load_state_dict(ckpt["state_dict"]) | |
| else: | |
| raise KeyError("No 'proj_txt' or 'state_dict' key in FT checkpoint.") | |
| proj_txt.eval() | |
| print("[OK] loaded finetuned projection head:", FT_HEAD_PATH) | |
| except Exception as e: | |
| print("[WARN] failed to load finetuned head:", e) | |
| proj_txt = None | |
| # Lazy CLIP image encoder (only load if user actually does fusion) | |
| clip_model = None | |
| clip_preprocess = None | |
| def _ensure_clip_loaded(): | |
| global clip_model, clip_preprocess | |
| if clip_model is None: | |
| import open_clip # lazy import | |
| model, _, preprocess_val = open_clip.create_model_and_transforms( | |
| "ViT-B-32", pretrained="laion2b_s34b_b79k", device="cpu" | |
| ) | |
| clip_model = model.eval() | |
| clip_preprocess = preprocess_val | |
| print("[OK] CLIP ViT-B/32 loaded on CPU") | |
| # LLM (small; CPU-friendly) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LLM_ID, | |
| torch_dtype=dtype, | |
| ).to("cpu").eval() | |
| # ========================= | |
| # Retrieval helpers | |
| # ========================= | |
| def _encode_query_e5_cached(q: str) -> np.ndarray: | |
| qn = "query: " + normalize_digits_months(q) | |
| v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0] | |
| return v.astype("float32") | |
| # استفاده بهجای قدیمی: | |
| def _encode_query_e5(q: str) -> np.ndarray: | |
| return _encode_query_e5_cached(q) | |
| def _faiss_search(index, q_vec: np.ndarray, k: int): | |
| if q_vec.ndim == 1: | |
| q_vec = q_vec[None, :] | |
| s, I = index.search(q_vec.astype("float32"), k) | |
| return list(zip(I[0].tolist(), s[0].tolist())) | |
| def search_text_rag(query_text: str, k: int = 5): | |
| q = _encode_query_e5(query_text) | |
| return _faiss_search(index_e5, q, k) | |
| # ---- Fusion (CLIP space) ---- | |
| def _jpeg(img, quality=40): | |
| buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality, optimize=False) | |
| buf.seek(0); return Image.open(buf).convert("RGB") | |
| def _rand_resized_crop(img, scale=(0.7, 0.9)): | |
| w,h = img.size; s = np.random.uniform(*scale) | |
| nw,nh = max(1,int(w*s)), max(1,int(h*s)) | |
| left = np.random.randint(0, max(1, w-nw)) | |
| top = np.random.randint(0, max(1, h-nh)) | |
| return img.crop((left, top, left+nw, top+nh)).resize((w, h), Image.BICUBIC) | |
| def _color_jitter(img, b=(0.9,1.1), c=(0.9,1.1)): | |
| img = ImageOps.autocontrast(img) | |
| img = ImageEnhance.Brightness(img).enhance(np.random.uniform(*b)) | |
| img = ImageEnhance.Contrast(img).enhance(np.random.uniform(*c)) | |
| return img | |
| def augment_once(img: Image.Image, level="medium"): | |
| if level == "mild": | |
| img = _rand_resized_crop(img, (0.85, 0.95)); img = _jpeg(img, 60) | |
| elif level == "medium": | |
| img = _rand_resized_crop(img, (0.7, 0.9)) | |
| img = img.filter(ImageFilter.GaussianBlur(1.0)) | |
| img = _color_jitter(img, (0.9,1.1), (0.9,1.1)); img = _jpeg(img, 40) | |
| else: | |
| img = _rand_resized_crop(img, (0.6, 0.8)) | |
| img = img.filter(ImageFilter.GaussianBlur(1.2)); img = _jpeg(img, 30) | |
| return img | |
| def _encode_pil_clip(img: Image.Image) -> np.ndarray: | |
| _ensure_clip_loaded() | |
| t = clip_preprocess(img).unsqueeze(0) | |
| feat = clip_model.encode_image(t) | |
| feat = F.normalize(feat.float(), dim=-1) | |
| return feat.cpu().numpy().astype("float32") # (1,512) | |
| def _encode_query_text_clipspace(q: str) -> np.ndarray: | |
| qn = normalize_digits_months(q) | |
| if proj_txt is not None: | |
| # mclip raw → proj → normalize | |
| t = torch.tensor( | |
| mclip.encode([qn], convert_to_numpy=True, normalize_embeddings=False), | |
| dtype=torch.float32 | |
| ) | |
| t = proj_txt(t) | |
| t = F.normalize(t, dim=-1).cpu().numpy().astype("float32") | |
| return t | |
| else: | |
| # fallback: CLIP multilingual text encoder (already normalized) | |
| t = st_clip_txt.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True) | |
| return t.astype("float32") | |
| def make_query_embed(query_text: str, | |
| image: Image.Image = None, | |
| alpha_q: float = 0.7, | |
| use_aug: bool = True, | |
| n_aug: int = 3) -> np.ndarray: | |
| qt = _encode_query_text_clipspace(query_text) # (1,512) | |
| qi = None | |
| if image is not None: | |
| if clip_model is None: # ensure loaded only if needed | |
| _ensure_clip_loaded() | |
| if use_aug: | |
| feats = [ _encode_pil_clip(augment_once(image, "medium")) for _ in range(max(1,int(n_aug))) ] | |
| qi = np.mean(np.vstack(feats), axis=0, keepdims=True).astype("float32") | |
| else: | |
| qi = _encode_pil_clip(image) | |
| if qi is not None: | |
| qv = torch.from_numpy(alpha_q*qt + (1.0-alpha_q)*qi) | |
| qv = F.normalize(qv, dim=-1).cpu().numpy().astype("float32") | |
| return qv | |
| return qt | |
| def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7): | |
| if index_fusion is None: | |
| raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).") | |
| qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3) | |
| return _faiss_search(index_fusion, qv, k) | |
| # ========================= | |
| # RAG + LLM | |
| # ========================= | |
| def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None) -> Dict[str, Any]: | |
| q = normalize_digits_months(question) | |
| if (image is not None): | |
| route = "fusion" | |
| try: | |
| hits = search_fusion(q, image=image, k=k) | |
| except Exception as e: | |
| route = "text_e5" # graceful fallback | |
| hits = search_text_rag(q, k=k) | |
| else: | |
| route = "text_e5" | |
| hits = search_text_rag(q, k=k) | |
| ctxs = [] | |
| for idx, score in hits: | |
| if 0 <= idx < len(df): | |
| row = df.iloc[idx] | |
| ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])}) | |
| return {"route": route, "contexts": ctxs} | |
| def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str: | |
| sys_fa = "تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. اگر پاسخی در متنها نبود، صادقانه بگو «در متنهای بازیابیشده پاسخی پیدا نشد.»" | |
| sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'" | |
| system_text = sys_fa if lang == "fa" else sys_en | |
| parts = [] | |
| for i, c in enumerate(contexts, 1): | |
| bi = c["bio"].strip() | |
| if bi: | |
| parts.append(f"[{i}] {bi}") | |
| joined = _truncate_chars("\n\n".join(parts), max_chars) | |
| user = (f"سؤال: {question}\n\nمتون بازیابیشده:\n{joined}\n\n" | |
| f"فقط با اتکا به متون بالا پاسخ بده و منابع را با [1], [2], ... ارجاع بده." | |
| ) if lang == "fa" else ( | |
| f"Question: {question}\n\nRetrieved passages:\n{joined}\n\n" | |
| f"Answer only using the passages, cite sources as [1], [2], ..." | |
| ) | |
| msgs = [{"role": "system", "content": system_text}, | |
| {"role": "user", "content": user}] | |
| return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, # قطعی | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| num_beams=1, # بدون beam-search | |
| use_cache=True, # سریعتر | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(out[0], skip_special_tokens=True) | |
| if text.startswith(prompt): | |
| text = text[len(prompt):] | |
| return text.strip() | |
| # ---- MCQ helpers ---- | |
| def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str: | |
| sys_fa = ( | |
| "تو یک دستیار پاسخگو هستی که فقط بر اساس متنهای دادهشده پاسخ میدهی. " | |
| "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی." | |
| ) | |
| sys_en = ( | |
| "You are a helpful assistant. Answer ONLY using the retrieved passages. " | |
| "You MUST return a single JSON object and nothing else." | |
| ) | |
| system_text = sys_fa if lang == "fa" else sys_en | |
| parts = [] | |
| for i, c in enumerate(contexts, 1): | |
| bi = c["bio"].strip() | |
| if bi: | |
| parts.append(f"[{i}] {bi}") | |
| joined = _truncate_chars("\n\n".join(parts), max_chars) | |
| labels = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
| opts_str = "\n".join([f"{labels[i]}) {o}" for i, o in enumerate(options)]) | |
| if lang == "fa": | |
| user = ( | |
| f"سؤال: {question}\n\nگزینهها:\n{opts_str}\n\nمتون بازیابیشده:\n{joined}\n\n" | |
| "دقیقاً و فقط یک JSON برگردان. فرمت اجباری: " | |
| '{"answer_index": X, "reason": "…"} ' | |
| "که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس." | |
| ) | |
| else: | |
| user = ( | |
| f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n" | |
| 'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} ' | |
| "where X is the 0-based option index. Do not write anything else." | |
| ) | |
| msgs = [{"role": "system", "content": system_text}, | |
| {"role": "user", "content": user}] | |
| return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| import json as _json | |
| import re as _re | |
| import numpy as _np | |
| def _strict_json_from_text(text: str): | |
| # فقط اولین بلاک {...} را بگیر و JSON-parse کن | |
| m = _re.search(r'\{.*\}', text, _re.S) | |
| if not m: | |
| return None | |
| frag = m.group(0) | |
| try: | |
| obj = _json.loads(frag) | |
| return obj | |
| except Exception: | |
| return None | |
| import re as _re | |
| import numpy as _np | |
| def _norm_text_for_match(s: str) -> str: | |
| # نرمالسازی ساده: اعداد فارسی/عربی، ZWNJ، فاصلههای اضافه | |
| s = normalize_digits_months(s or "") | |
| s = s.replace("\u200c", " ").strip() | |
| # پایینحرفی و تکفاصله | |
| s = _re.sub(r"\s+", " ", s.lower()) | |
| return s | |
| def _find_snippet(hay: str, needle: str, win: int = 60) -> str: | |
| """یک تکه متن کوتاه اطراف اولین مچ را بده.""" | |
| try: | |
| i = hay.index(needle) | |
| start = max(0, i - win) | |
| end = min(len(hay), i + len(needle) + win) | |
| return hay[start:end].replace("\n", " ") | |
| except ValueError: | |
| return "" | |
| def score_options_by_context( | |
| options: List[str], | |
| contexts: List[Dict[str, Any]], | |
| return_snippet: bool = False | |
| ): | |
| """ | |
| فالبک هوشمند: | |
| 1) boundary-aware substring در تکتک کانتکستها (امتیاز بالا + تعداد وقوع) | |
| 2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکستها | |
| خروجی: | |
| - اگر return_snippet=False → فقط best_idx (int) | |
| - اگر return_snippet=True → (best_idx, snippet) برمیگرداند | |
| """ | |
| # آمادهسازی کانتکستها | |
| raw_ctxs = [c.get("bio", "") for c in contexts] | |
| norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs] | |
| joined_norm = " \n ".join(norm_ctxs) | |
| # 1) جستوجوی دقیقتر: word boundary + شمارش | |
| # برای فارسی/عربی هم خوب جواب میدهد چون از فاصله استفاده میکنیم. | |
| best_idx, best_score, best_snip = 0, -1.0, "" | |
| for i, opt in enumerate(options): | |
| o_raw = str(opt).strip() | |
| o = _norm_text_for_match(o_raw) | |
| if not o: | |
| continue | |
| # الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله) | |
| # اگر گزینه چندکلمهای است، همین هم خوب جواب میدهد. | |
| # اگر لازم شد میتوان regex دقیقتر نوشت. | |
| pat = r"(?<!\S)" + _re.escape(o) + r"(?!\S)" | |
| total_hits = 0 | |
| first_snip = "" | |
| for raw, norm in zip(raw_ctxs, norm_ctxs): | |
| for m in _re.finditer(pat, norm): | |
| total_hits += 1 | |
| if not first_snip: | |
| # اسنیپت از متن خام (خواناتر) | |
| # موقعیت متن خام را تقریبی میگیریم با جستوجوی ساده | |
| # (اگر اختلاف normalization زیاد بود، از norm استفاده میکنیم) | |
| sn = _find_snippet(raw, o_raw) or _find_snippet(norm, o) | |
| first_snip = sn | |
| if total_hits > 0: | |
| # امتیاز بالا برای مچ صریح + تعداد وقوع | |
| score = 10000.0 + total_hits | |
| if score > best_score: | |
| best_score, best_idx, best_snip = score, i, first_snip | |
| if best_score > 0: | |
| return (best_idx, best_snip) if return_snippet else best_idx | |
| # 2) اگر هیچ مچی نبود → شباهت embedding (mE5) | |
| try: | |
| # وکتور کل کانتکستها (یکبار) | |
| ctx_vec = _encode_query_e5(joined_norm) # (dim,) | |
| sims = [] | |
| for opt in options: | |
| qv = _encode_query_e5(str(opt)) | |
| sims.append(float(_np.dot(qv, ctx_vec))) | |
| best_idx = int(_np.argmax(sims)) | |
| # برای snippet در این مسیر: نزدیکترین کانتکست را با dot جداگانه پیدا کنیم | |
| # (سریع و بهاندازه کافی خوب) | |
| best_snip = "" | |
| try: | |
| opt_vec = _encode_query_e5(str(options[best_idx])) | |
| # کوساین تقریباً همان inner-prod چون نرمال شدهاند | |
| # امتیاز هر کانتکست با گزینهی برنده: | |
| c_scores = [] | |
| for raw, norm in zip(raw_ctxs, norm_ctxs): | |
| c_vec = _encode_query_e5(norm) | |
| c_scores.append(float(_np.dot(opt_vec, c_vec))) | |
| j = int(_np.argmax(c_scores)) | |
| best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ") | |
| except Exception: | |
| pass | |
| return (best_idx, best_snip) if return_snippet else best_idx | |
| except Exception: | |
| return (0, "") if return_snippet else 0 # پیشفرض محافظهکارانه | |
| def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| obj = _strict_json_from_text(text) | |
| if obj and "answer_index" in obj: | |
| idx = obj["answer_index"] | |
| if isinstance(idx, int) and 0 <= idx < len(options): | |
| reason = str(obj.get("reason", "")).strip() or "—" | |
| return {"answer_index": idx, "reason": reason} | |
| idx, snip = score_options_by_context(options, contexts, return_snippet=True) | |
| return {"answer_index": idx, "reason": snip or "matched by context"} | |
| def parse_mcq_output(text: str, n: int) -> Dict[str, Any]: | |
| m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S) | |
| if m: | |
| idx = int(m.group(1)); reason = m.group(2).strip() | |
| if 0 <= idx < n: | |
| return {"answer_index": idx, "reason": reason} | |
| letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
| m2 = re.search(r'\b([A-D])\b', text, re.I) | |
| if m2: | |
| idx = letters.index(m2.group(1).upper()) | |
| if idx < n: | |
| return {"answer_index": idx, "reason": text.strip()} | |
| m3 = re.search(r'\b([1-9])\b', text) | |
| if m3: | |
| idx = int(m3.group(1)) - 1 | |
| if 0 <= idx < n: | |
| return {"answer_index": idx, "reason": text.strip()} | |
| return {"answer_index": None, "reason": text.strip()} | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k): | |
| if not question or not question.strip(): | |
| return "Please enter a question.", [], "" | |
| # Retrieve | |
| ret = retrieve_context_auto(question, k=int(topk), image=image) | |
| prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800) | |
| ans = llm_generate(prompt, max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), top_p=float(top_p), | |
| top_k=int(top_k), do_sample=False) | |
| # Sources | |
| rows = [] | |
| for i, c in enumerate(ret["contexts"], 1): | |
| snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "") | |
| rows.append([i, c["id"], round(c["score"], 4), snip]) | |
| return ans, rows, ret["route"] | |
| def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, top_k): | |
| opts = [o.strip() for o in (options_txt or "").splitlines() if o.strip()] | |
| if not question or len(opts) < 2: | |
| return "Provide a question and at least 2 options.", "", [], "" | |
| ret = retrieve_context_auto(question, k=int(topk), image=image) | |
| prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000) | |
| out = llm_generate(prompt, max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), top_p=float(top_p), | |
| top_k=int(top_k), do_sample=False) # deterministic on CPU | |
| parsed = parse_mcq_output_strict(out, opts, ret["contexts"]) | |
| pred = parsed["answer_index"] | |
| pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A") | |
| rows = [] | |
| for i, c in enumerate(ret["contexts"], 1): | |
| snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "") | |
| rows.append([i, c["id"], round(c["score"], 4), snip]) | |
| result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}" | |
| return result, out, rows, ret["route"] | |
| with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo: | |
| gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B") | |
| with gr.Tab("Ask"): | |
| with gr.Row(): | |
| q = gr.Textbox(label="Question", lines=3) | |
| img = gr.Image(type="pil", label="Optional image") | |
| use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False) | |
| with gr.Row(): | |
| topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve") | |
| max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p") | |
| top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
| btn = gr.Button("Answer") | |
| ans = gr.Textbox(label="Answer", lines=8) | |
| route = gr.Textbox(label="Route used (text_e5 or fusion)") | |
| table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False) | |
| btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route]) | |
| with gr.Tab("MCQ"): | |
| with gr.Row(): | |
| q_mcq = gr.Textbox(label="Question", lines=3) | |
| opts_mcq = gr.Textbox(label="Options (one per line)", lines=8) | |
| img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)") | |
| with gr.Row(): | |
| topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve") | |
| max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens") | |
| with gr.Row(): | |
| temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") | |
| top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p") | |
| top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
| btn2 = gr.Button("Answer MCQ") | |
| result = gr.Textbox(label="Prediction", lines=12, max_lines=20) | |
| raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20) | |
| route2 = gr.Textbox(label="Route used") | |
| table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False) | |
| btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2], | |
| [result, raw, table2, route2]) | |
| demo.launch() | |