mamathew commited on
Commit
3d1e402
·
verified ·
1 Parent(s): 70c1c99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -72
app.py CHANGED
@@ -1,6 +1,5 @@
 
1
 
2
- # Simple HF Space to test your RAG + image/text search with your Hub models.
3
- # Move this file (and requirements.txt + README.md) into a new Space.
4
  import os, json
5
  from dataclasses import dataclass
6
  from typing import List, Optional, Tuple
@@ -15,24 +14,20 @@ from sentence_transformers import SentenceTransformer
15
  import torch
16
  from transformers import CLIPModel, CLIPProcessor
17
 
18
- # ========== CONFIG (edit to your repos) ==========
19
  TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "<your-username>/text-ft-food-rag")
20
  CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "<your-username>/clip-ft-food-rag")
21
  DATASET_REPO = os.environ.get("DATASET_REPO", "<your-username>/food-rag-index")
22
- # LLM via Inference API (set HF_TOKEN in Space secrets). Change to your preferred instruct model.
23
- LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
24
 
25
- # =================================================
 
 
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # ---- Download dataset snapshot (FAISS + metas + optionally images/) ----
30
  DATA_DIR = snapshot_download(repo_id=DATASET_REPO, repo_type="dataset")
31
 
32
- # Expected files inside DATA_DIR:
33
- # faiss_text.bin, faiss_image.bin, text_meta.jsonl, image_meta.jsonl
34
- # images/ (optional) if you want to show pictures next to results
35
-
36
  def read_jsonl(path: str):
37
  out = []
38
  with open(path, "r", encoding="utf-8") as f:
@@ -53,34 +48,30 @@ text_enc = SentenceTransformer(TEXT_MODEL_REPO, device=DEVICE)
53
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_REPO).to(DEVICE)
54
  clip_proc = CLIPProcessor.from_pretrained(CLIP_MODEL_REPO)
55
 
56
- # Optional: LLM via HF Inference API (so Spaces don't need to run an LLM locally)
57
  try:
58
  from huggingface_hub import InferenceClient
59
- HF_TOKEN = os.environ.get("HF_TOKEN") # set this in Space -> Settings -> Repository secrets
60
- client = InferenceClient(model=LLM_ID, token=HF_TOKEN)
61
- except Exception as e:
62
  client = None
63
 
64
- @dataclass
65
- class Pair:
66
- rank: int
67
- idx: int
68
- doc_id: str
69
- title: Optional[str]
70
- score: float
71
- image_path: Optional[str]
72
- text: Optional[str] = None # <-- NEW
73
 
 
 
 
 
 
 
 
 
74
 
75
  def _get_meta_text(m: dict) -> Optional[str]:
76
- # Try common keys first
77
- for k in ("text", "content", "passage", "body", "chunk", "article"):
78
- if m.get(k):
79
- return m[k]
80
- # If you stored a local file path for the text, read it
81
  p = m.get("path") or m.get("filepath")
82
  if p:
83
- import os
84
  fp = p if os.path.isabs(p) else os.path.join(DATA_DIR, p)
85
  if os.path.exists(fp):
86
  try:
@@ -90,26 +81,51 @@ def _get_meta_text(m: dict) -> Optional[str]:
90
  pass
91
  return None
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def _pair_from_idx(idx: int, score: float, rank: int) -> Pair:
94
  m = TEXT_META[idx]
95
- img_path = IMAGE_META[idx].get("image_path")
96
  return Pair(
97
- rank=rank,
98
- idx=idx,
99
- doc_id=m.get("id"),
100
- title=m.get("title"),
 
 
 
 
 
 
101
  score=float(score),
102
- image_path=img_path,
103
- text=_get_meta_text(m), # <-- NEW
104
  )
105
 
106
- def _truncate(s: str, max_chars: int = 1200) -> str:
107
- if not s: return ""
108
- s = s.strip().replace("\r", " ")
109
- return s[:max_chars]
110
 
111
-
112
  def search_text(q: str, topk: int = 10) -> List[Pair]:
 
113
  qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32")
114
  D, I = T_INDEX.search(qv, topk)
115
  out = []
@@ -130,30 +146,94 @@ def search_image(img: Image.Image, topk: int = 10) -> List[Pair]:
130
  out.append(_pair_from_idx(i, s, r))
131
  return out
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def build_prompt(question: str, ctx: List[Pair]) -> str:
134
  lines = [
135
  "از زمینهٔ زیر استفاده کن و به فارسی پاسخ بده. اگر پاسخ در زمینه نبود، بگو «نمی‌دانم».",
136
  "",
137
- "### زمینه:",
138
  ]
