File size: 7,877 Bytes
6a2c742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9c8d4f
 
6a2c742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c1c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2c742
 
 
 
70c1c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2c742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70c1c99
 
 
 
 
6a2c742
70c1c99
 
 
 
 
 
6a2c742
 
 
d9c8d4f
6a2c742
 
 
d9c8d4f
 
 
 
 
 
 
 
 
 
 
 
6a2c742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

# Simple HF Space to test your RAG + image/text search with your Hub models.
# Move this file (and requirements.txt + README.md) into a new Space.
import os, json
from dataclasses import dataclass
from typing import List, Optional, Tuple

import gradio as gr
import numpy as np
import faiss
from PIL import Image

from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
import torch
from transformers import CLIPModel, CLIPProcessor

# ========== CONFIG (edit to your repos) ==========
TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "<your-username>/text-ft-food-rag")
CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "<your-username>/clip-ft-food-rag")
DATASET_REPO    = os.environ.get("DATASET_REPO",    "<your-username>/food-rag-index")
# LLM via Inference API (set HF_TOKEN in Space secrets). Change to your preferred instruct model.
LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")

# =================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Download dataset snapshot (FAISS + metas + optionally images/) ----
DATA_DIR = snapshot_download(repo_id=DATASET_REPO, repo_type="dataset")

# Expected files inside DATA_DIR:
#   faiss_text.bin, faiss_image.bin, text_meta.jsonl, image_meta.jsonl
#   images/ (optional) if you want to show pictures next to results

def read_jsonl(path: str):
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                out.append(json.loads(line))
    return out

# Load metas & FAISS
TEXT_META  = read_jsonl(os.path.join(DATA_DIR, "text_meta.jsonl"))
IMAGE_META = read_jsonl(os.path.join(DATA_DIR, "image_meta.jsonl"))
T_INDEX = faiss.read_index(os.path.join(DATA_DIR, "faiss_text.bin"))
I_INDEX = faiss.read_index(os.path.join(DATA_DIR, "faiss_image.bin"))

# Load encoders
text_enc = SentenceTransformer(TEXT_MODEL_REPO, device=DEVICE)
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_REPO).to(DEVICE)
clip_proc  = CLIPProcessor.from_pretrained(CLIP_MODEL_REPO)

# Optional: LLM via HF Inference API (so Spaces don't need to run an LLM locally)
try:
    from huggingface_hub import InferenceClient
    HF_TOKEN = os.environ.get("HF_TOKEN")  # set this in Space -> Settings -> Repository secrets
    client = InferenceClient(model=LLM_ID, token=HF_TOKEN)
except Exception as e:
    client = None

@dataclass
class Pair:
    rank: int
    idx: int
    doc_id: str
    title: Optional[str]
    score: float
    image_path: Optional[str]
    text: Optional[str] = None  # <-- NEW


def _get_meta_text(m: dict) -> Optional[str]:
    # Try common keys first
    for k in ("text", "content", "passage", "body", "chunk", "article"):
        if m.get(k):
            return m[k]
    # If you stored a local file path for the text, read it
    p = m.get("path") or m.get("filepath")
    if p:
        import os
        fp = p if os.path.isabs(p) else os.path.join(DATA_DIR, p)
        if os.path.exists(fp):
            try:
                with open(fp, "r", encoding="utf-8") as f:
                    return f.read()
            except:
                pass
    return None

def _pair_from_idx(idx: int, score: float, rank: int) -> Pair:
    m = TEXT_META[idx]
    img_path = IMAGE_META[idx].get("image_path")
    return Pair(
        rank=rank,
        idx=idx,
        doc_id=m.get("id"),
        title=m.get("title"),
        score=float(score),
        image_path=img_path,
        text=_get_meta_text(m),  # <-- NEW
    )

def _truncate(s: str, max_chars: int = 1200) -> str:
    if not s: return ""
    s = s.strip().replace("\r", " ")
    return s[:max_chars]

    
