amirhossein mohammadpour commited on
Commit
3f6908f
·
1 Parent(s): 91d84fa

Add app and deps

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/deployment.xml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
4
+ <serverData>
5
+ <paths name="social@95.216.162.70:22 password">
6
+ <serverdata>
7
+ <mappings>
8
+ <mapping local="$PROJECT_DIR$" web="/" />
9
+ </mappings>
10
+ </serverdata>
11
+ </paths>
12
+ <paths name="social@95.216.162.70:22 password (2)">
13
+ <serverdata>
14
+ <mappings>
15
+ <mapping local="$PROJECT_DIR$" web="/" />
16
+ </mappings>
17
+ </serverdata>
18
+ </paths>
19
+ </serverData>
20
+ </component>
21
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="122" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
10
+ <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
11
+ <option name="ignoredIdentifiers">
12
+ <list>
13
+ <option value="dict.__setitem__" />
14
+ </list>
15
+ </option>
16
+ </inspection_tool>
17
+ </profile>
18
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.11" />
5
+ </component>
6
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/multimodal-rag-demo.iml" filepath="$PROJECT_DIR$/.idea/multimodal-rag-demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/multimodal-rag-demo.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, gc, json, re, ast
2
+ import numpy as np
3
+ import pandas as pd
4
+ import faiss
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import List, Dict, Any
8
+ from PIL import Image, ImageFilter, ImageOps, ImageEnhance
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ # =========================
15
+ # Config (override in Space → Settings → Variables & secrets)
16
+ # =========================
17
+ DATASET_REPO = os.getenv("DATASET_REPO", "ahm1378/NLP-Project") # <--- CHANGE to your repo
18
+ CSV_FILE = os.getenv("CSV_FILE", "final_merged_images.csv")
19
+ E5_INDEX_FILE = os.getenv("E5_INDEX_FILE", "faiss_e5_rag_v15.ip")
20
+ E5_EMB_FILE = os.getenv("E5_EMB_FILE", "doc_embeds_e5_rag_v15.npy")
21
+ FUSION_INDEX_FILE = os.getenv("FUSION_INDEX_FILE", "faiss_fusion.ip")
22
+ FUSION_EMB_FILE = os.getenv("FUSION_EMB_FILE", "fusion_doc_emb.npy")
23
+ FT_HEAD_FILE = os.getenv("FT_HEAD_FILE", "finetune_clip_fa.pt") # your finetuned text projection (CLIP space)
24
+
25
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is private
26
+
27
+ # Models (CPU-friendly defaults; override via env if desired)
28
+ E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
29
+ CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
30
+ LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct") # small enough for free CPU
31
+
32
+ # Generation defaults (also controllable from UI)
33
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "192"))
34
+ TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0")) # deterministic by default on CPU
35
+ TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.9"))
36
+ TOP_K_DEFAULT = int(os.getenv("TOP_K", "50"))
37
+
38
+ # =========================
39
+ # Helpers
40
+ # =========================
41
+ def normalize_digits_months(s: str) -> str:
42
+ if not isinstance(s, str):
43
+ s = str(s)
44
+ trans = str.maketrans("۰۱۲۳۴۵۶۷۸۹٠١٢٣٤٥٦٧٨٩", "01234567890123456789")
45
+ s = s.translate(trans).replace("\u200c", " ").strip()
46
+ return s
47
+
48
+ def _truncate_chars(s: str, limit: int) -> str:
49
+ return s if (limit is None or len(s) <= limit) else s[:limit] + "…"
50
+
51
+ def _maybe_hub(file, repo=DATASET_REPO, repo_type="dataset") -> str:
52
+ # If present locally, use it. Otherwise download from Hub.
53
+ if os.path.isfile(file):
54
+ return file
55
+ return hf_hub_download(repo_id=repo, filename=file, repo_type=repo_type, token=HF_TOKEN)
56
+
57
+ # =========================
58
+ # Fetch artifacts
59
+ # =========================
60
+ CSV_PATH = _maybe_hub(CSV_FILE)
61
+ E5_INDEX_PATH = _maybe_hub(E5_INDEX_FILE)
62
+ # (E5_EMB_PATH not strictly needed at runtime)
63
+ FUSION_INDEX_PATH = _maybe_hub(FUSION_INDEX_FILE) if FUSION_INDEX_FILE else None
64
+ FT_HEAD_PATH = _maybe_hub(FT_HEAD_FILE) if FT_HEAD_FILE else None
65
+
66
+ # =========================
67
+ # Load dataframe
68
+ # =========================
69
+ if not os.path.isfile(CSV_PATH):
70
+ raise FileNotFoundError(f"CSV missing: {CSV_PATH}")
71
+
72
+ df = pd.read_csv(CSV_PATH)
73
+
74
+ # Expect columns: 'id', 'bio', 'image_paths_abs' (list or stringified list)
75
+ def first_image(x):
76
+ if isinstance(x, list) and x:
77
+ return x[0]
78
+ if isinstance(x, str) and x.strip():
79
+ # try JSON list
80
+ try:
81
+ lst = json.loads(x)
82
+ if isinstance(lst, list) and lst:
83
+ return lst[0]
84
+ except Exception:
85
+ # try Python literal list (handles single quotes)
86
+ try:
87
+ lst = ast.literal_eval(x)
88
+ if isinstance(lst, list) and lst:
89
+ return lst[0]
90
+ except Exception:
91
+ return x # treat as single path
92
+ return ""
93
+
94
+ if "image_paths_abs" in df.columns:
95
+ df["first_image"] = df["image_paths_abs"].apply(first_image)
96
+ else:
97
+ df["first_image"] = ""
98
+
99
+ if "bio" not in df.columns:
100
+ raise KeyError("Expected 'bio' column in CSV.")
101
+ df["bio"] = df["bio"].astype(str)
102
+
103
+ # =========================
104
+ # Indices
105
+ # =========================
106
+ if not os.path.isfile(E5_INDEX_PATH):
107
+ raise FileNotFoundError(f"E5 index not found: {E5_INDEX_PATH}")
108
+ index_e5 = faiss.read_index(E5_INDEX_PATH)
109
+
110
+ index_fusion = None
111
+ if FUSION_INDEX_PATH and os.path.isfile(FUSION_INDEX_PATH):
112
+ index_fusion = faiss.read_index(FUSION_INDEX_PATH)
113
+
114
+ # =========================
115
+ # Models (CPU-only)
116
+ # =========================
117
+ device = "cpu"
118
+ dtype = torch.float32
119
+
120
+ # Text retrieval encoder (E5)
121
+ st_e5 = SentenceTransformer(E5_ID, device=device)
122
+
123
+ # CLIP text encoder (fallback when no FT head)
124
+ st_clip_txt = SentenceTransformer(CLIP_TXT_ID, device=device).eval()
125
+
126
+ # Optional: finetuned CLIP text projection head (512->512, bias=False)
127
+ mclip = SentenceTransformer(CLIP_TXT_ID, device=device).eval()
128
+ proj_txt = None
129
+ if FT_HEAD_PATH and os.path.isfile(FT_HEAD_PATH):
130
+ try:
131
+ proj_txt = torch.nn.Linear(512, 512, bias=False)
132
+ ckpt = torch.load(FT_HEAD_PATH, map_location="cpu")
133
+ if "proj_txt" in ckpt:
134
+ proj_txt.load_state_dict(ckpt["proj_txt"])
135
+ elif "state_dict" in ckpt:
136
+ proj_txt.load_state_dict(ckpt["state_dict"])
137
+ else:
138
+ raise KeyError("No 'proj_txt' or 'state_dict' key in FT checkpoint.")
139
+ proj_txt.eval()
140
+ print("[OK] loaded finetuned projection head:", FT_HEAD_PATH)
141
+ except Exception as e:
142
+ print("[WARN] failed to load finetuned head:", e)
143
+ proj_txt = None
144
+
145
+ # Lazy CLIP image encoder (only load if user actually does fusion)
146
+ clip_model = None
147
+ clip_preprocess = None
148
+ def _ensure_clip_loaded():
149
+ global clip_model, clip_preprocess
150
+ if clip_model is None:
151
+ import open_clip # lazy import
152
+ model, _, preprocess_val = open_clip.create_model_and_transforms(
153
+ "ViT-B-32", pretrained="laion2b_s34b_b79k", device="cpu"
154
+ )
155
+ clip_model = model.eval()
156
+ clip_preprocess = preprocess_val
157
+ print("[OK] CLIP ViT-B/32 loaded on CPU")
158
+
159
+ # LLM (small; CPU-friendly)
160
+ tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True)
161
+ model = AutoModelForCausalLM.from_pretrained(
162
+ LLM_ID,
163
+ torch_dtype=dtype,
164
+ ).to("cpu").eval()
165
+
166
+ # =========================
167
+ # Retrieval helpers
168
+ # =========================
169
+ @torch.no_grad()
170
+ def _encode_query_e5(q: str) -> np.ndarray:
171
+ qn = "query: " + normalize_digits_months(q)
172
+ v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
173
+ return v.astype("float32")
174
+
175
+ def _faiss_search(index, q_vec: np.ndarray, k: int):
176
+ if q_vec.ndim == 1:
177
+ q_vec = q_vec[None, :]
178
+ s, I = index.search(q_vec.astype("float32"), k)
179
+ return list(zip(I[0].tolist(), s[0].tolist()))
180
+
181
+ def search_text_rag(query_text: str, k: int = 5):
182
+ q = _encode_query_e5(query_text)
183
+ return _faiss_search(index_e5, q, k)
184
+
185
+ # ---- Fusion (CLIP space) ----
186
+ def _jpeg(img, quality=40):
187
+ buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality, optimize=False)
188
+ buf.seek(0); return Image.open(buf).convert("RGB")
189
+
190
+ def _rand_resized_crop(img, scale=(0.7, 0.9)):
191
+ w,h = img.size; s = np.random.uniform(*scale)
192
+ nw,nh = max(1,int(w*s)), max(1,int(h*s))
193
+ left = np.random.randint(0, max(1, w-nw))
194
+ top = np.random.randint(0, max(1, h-nh))
195
+ return img.crop((left, top, left+nw, top+nh)).resize((w, h), Image.BICUBIC)
196
+
197
+ def _color_jitter(img, b=(0.9,1.1), c=(0.9,1.1)):
198
+ img = ImageOps.autocontrast(img)
199
+ img = ImageEnhance.Brightness(img).enhance(np.random.uniform(*b))
200
+ img = ImageEnhance.Contrast(img).enhance(np.random.uniform(*c))
201
+ return img
202
+
203
+ def augment_once(img: Image.Image, level="medium"):
204
+ if level == "mild":
205
+ img = _rand_resized_crop(img, (0.85, 0.95)); img = _jpeg(img, 60)
206
+ elif level == "medium":
207
+ img = _rand_resized_crop(img, (0.7, 0.9))
208
+ img = img.filter(ImageFilter.GaussianBlur(1.0))
209
+ img = _color_jitter(img, (0.9,1.1), (0.9,1.1)); img = _jpeg(img, 40)
210
+ else:
211
+ img = _rand_resized_crop(img, (0.6, 0.8))
212
+ img = img.filter(ImageFilter.GaussianBlur(1.2)); img = _jpeg(img, 30)
213
+ return img
214
+
215
+ @torch.no_grad()
216
+ def _encode_pil_clip(img: Image.Image) -> np.ndarray:
217
+ _ensure_clip_loaded()
218
+ t = clip_preprocess(img).unsqueeze(0)
219
+ feat = clip_model.encode_image(t)
220
+ feat = F.normalize(feat.float(), dim=-1)
221
+ return feat.cpu().numpy().astype("float32") # (1,512)
222
+
223
+ @torch.no_grad()
224
+ def _encode_query_text_clipspace(q: str) -> np.ndarray:
225
+ qn = normalize_digits_months(q)
226
+ if proj_txt is not None:
227
+ # mclip raw → proj → normalize
228
+ t = torch.tensor(
229
+ mclip.encode([qn], convert_to_numpy=True, normalize_embeddings=False),
230
+ dtype=torch.float32
231
+ )
232
+ t = proj_txt(t)
233
+ t = F.normalize(t, dim=-1).cpu().numpy().astype("float32")
234
+ return t
235
+ else:
236
+ # fallback: CLIP multilingual text encoder (already normalized)
237
+ t = st_clip_txt.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)
238
+ return t.astype("float32")
239
+
240
+ @torch.no_grad()
241
+ def make_query_embed(query_text: str,
242
+ image: Image.Image = None,
243
+ alpha_q: float = 0.7,
244
+ use_aug: bool = True,
245
+ n_aug: int = 3) -> np.ndarray:
246
+ qt = _encode_query_text_clipspace(query_text) # (1,512)
247
+ qi = None
248
+ if image is not None:
249
+ if clip_model is None: # ensure loaded only if needed
250
+ _ensure_clip_loaded()
251
+ if use_aug:
252
+ feats = [ _encode_pil_clip(augment_once(image, "medium")) for _ in range(max(1,int(n_aug))) ]
253
+ qi = np.mean(np.vstack(feats), axis=0, keepdims=True).astype("float32")
254
+ else:
255
+ qi = _encode_pil_clip(image)
256
+ if qi is not None:
257
+ qv = torch.from_numpy(alpha_q*qt + (1.0-alpha_q)*qi)
258
+ qv = F.normalize(qv, dim=-1).cpu().numpy().astype("float32")
259
+ return qv
260
+ return qt
261
+
262
+ def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
263
+ if index_fusion is None:
264
+ raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
265
+ qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=True, n_aug=3)
266
+ return _faiss_search(index_fusion, qv, k)
267
+
268
+ # =========================
269
+ # RAG + LLM
270
+ # =========================
271
+ def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None) -> Dict[str, Any]:
272
+ q = normalize_digits_months(question)
273
+ if (image is not None):
274
+ route = "fusion"
275
+ try:
276
+ hits = search_fusion(q, image=image, k=k)
277
+ except Exception as e:
278
+ route = "text_e5" # graceful fallback
279
+ hits = search_text_rag(q, k=k)
280
+ else:
281
+ route = "text_e5"
282
+ hits = search_text_rag(q, k=k)
283
+
284
+ ctxs = []
285
+ for idx, score in hits:
286
+ if 0 <= idx < len(df):
287
+ row = df.iloc[idx]
288
+ ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
289
+ return {"route": route, "contexts": ctxs}
290
+
291
+ def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
292
+ sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»"
293
+ sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
294
+ system_text = sys_fa if lang == "fa" else sys_en
295
+
296
+ parts = []
297
+ for i, c in enumerate(contexts, 1):
298
+ bi = c["bio"].strip()
299
+ if bi:
300
+ parts.append(f"[{i}] {bi}")
301
+ joined = _truncate_chars("\n\n".join(parts), max_chars)
302
+
303
+ user = (f"سؤال: {question}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
304
+ f"فقط با اتکا به متون بالا پاسخ بده و منابع را با [1], [2], ... ارجاع بده."
305
+ ) if lang == "fa" else (
306
+ f"Question: {question}\n\nRetrieved passages:\n{joined}\n\n"
307
+ f"Answer only using the passages, cite sources as [1], [2], ..."
308
+ )
309
+ msgs = [{"role": "system", "content": system_text},
310
+ {"role": "user", "content": user}]
311
+ return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
312
+
313
+ @torch.inference_mode()
314
+ def llm_generate(prompt: str,
315
+ max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
316
+ temperature=TEMPERATURE_DEFAULT,
317
+ top_p=TOP_P_DEFAULT,
318
+ top_k=TOP_K_DEFAULT,
319
+ do_sample=False) -> str:
320
+ inputs = tokenizer(prompt, return_tensors="pt")
321
+ out = model.generate(
322
+ **inputs,
323
+ max_new_tokens=int(max_new_tokens),
324
+ do_sample=bool(do_sample),
325
+ temperature=float(temperature),
326
+ top_p=float(top_p),
327
+ top_k=int(top_k),
328
+ pad_token_id=tokenizer.eos_token_id,
329
+ eos_token_id=tokenizer.eos_token_id,
330
+ )
331
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
332
+ if text.startswith(prompt):
333
+ text = text[len(prompt):]
334
+ return text.strip()
335
+
336
+ # ---- MCQ helpers ----
337
+ def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=5000) -> str:
338
+ sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی."
339
+ sys_en = "You are a helpful assistant. Answer only using the retrieved passages."
340
+ system_text = sys_fa if lang == "fa" else sys_en
341
+
342
+ parts = []
343
+ for i, c in enumerate(contexts, 1):
344
+ bi = c["bio"].strip()
345
+ if bi:
346
+ parts.append(f"[{i}] {bi}")
347
+ joined = _truncate_chars("\n\n".join(parts), max_chars)
348
+
349
+ labels = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
350
+ opts_str = "\n".join([f"{labels[i]}) {o}" for i, o in enumerate(options)])
351
+
352
+ if lang == "fa":
353
+ user = (
354
+ f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
355
+ 'فقط براساس متون بالا پاسخ بده. دقیقاً در این قالب برگردان:\n{"answer_index": X, "reason": "…"}'
356
+ )
357
+ else:
358
+ user = (
359
+ f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
360
+ 'Answer strictly based on passages. Return exactly:\n{"answer_index": X, "reason": "..."}'
361
+ )
362
+ msgs = [{"role": "system", "content": system_text},
363
+ {"role": "user", "content": user}]
364
+ return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
365
+
366
+ def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
367
+ m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
368
+ if m:
369
+ idx = int(m.group(1)); reason = m.group(2).strip()
370
+ if 0 <= idx < n:
371
+ return {"answer_index": idx, "reason": reason}
372
+ letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
373
+ m2 = re.search(r'\b([A-D])\b', text, re.I)
374
+ if m2:
375
+ idx = letters.index(m2.group(1).upper())
376
+ if idx < n:
377
+ return {"answer_index": idx, "reason": text.strip()}
378
+ m3 = re.search(r'\b([1-9])\b', text)
379
+ if m3:
380
+ idx = int(m3.group(1)) - 1
381
+ if 0 <= idx < n:
382
+ return {"answer_index": idx, "reason": text.strip()}
383
+ return {"answer_index": None, "reason": text.strip()}
384
+
385
+ # =========================
386
+ # Gradio UI
387
+ # =========================
388
+ def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k):
389
+ if not question or not question.strip():
390
+ return "Please enter a question.", [], ""
391
+ # Retrieve
392
+ ret = retrieve_context_auto(question, k=int(topk), image=image)
393
+ prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=5000)
394
+ ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
395
+ temperature=float(temperature), top_p=float(top_p),
396
+ top_k=int(top_k), do_sample=False)
397
+ # Sources
398
+ rows = []
399
+ for i, c in enumerate(ret["contexts"], 1):
400
+ snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
401
+ rows.append([i, c["id"], round(c["score"], 4), snip])
402
+ return ans, rows, ret["route"]
403
+
404
+ def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, top_k):
405
+ opts = [o.strip() for o in (options_txt or "").splitlines() if o.strip()]
406
+ if not question or len(opts) < 2:
407
+ return "Provide a question and at least 2 options.", "", [], ""
408
+ ret = retrieve_context_auto(question, k=int(topk), image=image)
409
+ prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
410
+ out = llm_generate(prompt, max_new_tokens=int(max_tokens),
411
+ temperature=float(temperature), top_p=float(top_p),
412
+ top_k=int(top_k), do_sample=False)
413
+ parsed = parse_mcq_output(out, len(opts))
414
+ pred = parsed["answer_index"]
415
+ pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")
416
+ rows = []
417
+ for i, c in enumerate(ret["contexts"], 1):
418
+ snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
419
+ rows.append([i, c["id"], round(c["score"], 4), snip])
420
+ result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}"
421
+ return result, out, rows, ret["route"]
422
+
423
+ with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
424
+ gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
425
+ with gr.Tab("Ask"):
426
+ with gr.Row():
427
+ q = gr.Textbox(label="Question", placeholder="سؤال خود را بنویسید…", lines=3)
428
+ img = gr.Image(type="pil", label="Optional image (fusion if provided)")
429
+ with gr.Row():
430
+ topk = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
431
+ max_tokens = gr.Slider(32, 1024, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
432
+ with gr.Row():
433
+ temperature = gr.Slider(0.0, 1.5, value=TEMPERATURE_DEFAULT, step=0.1, label="Temperature")
434
+ top_p = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
435
+ top_k = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
436
+ btn = gr.Button("Answer")
437
+ ans = gr.Textbox(label="Answer", lines=8)
438
+ route = gr.Textbox(label="Route used (text_e5 or fusion)")
439
+ table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
440
+ btn.click(ui_answer, [q, img, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])
441
+
442
+ with gr.Tab("MCQ"):
443
+ with gr.Row():
444
+ q_mcq = gr.Textbox(label="Question", lines=3)
445
+ opts_mcq = gr.Textbox(label="Options (one per line)", lines=6)
446
+ img_mcq = gr.Image(type="pil", label="Optional image (fusion if provided)")
447
+ with gr.Row():
448
+ topk2 = gr.Slider(1, 20, value=5, step=1, label="Top-K retrieve")
449
+ max_tokens2 = gr.Slider(32, 1024, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
450
+ with gr.Row():
451
+ temperature2 = gr.Slider(0.0, 1.5, value=TEMPERATURE_DEFAULT, step=0.1, label="Temperature")
452
+ top_p2 = gr.Slider(0.1, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
453
+ top_k2 = gr.Slider(1, 100, value=TOP_K_DEFAULT, step=1, label="Top-k")
454
+ btn2 = gr.Button("Answer MCQ")
455
+ result = gr.Textbox(label="Prediction")
456
+ raw = gr.Textbox(label="Raw LLM output", lines=6)
457
+ route2 = gr.Textbox(label="Route used")
458
+ table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
459
+ btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
460
+ [result, raw, table2, route2])
461
+
462
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.29.0
2
+ huggingface_hub>=0.23.0
3
+ pandas
4
+ numpy
5
+ pillow
6
+ faiss-cpu
7
+ sentence-transformers>=2.5.0
8
+ transformers>=4.41.0
9
+ open_clip_torch>=2.24.0
10
+ tqdm
11
+ torch