139
  for p in ctx:
140
  snippet = _truncate(p.text or "")
141
  lines.append(
142
- f"- عنوان: {p.title or '—'} (id={p.doc_id}, score={p.score:.3f})\n"
143
  f" متن: {snippet if snippet else '—'}"
144
  )
145
  lines.append(f"\n### پرسش: {question}\n### پاسخ:")
146
  return "\n".join(lines)
147
 
148
  def call_llm(prompt: str) -> str:
149
- # prompt already includes your Context + Question text
150
  if client is None:
151
  return "(LLM not configured)\n\n" + prompt
 
152
  try:
153
  resp = client.chat_completion(
154
  messages=[
155
  {"role": "system", "content": (
156
- "You are a helpful assistant. Use the provided context to answer in Persian language; "
157
  "if it's not in the context, say you don't know."
158
  )},
159
  {"role": "user", "content": prompt},
@@ -162,45 +242,96 @@ def call_llm(prompt: str) -> str:
162
  temperature=0.2,
163
  )
164
  return resp.choices[0].message.content.strip()
165
- except Exception as e:
166
- return f"(LLM error: {e})\n\n" + prompt
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- def display_gallery(pairs: List[Pair]) -> List[Tuple[str, str]]:
169
- # Return [(image_path, caption), ...] for Gradio Gallery. Works if images/ folder is included.
170
  items = []
171
  for p in pairs:
172
- if p.image_path:
173
- local_path = os.path.join(DATA_DIR, p.image_path) if not os.path.isabs(p.image_path) else p.image_path
174
- if os.path.exists(local_path):
175
- caption = f"#{p.rank} — {p.title or ''}\nscore={p.score:.3f}"
176
- items.append((local_path, caption))
 
 
 
 
 
 
 
 
 
 
177
  return items
178
 
179
- def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, use_image: bool):
 
 
 
180
  if use_image and image is not None:
181
- top = search_image(image, topk=topk)
182
  else:
183
- top = search_text(question, topk=topk)
184
- ctx = top[:max(1, k_ctx)]
 
 
185
  prompt = build_prompt(question, ctx)
186
  gen = call_llm(prompt)
187
- gal = display_gallery(top)
188
- return gen, [[p.rank, p.title or "", f"{p.score:.3f}", p.doc_id] for p in top], gal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  with gr.Blocks() as demo:
191
- gr.Markdown("# 🍜 Food RAG Demo (text+image search)")
 
192
  with gr.Row():
193
- q = gr.Textbox(label="Question", placeholder="Ask something about a dish, ingredient, etc.")
194
- img = gr.Image(label="Optional image", type="pil")
 
195
  with gr.Row():
196
- topk = gr.Slider(1, 20, value=10, step=1, label="Top-K search")
197
- kctx = gr.Slider(1, 10, value=4, step=1, label="K context to LLM")
198
- use_img = gr.Checkbox(label="Use image for search", value=False)
199
- btn = gr.Button("Run")
200
- out_text = gr.Textbox(label="Answer")
201
- out_table = gr.Dataframe(headers=["Rank", "Title", "Score", "Doc ID"], label="Top-K retrieval")
202
- out_gallery = gr.Gallery(label="Matches (if images available)", columns=5, height=200)
203
- btn.click(answer, inputs=[q, img, topk, kctx, use_img], outputs=[out_text, out_table, out_gallery])
 
 
 
 
 
 
 
 
204
 
205
  if __name__ == "__main__":
206
  demo.launch()
 
1
+ # app.py — HF Space: hybrid text+image RAG demo (Persian-ready)
2
 
 
 
3
  import os, json
4
  from dataclasses import dataclass
5
  from typing import List, Optional, Tuple
 
14
  import torch
15
  from transformers import CLIPModel, CLIPProcessor
16
 
17
+ # ========= CONFIG (override in Space → Settings → Variables) =========
18
  TEXT_MODEL_REPO = os.environ.get("TEXT_MODEL_REPO", "<your-username>/text-ft-food-rag")
19
  CLIP_MODEL_REPO = os.environ.get("CLIP_MODEL_REPO", "<your-username>/clip-ft-food-rag")
20
  DATASET_REPO = os.environ.get("DATASET_REPO", "<your-username>/food-rag-index")
 
 
21
 