def search_text(q: str, topk: int = 10) -> List[Pair]:
    qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32")
    D, I = T_INDEX.search(qv, topk)
    out = []
    for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1):
        if i < 0: continue
        out.append(_pair_from_idx(i, s, r))
    return out

def search_image(img: Image.Image, topk: int = 10) -> List[Pair]:
    inputs = clip_proc(images=[img.convert("RGB")], return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        qv = clip_model.get_image_features(**inputs)
        qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32)
    D, I = I_INDEX.search(qv, topk)
    out = []
    for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1):
        if i < 0: continue
        out.append(_pair_from_idx(i, s, r))
    return out

def build_prompt(question: str, ctx: List[Pair]) -> str:
    lines = [
        "از زمینهٔ زیر استفاده کن و به فارسی پاسخ بده. اگر پاسخ در زمینه نبود، بگو «نمی‌دانم».",
        "",
        "### زمینه:",
    ]
    for p in ctx:
        snippet = _truncate(p.text or "")
        lines.append(
            f"- عنوان: {p.title or '—'} (id={p.doc_id}, score={p.score:.3f})\n"
            f"  متن: {snippet if snippet else '—'}"
        )
    lines.append(f"\n### پرسش: {question}\n### پاسخ:")
    return "\n".join(lines)

def call_llm(prompt: str) -> str:
    # prompt already includes your Context + Question text
    if client is None:
        return "(LLM not configured)\n\n" + prompt
    try:
        resp = client.chat_completion(
            messages=[
                {"role": "system", "content": (
                    "You are a helpful assistant. Use the provided context to answer in Persian language; "
                    "if it's not in the context, say you don't know."
                )},
                {"role": "user", "content": prompt},
            ],
            max_tokens=256,
            temperature=0.2,
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        return f"(LLM error: {e})\n\n" + prompt

def display_gallery(pairs: List[Pair]) -> List[Tuple[str, str]]:
    # Return [(image_path, caption), ...] for Gradio Gallery. Works if images/ folder is included.
    items = []
    for p in pairs:
        if p.image_path:
            local_path = os.path.join(DATA_DIR, p.image_path) if not os.path.isabs(p.image_path) else p.image_path
            if os.path.exists(local_path):
                caption = f"#{p.rank}{p.title or ''}\nscore={p.score:.3f}"
                items.append((local_path, caption))
    return items

def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, use_image: bool):
    if use_image and image is not None:
        top = search_image(image, topk=topk)
    else:
        top = search_text(question, topk=topk)
    ctx = top[:max(1, k_ctx)]
    prompt = build_prompt(question, ctx)
    gen = call_llm(prompt)
    gal = display_gallery(top)
    return gen, [[p.rank, p.title or "", f"{p.score:.3f}", p.doc_id] for p in top], gal

with gr.Blocks() as demo:
    gr.Markdown("# 🍜 Food RAG Demo (text+image search)")
    with gr.Row():
        q = gr.Textbox(label="Question", placeholder="Ask something about a dish, ingredient, etc.")
        img = gr.Image(label="Optional image", type="pil")
    with gr.Row():
        topk = gr.Slider(1, 20, value=10, step=1, label="Top-K search")
        kctx = gr.Slider(1, 10, value=4, step=1, label="K context to LLM")
        use_img = gr.Checkbox(label="Use image for search", value=False)
    btn = gr.Button("Run")
    out_text = gr.Textbox(label="Answer")
    out_table = gr.Dataframe(headers=["Rank", "Title", "Score", "Doc ID"], label="Top-K retrieval")
    out_gallery = gr.Gallery(label="Matches (if images available)", columns=5, height=200)
    btn.click(answer, inputs=[q, img, topk, kctx, use_img], outputs=[out_text, out_table, out_gallery])

if __name__ == "__main__":
    demo.launch()