Spaces:
Runtime error
Runtime error
| # app.py — HF Space: hybrid text+image RAG demo (Persian-ready) | |
| import os, json | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor | |
| # ========= CONFIG (override in Space → Settings → Variables) ========= | |
| TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "mamathew/text-ft-food-rag") | |
| CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "mamathew/clip-ft-food-rag") | |
| DATASET_REPO = os.environ.get("DATASET_REPO", "mamathew/food-rag-index") | |
| # Inference API chat model (Gemma IT by default). | |
| LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it") | |
| # ===================================================================== | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---- dataset snapshot (FAISS + metas + optionally images/) ---- | |
| DATA_DIR = snapshot_download(repo_id=DATASET_REPO, repo_type="dataset") | |
| def read_jsonl(path: str): | |
| out = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| out.append(json.loads(line)) | |
| return out | |
| # Load metas & FAISS | |
| TEXT_META = read_jsonl(os.path.join(DATA_DIR, "text_meta.jsonl")) | |
| IMAGE_META = read_jsonl(os.path.join(DATA_DIR, "image_meta.jsonl")) | |
| T_INDEX = faiss.read_index(os.path.join(DATA_DIR, "faiss_text.bin")) | |
| I_INDEX = faiss.read_index(os.path.join(DATA_DIR, "faiss_image.bin")) | |
| # Load encoders | |
| text_enc = SentenceTransformer(TEXT_MODEL_REPO, device=DEVICE) | |
| clip_model = CLIPModel.from_pretrained(CLIP_MODEL_REPO).to(DEVICE) | |
| clip_proc = CLIPProcessor.from_pretrained(CLIP_MODEL_REPO) | |
| # Inference API client (chat-first, with fallback) | |
| try: | |
| from huggingface_hub import InferenceClient | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space → Settings → Repository secrets | |
| client = InferenceClient(model=LLM_ID, token=HF_TOKEN) if HF_TOKEN else InferenceClient(model=LLM_ID) | |
| except Exception: | |
| client = None | |
| # ---------------------- utils & dataclasses ---------------------- | |
| from PIL import Image | |
| def _resolve_path(rel_or_abs: str) -> str: | |
| # If relative, make it under the dataset snapshot root | |
| p = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(DATA_DIR, rel_or_abs) | |
| # Resolve symlinks to a canonical path (helps in HF cache) | |
| return os.path.realpath(p) | |
| def _open_image_safe(path: str): | |
| try: | |
| return Image.open(path).convert("RGB") | |
| except Exception: | |
| return None | |
| def normalize_fa(s: str) -> str: | |
| if not s: return s | |
| return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip()) | |
| def _truncate(s: str, max_chars: int = 1200) -> str: | |
| if not s: return "" | |
| s = s.strip().replace("\r", " ") | |
| return s[:max_chars] | |
| def _get_meta_text(m: dict) -> Optional[str]: | |
| for k in ("text","content","passage","body","chunk","article","description"): | |
| if m.get(k): return m[k] | |
| p = m.get("path") or m.get("filepath") | |
| if p: | |
| fp = p if os.path.isabs(p) else os.path.join(DATA_DIR, p) | |
| if os.path.exists(fp): | |
| try: | |
| with open(fp, "r", encoding="utf-8") as f: | |
| return f.read() | |
| except: | |
| pass | |
| return None | |
| class Pair: | |
| rank: int | |
| idx: int | |
| doc_id: str | |
| title: Optional[str] | |
| score: float | |
| image_path: Optional[str] | |
| text: Optional[str] = None | |
| tscore: Optional[float] = None | |
| iscore: Optional[float] = None | |
| hscore: Optional[float] = None | |
| class ImgHit: | |
| rank: int | |
| idx: int | |
| id: Optional[str] | |
| title: Optional[str] | |
| caption: Optional[str] | |
| score: float | |
| image_path: Optional[str] | |
| def _pair_from_idx(idx: int, score: float, rank: int) -> Pair: | |
| m = TEXT_META[idx] | |
| img_path = IMAGE_META[idx].get("image_path") if idx < len(IMAGE_META) else None | |
| return Pair( | |
| rank=rank, idx=idx, doc_id=m.get("id"), title=m.get("title"), | |
| score=float(score), image_path=img_path, text=_get_meta_text(m) | |
| ) | |
| def _pair_from_image_idx(idx: int, score: float, rank: int) -> ImgHit: | |
| m = IMAGE_META[idx] | |
| return ImgHit( | |
| rank=rank, idx=idx, id=m.get("id"), | |
| title=m.get("title") or m.get("name"), | |
| caption=m.get("caption") or m.get("alt"), | |
| score=float(score), | |
| image_path=m.get("image_path"), | |
| ) | |
| # ---------------------- retrieval funcs ---------------------- | |
| def search_text(q: str, topk: int = 10) -> List[Pair]: | |
| q = normalize_fa(q) | |
| qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32") | |
| D, I = T_INDEX.search(qv, topk) | |
| out = [] | |
| for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1): | |
| if i < 0: continue | |
| out.append(_pair_from_idx(i, s, r)) | |
| return out | |
| def search_image(img: Image.Image, topk: int = 10) -> List[Pair]: | |
| inputs = clip_proc(images=[img.convert("RGB")], return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| qv = clip_model.get_image_features(**inputs) | |
| qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32) | |
| D, I = I_INDEX.search(qv, topk) | |
| out = [] | |
| for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1): | |
| if i < 0: continue | |
| out.append(_pair_from_idx(i, s, r)) | |
| return out | |
| def search_image_by_text(q: str, topk: int = 8) -> List[ImgHit]: | |
| q = normalize_fa(q) | |
| inputs = clip_proc(text=[q], return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| qv = clip_model.get_text_features(**inputs) | |
| qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32) | |
| D, I = I_INDEX.search(qv, topk) | |
| out = [] | |
| for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1): | |
| if i < 0: continue | |
| out.append(_pair_from_image_idx(i, s, r)) | |
| return out | |
| def _normalize_scores(score_dict): | |
| if not score_dict: return {} | |
| vals = list(score_dict.values()) | |
| mn, mx = min(vals), max(vals) | |
| if mx - mn < 1e-9: | |
| return {k: 0.5 for k in score_dict} | |
| return {k: (v - mn) / (mx - mn) for k, v in score_dict.items()} | |
| def _topk_dict(D, I): | |
| out = {} | |
| for i, s in zip(I[0].tolist(), D[0].tolist()): | |
| if i >= 0: out[i] = float(s) | |
| return out | |
| def hybrid_search(question: Optional[str], image: Optional[Image.Image], topk: int, alpha_image: float): | |
| # alpha_image in [0,1]: 0 -> pure text ; 1 -> pure image | |
| t_scores = {} | |
| if question and question.strip(): | |
| q = normalize_fa(question) | |
| qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32") | |
| D_t, I_t = T_INDEX.search(qv, max(topk, 20)) | |
| t_scores = _topk_dict(D_t, I_t) | |
| i_scores = {} | |
| if image is not None: | |
| inputs = clip_proc(images=[image.convert("RGB")], return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| qv = clip_model.get_image_features(**inputs) | |
| qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32) | |
| D_i, I_i = I_INDEX.search(qv, max(topk, 20)) | |
| i_scores = _topk_dict(D_i, I_i) | |
| keys = set(t_scores) | set(i_scores) | |
| tN = _normalize_scores(t_scores) | |
| iN = _normalize_scores(i_scores) | |
| hybrid = {k: (1.0 - alpha_image) * tN.get(k, 0.0) + alpha_image * iN.get(k, 0.0) for k in keys} | |
| sorted_idxs = sorted(hybrid.items(), key=lambda kv: kv[1], reverse=True)[:topk] | |
| pairs = [] | |
| for r, (idx, h) in enumerate(sorted_idxs, start=1): | |
| m = TEXT_META[idx] | |
| img_path = IMAGE_META[idx].get("image_path") if idx < len(IMAGE_META) else None | |
| pairs.append(Pair( | |
| rank=r, idx=idx, doc_id=m.get("id"), title=m.get("title"), | |
| score=h, image_path=img_path, text=_get_meta_text(m), | |
| tscore=t_scores.get(idx), iscore=i_scores.get(idx), hscore=h | |
| )) | |
| return pairs | |
| # ---------------------- LLM prompt & call ---------------------- | |
| def build_prompt(question: str, ctx: List[Pair]) -> str: | |
| lines = [ | |
| "از زمینهٔ زیر استفاده کن و به فارسی پاسخ بده. اگر پاسخ در زمینه نبود، بگو «نمیدانم».", | |
| "", | |
| "### زمینه:" | |
| ] | |
| for p in ctx: | |
| snippet = _truncate(p.text or "") | |
| lines.append( | |
| f"- عنوان: {p.title or '—'} (id={p.doc_id}, score={p.hscore if p.hscore is not None else p.score:.3f})\n" | |
| f" متن: {snippet if snippet else '—'}" | |
| ) | |
| lines.append(f"\n### پرسش: {question}\n### پاسخ:") | |
| return "\n".join(lines) | |
| def call_llm(prompt: str) -> str: | |
| if client is None: | |
| return "(LLM not configured)\n\n" + prompt | |
| # Prefer chat (Gemma IT & many instruct models are conversational) | |
| try: | |
| resp = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": ( | |
| "You are a helpful assistant. Use the provided context to answer in Persian; " | |
| "if it's not in the context, say you don't know." | |
| )}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=256, | |
| temperature=0.2, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e_chat: | |
| # Fallback to text-generation if the model supports it | |
| try: | |
| out = client.text_generation( | |
| prompt=prompt, | |
| max_new_tokens=256, | |
| temperature=0.2, | |
| do_sample=True, | |
| ) | |
| return out.strip() | |
| except Exception as e_text: | |
| return f"(LLM error: {e_chat} / {e_text})\n\n" + prompt | |
| # ---------------------- gallery helpers ---------------------- | |
| def display_gallery_pairs(pairs): | |
| items = [] | |
| for p in pairs: | |
| if not p.image_path: | |
| continue | |
| local_path = _resolve_path(p.image_path) | |
| if os.path.exists(local_path): | |
| img = _open_image_safe(local_path) | |
| if img is not None: | |
| caption = f"#{p.rank} — {p.title or ''}\nscore={(p.hscore if p.hscore is not None else p.score):.3f}" | |
| items.append((img, caption)) # PIL image instead of path | |
| return items | |
| def display_gallery_images(img_hits): | |
| items = [] | |
| for h in img_hits: | |
| if not h.image_path: | |
| continue | |
| local_path = _resolve_path(h.image_path) | |
| if os.path.exists(local_path): | |
| img = _open_image_safe(local_path) | |
| if img is not None: | |
| caption = f"#{h.rank} — {h.title or ''}\nscore={h.score:.3f}" | |
| items.append((img, caption)) # PIL image instead of path | |
| return items | |
| # ---------------------- main app logic ---------------------- | |
| def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, use_image: bool, alpha_image: float = 0.5): | |
| # HYBRID when an image is provided + checkbox is on; else text-only | |
| if use_image and image is not None: | |
| top_pairs = hybrid_search(question, image, topk=topk, alpha_image=alpha_image) | |
| else: | |
| top_pairs = search_text(question, topk=topk) | |
| # LLM | |
| ctx = top_pairs[:max(1, k_ctx)] | |
| prompt = build_prompt(question, ctx) | |
| gen = call_llm(prompt) | |
| # Gallery | |
| gallery = display_gallery_pairs(top_pairs) | |
| if not gallery and not (use_image and image is not None): | |
| # text-only path: still try text->image to show visuals | |
| img_hits = search_image_by_text(question, topk=min(8, topk)) | |
| gallery = display_gallery_images(img_hits) | |
| top_image = gallery[0][0] if gallery else None | |
| # Table | |
| def fmt(x): return "—" if x is None else f"{x:.3f}" | |
| table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs] | |
| return gen, table, gallery, top_image | |
| # ---------------------- UI ---------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🍜 RAG (متن + تصویر) — Hybrid Retrieval + Persian LLM") | |
| with gr.Row(): | |
| q = gr.Textbox(label="پرسش (Question)", placeholder="مثلاً: طرز تهیه هویج پلو") | |
| img = gr.Image(label="تصویر اختیاری (Optional image)", type="pil") | |
| with gr.Row(): | |
| topk = gr.Slider(1, 20, value=10, step=1, label="Top-K") | |
| kctx = gr.Slider(1, 10, value=4, step=1, label="K متن زمینه برای LLM") | |
| use_img = gr.Checkbox(label="Hybrid (از تصویر هم استفاده شود؟)", value=False) | |
| alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="وزن تصویر در Hybrid") | |
| btn = gr.Button("اجرا (Run)") | |
| out_text = gr.Textbox(label="پاسخ (Answer)") | |
| out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval") | |
| out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240) | |
| out_img_top = gr.Image(label="Top image match") | |
| btn.click( | |
| answer, | |
| inputs=[q, img, topk, kctx, use_img, alpha], | |
| outputs=[out_text, out_table, out_gallery, out_img_top] | |
| ) | |
| ALLOWED = [ | |
| DATA_DIR, | |
| os.path.join(DATA_DIR, "data"), | |
| os.path.join(DATA_DIR, "data", "interim"), | |
| os.path.join(DATA_DIR, "data", "interim", "images_cache"), | |
| ] | |
| if __name__ == "__main__": | |
| demo.launch(allowed_paths=[os.path.realpath(p) for p in ALLOWED]) | |