22
+ # Inference API chat model (Gemma IT by default).
23
+ LLM_ID = os.environ.get("LLM_ID", "google/gemma-2-2b-it")
24
+ # =====================================================================
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
+ # ---- dataset snapshot (FAISS + metas + optionally images/) ----
29
  DATA_DIR = snapshot_download(repo_id=DATASET_REPO, repo_type="dataset")
30
 
 
 
 
 
31
  def read_jsonl(path: str):
32
  out = []
33
  with open(path, "r", encoding="utf-8") as f:
 
48
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_REPO).to(DEVICE)
49
  clip_proc = CLIPProcessor.from_pretrained(CLIP_MODEL_REPO)
50
 
51
+ # Inference API client (chat-first, with fallback)
52
  try:
53
  from huggingface_hub import InferenceClient
54
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space Settings Repository secrets
55
+ client = InferenceClient(model=LLM_ID, token=HF_TOKEN) if HF_TOKEN else InferenceClient(model=LLM_ID)
56
+ except Exception:
57
  client = None
58
 
59
+ # ---------------------- utils & dataclasses ----------------------
 
 
 
 
 
 
 
 
60
 
61
+ def normalize_fa(s: str) -> str:
62
+ if not s: return s
63
+ return (s.replace("ي","ی").replace("ك","ک").replace("\u200c"," ").strip())
64
+
65
+ def _truncate(s: str, max_chars: int = 1200) -> str:
66
+ if not s: return ""
67
+ s = s.strip().replace("\r", " ")
68
+ return s[:max_chars]
69
 
70
  def _get_meta_text(m: dict) -> Optional[str]:
71
+ for k in ("text","content","passage","body","chunk","article","description"):
72
+ if m.get(k): return m[k]
 
 
 
73
  p = m.get("path") or m.get("filepath")
74
  if p:
 
75
  fp = p if os.path.isabs(p) else os.path.join(DATA_DIR, p)
76
  if os.path.exists(fp):
77
  try:
 
81
  pass
82
  return None
83
 
84
+ @dataclass
85
+ class Pair:
86
+ rank: int
87
+ idx: int
88
+ doc_id: str
89
+ title: Optional[str]
90
+ score: float
91
+ image_path: Optional[str]
92
+ text: Optional[str] = None
93
+ tscore: Optional[float] = None
94
+ iscore: Optional[float] = None
95
+ hscore: Optional[float] = None
96
+
97
+ @dataclass
98
+ class ImgHit:
99
+ rank: int
100
+ idx: int
101
+ id: Optional[str]
102
+ title: Optional[str]
103
+ caption: Optional[str]
104
+ score: float
105
+ image_path: Optional[str]
106
+
107
  def _pair_from_idx(idx: int, score: float, rank: int) -> Pair:
108
  m = TEXT_META[idx]
109
+ img_path = IMAGE_META[idx].get("image_path") if idx < len(IMAGE_META) else None
110
  return Pair(
111
+ rank=rank, idx=idx, doc_id=m.get("id"), title=m.get("title"),
112
+ score=float(score), image_path=img_path, text=_get_meta_text(m)
113
+ )
114
+
115
+ def _pair_from_image_idx(idx: int, score: float, rank: int) -> ImgHit:
116
+ m = IMAGE_META[idx]
117
+ return ImgHit(
118
+ rank=rank, idx=idx, id=m.get("id"),
119
+ title=m.get("title") or m.get("name"),
120
+ caption=m.get("caption") or m.get("alt"),
121
  score=float(score),
122
+ image_path=m.get("image_path"),
 
123
  )
124
 
125
+ # ---------------------- retrieval funcs ----------------------
 
 
 
126
 
 
127
  def search_text(q: str, topk: int = 10) -> List[Pair]:
128
+ q = normalize_fa(q)
129
  qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32")
130
  D, I = T_INDEX.search(qv, topk)
131
  out = []
 
146
  out.append(_pair_from_idx(i, s, r))
147
  return out
148
 
149
+ def search_image_by_text(q: str, topk: int = 8) -> List[ImgHit]:
150
+ q = normalize_fa(q)
151
+ inputs = clip_proc(text=[q], return_tensors="pt").to(DEVICE)
152
+ with torch.no_grad():
153
+ qv = clip_model.get_text_features(**inputs)
154
+ qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32)
155
+ D, I = I_INDEX.search(qv, topk)
156
+ out = []
157
+ for r, (i, s) in enumerate(zip(I[0].tolist(), D[0].tolist()), start=1):
158
+ if i < 0: continue
159
+ out.append(_pair_from_image_idx(i, s, r))
160
+ return out
161
+
162
+ def _normalize_scores(score_dict):
163
+ if not score_dict: return {}
164
+ vals = list(score_dict.values())
165
+ mn, mx = min(vals), max(vals)
166
+ if mx - mn < 1e-9:
167
+ return {k: 0.5 for k in score_dict}
168
+ return {k: (v - mn) / (mx - mn) for k, v in score_dict.items()}
169
+
170
+ def _topk_dict(D, I):
171
+ out = {}
172
+ for i, s in zip(I[0].tolist(), D[0].tolist()):
173
+ if i >= 0: out[i] = float(s)
174
+ return out
175
+
176
+ def hybrid_search(question: Optional[str], image: Optional[Image.Image], topk: int, alpha_image: float):
177
+ # alpha_image in [0,1]: 0 -> pure text ; 1 -> pure image
178
+ t_scores = {}
179
+ if question and question.strip():
180
+ q = normalize_fa(question)
181
+ qv = text_enc.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype("float32")
182
+ D_t, I_t = T_INDEX.search(qv, max(topk, 20))
183
+ t_scores = _topk_dict(D_t, I_t)
184
+
185
+ i_scores = {}
186
+ if image is not None:
187
+ inputs = clip_proc(images=[image.convert("RGB")], return_tensors="pt").to(DEVICE)
188
+ with torch.no_grad():
189
+ qv = clip_model.get_image_features(**inputs)
190
+ qv = torch.nn.functional.normalize(qv, dim=1).float().cpu().numpy().astype(np.float32)
191
+ D_i, I_i = I_INDEX.search(qv, max(topk, 20))
192
+ i_scores = _topk_dict(D_i, I_i)
193
+
194
+ keys = set(t_scores) | set(i_scores)
195
+ tN = _normalize_scores(t_scores)
196
+ iN = _normalize_scores(i_scores)
197
+ hybrid = {k: (1.0 - alpha_image) * tN.get(k, 0.0) + alpha_image * iN.get(k, 0.0) for k in keys}
198
+
199
+ sorted_idxs = sorted(hybrid.items(), key=lambda kv: kv[1], reverse=True)[:topk]
200
+ pairs = []
201
+ for r, (idx, h) in enumerate(sorted_idxs, start=1):
202
+ m = TEXT_META[idx]
203
+ img_path = IMAGE_META[idx].get("image_path") if idx < len(IMAGE_META) else None
204
+ pairs.append(Pair(
205
+ rank=r, idx=idx, doc_id=m.get("id"), title=m.get("title"),
206
+ score=h, image_path=img_path, text=_get_meta_text(m),
207
+ tscore=t_scores.get(idx), iscore=i_scores.get(idx), hscore=h
208
+ ))
209
+ return pairs
210
+
211
+ # ---------------------- LLM prompt & call ----------------------
212
+
213
  def build_prompt(question: str, ctx: List[Pair]) -> str:
214
  lines = [
215
  "از زمینهٔ زیر استفاده کن و به فارسی پاسخ بده. اگر پاسخ در زمینه نبود، بگو «نمی‌دانم».",
216
  "",
217
+ "### زمینه:"
218
  ]
219
  for p in ctx:
220
  snippet = _truncate(p.text or "")
221
  lines.append(
222
+ f"- عنوان: {p.title or '—'} (id={p.doc_id}, score={p.hscore if p.hscore is not None else p.score:.3f})\n"
223
  f" متن: {snippet if snippet else '—'}"
224
  )
225
  lines.append(f"\n### پرسش: {question}\n### پاسخ:")
226
  return "\n".join(lines)
227
 
228
  def call_llm(prompt: str) -> str:
 
229
  if client is None:
230
  return "(LLM not configured)\n\n" + prompt
231
+ # Prefer chat (Gemma IT & many instruct models are conversational)
232
  try:
233
  resp = client.chat_completion(
234
  messages=[
235
  {"role": "system", "content": (
236
+ "You are a helpful assistant. Use the provided context to answer in Persian; "
237
  "if it's not in the context, say you don't know."
238
  )},
239
  {"role": "user", "content": prompt},
 
242
  temperature=0.2,
243
  )
244
  return resp.choices[0].message.content.strip()
245
+ except Exception as e_chat:
246
+ # Fallback to text-generation if the model supports it
247
+ try:
248
+ out = client.text_generation(
249
+ prompt=prompt,
250
+ max_new_tokens=256,
251
+ temperature=0.2,
252
+ do_sample=True,
253
+ )
254
+ return out.strip()
255
+ except Exception as e_text:
256
+ return f"(LLM error: {e_chat} / {e_text})\n\n" + prompt
257
+
258
+ # ---------------------- gallery helpers ----------------------
259
 
260
+ def display_gallery_pairs(pairs: List[Pair]) -> List[Tuple[str, str]]:
 
261
  items = []
262
  for p in pairs:
263
+ if not p.image_path: continue
264
+ local_path = os.path.join(DATA_DIR, p.image_path) if not os.path.isabs(p.image_path) else p.image_path
265
+ if os.path.exists(local_path):
266
+ caption = f"#{p.rank} — {p.title or ''}\nscore={(p.hscore if p.hscore is not None else p.score):.3f}"
267
+ items.append((local_path, caption))
268
+ return items
269
+
270
+ def display_gallery_images(img_hits: List[ImgHit]) -> List[Tuple[str, str]]:
271
+ items = []
272
+ for h in img_hits:
273
+ if not h.image_path: continue
274
+ local_path = os.path.join(DATA_DIR, h.image_path) if not os.path.isabs(h.image_path) else h.image_path
275
+ if os.path.exists(local_path):
276
+ caption = f"#{h.rank} — {h.title or ''}\nscore={h.score:.3f}"
277
+ items.append((local_path, caption))
278
  return items
279
 
280
+ # ---------------------- main app logic ----------------------
281
+
282
+ def answer(question: str, image: Optional[Image.Image], topk: int, k_ctx: int, use_image: bool, alpha_image: float = 0.5):
283
+ # HYBRID when an image is provided + checkbox is on; else text-only
284
  if use_image and image is not None:
285
+ top_pairs = hybrid_search(question, image, topk=topk, alpha_image=alpha_image)
286
  else:
287
+ top_pairs = search_text(question, topk=topk)
288
+
289
+ # LLM
290
+ ctx = top_pairs[:max(1, k_ctx)]
291
  prompt = build_prompt(question, ctx)
292
  gen = call_llm(prompt)
293
+
294
+ # Gallery
295
+ gallery = display_gallery_pairs(top_pairs)
296
+ if not gallery and not (use_image and image is not None):
297
+ # text-only path: still try text->image to show visuals
298
+ img_hits = search_image_by_text(question, topk=min(8, topk))
299
+ gallery = display_gallery_images(img_hits)
300
+
301
+ top_image_path = gallery[0][0] if gallery else None
302
+
303
+ # Table
304
+ def fmt(x): return "—" if x is None else f"{x:.3f}"
305
+ table = [[p.rank, p.title or "", fmt(p.tscore), fmt(p.iscore), fmt(p.hscore or p.score), p.doc_id] for p in top_pairs]
306
+
307
+ return gen, table, gallery, top_image_path
308
+
309
+ # ---------------------- UI ----------------------
310
 
311
  with gr.Blocks() as demo:
312
+ gr.Markdown("# 🍜 RAG (متن + تصویر) — Hybrid Retrieval + Persian LLM")
313
+
314
  with gr.Row():
315
+ q = gr.Textbox(label="پرسش (Question)", placeholder="مثلاً: طرز تهیه هویج پلو")
316
+ img = gr.Image(label="تصویر اختیاری (Optional image)", type="pil")
317
+
318
  with gr.Row():
319
+ topk = gr.Slider(1, 20, value=10, step=1, label="Top-K")
320
+ kctx = gr.Slider(1, 10, value=4, step=1, label="K متن زمینه برای LLM")
321
+ use_img = gr.Checkbox(label="Hybrid (از تصویر هم استفاده شود؟)", value=False)
322
+ alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="وزن تصویر در Hybrid")
323
+
324
+ btn = gr.Button("اجرا (Run)")
325
+ out_text = gr.Textbox(label="پاسخ (Answer)")
326
+ out_table = gr.Dataframe(headers=["Rank", "Title", "Text S", "Image S", "Hybrid S", "Doc ID"], label="Top-K retrieval")
327
+ out_gallery = gr.Gallery(label="تصاویر مرتبط (Image matches)", columns=5, height=240)
328
+ out_img_top = gr.Image(label="بهترین تصویر")
329
+
330
+ btn.click(
331
+ answer,
332
+ inputs=[q, img, topk, kctx, use_img, alpha],
333
+ outputs=[out_text, out_table, out_gallery, out_img_top]
334
+ )
335
 
336
  if __name__ == "__main__":
337
  demo.launch()