File size: 26,380 Bytes
3f6908f
73a17b2
 
3f6908f
 
 
 
 
 
 
 
 
 
 
73a17b2
 
 
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
 
 
 
3f6908f
73a17b2
 
 
 
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
3f6908f
 
 
73a17b2
 
3f6908f
 
 
 
73a17b2
 
 
 
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
3f6908f
73a17b2
 
3f6908f
 
73a17b2
 
 
 
 
 
 
3f6908f
 
 
 
 
 
 
 
 
73a17b2
6ca41c8
 
 
 
 
 
 
 
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca41c8
 
 
3f6908f
 
 
 
6ca41c8
 
3f6908f
6ca41c8
3f6908f
 
 
 
6ca41c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca41c8
73a17b2
 
 
 
 
 
6ca41c8
73a17b2
 
 
 
 
 
 
 
6ca41c8
73a17b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca41c8
73a17b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca41c8
73a17b2
 
6ca41c8
 
 
 
 
 
 
 
73a17b2
 
6ca41c8
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a17b2
3f6908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca41c8
 
3f6908f
 
6ca41c8
3f6908f
 
 
 
6ca41c8
3f6908f
 
 
6ca41c8
3f6908f
 
 
 
73a17b2
 
 
3f6908f
73a17b2
 
3f6908f
73a17b2
 
 
3f6908f
 
 
 
73a17b2
 
3f6908f
 
 
6ca41c8
73a17b2
3f6908f
73a17b2
 
3f6908f
73a17b2
 
 
3f6908f
6ca41c8
 
3f6908f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
import os, io, gc, json, re, ast
from functools import lru_cache

import numpy as np
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
from typing import List, Dict, Any
from PIL import Image, ImageFilter, ImageOps, ImageEnhance
import gradio as gr
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import os, torch
torch.set_num_threads(2)                 # vCPUهای Space معمولاً 2 تاست
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# =========================
# Config (override in Space → Settings → Variables & secrets)
# =========================
DATASET_REPO      = os.getenv("DATASET_REPO", "ahm1378/NLP-Project")  # <--- CHANGE to your repo
CSV_FILE          = os.getenv("CSV_FILE", "final_merged_images.csv")
E5_INDEX_FILE     = os.getenv("E5_INDEX_FILE", "faiss_e5_rag_v15.ip")
E5_EMB_FILE       = os.getenv("E5_EMB_FILE", "doc_embeds_e5_rag_v15.npy")
FUSION_INDEX_FILE = os.getenv("FUSION_INDEX_FILE", "faiss_fusion.ip")
FUSION_EMB_FILE   = os.getenv("FUSION_EMB_FILE", "fusion_doc_emb.npy")
FT_HEAD_FILE      = os.getenv("FT_HEAD_FILE", "finetune_clip_fa.pt")  # your finetuned text projection (CLIP space)

HF_TOKEN          = os.getenv("HF_TOKEN", None)  # needed if DATASET_REPO is private

# Models (CPU-friendly defaults; override via env if desired)
E5_ID       = os.getenv("E5_ID", "intfloat/multilingual-e5-small")
CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1")
LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct")

# خروجی کوتاه‌تر
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96"))  # قبلاً 256

# نمونه‌برداری خاموش (قطعی و سریع‌تر)
TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0"))
TOP_P_DEFAULT       = float(os.getenv("TOP_P_DEFAULT", "1.0"))
TOP_K_DEFAULT       = int(os.getenv("TOP_K_DEFAULT", "50"))

# =========================
# Helpers
# =========================
def normalize_digits_months(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    trans = str.maketrans("۰۱۲۳۴۵۶۷۸۹٠١٢٣٤٥٦٧٨٩", "01234567890123456789")
    s = s.translate(trans).replace("\u200c", " ").strip()
    return s

def _truncate_chars(s: str, limit: int) -> str:
    return s if (limit is None or len(s) <= limit) else s[:limit] + "…"

def _maybe_hub(file, repo=DATASET_REPO, repo_type="dataset") -> str:
    # If present locally, use it. Otherwise download from Hub.
    if os.path.isfile(file):
        return file
    return hf_hub_download(repo_id=repo, filename=file, repo_type=repo_type, token=HF_TOKEN)

# =========================
# Fetch artifacts
# =========================
CSV_PATH          = _maybe_hub(CSV_FILE)
E5_INDEX_PATH     = _maybe_hub(E5_INDEX_FILE)
# (E5_EMB_PATH not strictly needed at runtime)
FUSION_INDEX_PATH = _maybe_hub(FUSION_INDEX_FILE) if FUSION_INDEX_FILE else None
FT_HEAD_PATH      = _maybe_hub(FT_HEAD_FILE) if FT_HEAD_FILE else None

# =========================
# Load dataframe
# =========================
if not os.path.isfile(CSV_PATH):
    raise FileNotFoundError(f"CSV missing: {CSV_PATH}")

df = pd.read_csv(CSV_PATH)

# Expect columns: 'id', 'bio', 'image_paths_abs' (list or stringified list)
def first_image(x):
    if isinstance(x, list) and x:
        return x[0]
    if isinstance(x, str) and x.strip():
        # try JSON list
        try:
            lst = json.loads(x)
            if isinstance(lst, list) and lst:
                return lst[0]
        except Exception:
            # try Python literal list (handles single quotes)
            try:
                lst = ast.literal_eval(x)
                if isinstance(lst, list) and lst:
                    return lst[0]
            except Exception:
                return x  # treat as single path
    return ""

if "image_paths_abs" in df.columns:
    df["first_image"] = df["image_paths_abs"].apply(first_image)
else:
    df["first_image"] = ""

if "bio" not in df.columns:
    raise KeyError("Expected 'bio' column in CSV.")
df["bio"] = df["bio"].astype(str)

# =========================
# Indices
# =========================
if not os.path.isfile(E5_INDEX_PATH):
    raise FileNotFoundError(f"E5 index not found: {E5_INDEX_PATH}")
index_e5 = faiss.read_index(E5_INDEX_PATH)

index_fusion = None
if FUSION_INDEX_PATH and os.path.isfile(FUSION_INDEX_PATH):
    index_fusion = faiss.read_index(FUSION_INDEX_PATH)

# =========================
# Models (CPU-only)
# =========================
device = "cpu"
dtype  = torch.float32

# Text retrieval encoder (E5)
st_e5 = SentenceTransformer(E5_ID, device=device)

# CLIP text encoder (fallback when no FT head)
st_clip_txt = SentenceTransformer(CLIP_TXT_ID, device=device).eval()

# Optional: finetuned CLIP text projection head (512->512, bias=False)
mclip = SentenceTransformer(CLIP_TXT_ID, device=device).eval()
proj_txt = None
if FT_HEAD_PATH and os.path.isfile(FT_HEAD_PATH):
    try:
        proj_txt = torch.nn.Linear(512, 512, bias=False)
        ckpt = torch.load(FT_HEAD_PATH, map_location="cpu")
        if "proj_txt" in ckpt:
            proj_txt.load_state_dict(ckpt["proj_txt"])
        elif "state_dict" in ckpt:
            proj_txt.load_state_dict(ckpt["state_dict"])
        else:
            raise KeyError("No 'proj_txt' or 'state_dict' key in FT checkpoint.")
        proj_txt.eval()
        print("[OK] loaded finetuned projection head:", FT_HEAD_PATH)
    except Exception as e:
        print("[WARN] failed to load finetuned head:", e)
        proj_txt = None

# Lazy CLIP image encoder (only load if user actually does fusion)
clip_model = None
clip_preprocess = None
def _ensure_clip_loaded():
    global clip_model, clip_preprocess
    if clip_model is None:
        import open_clip  # lazy import
        model, _, preprocess_val = open_clip.create_model_and_transforms(
            "ViT-B-32", pretrained="laion2b_s34b_b79k", device="cpu"
        )
        clip_model = model.eval()
        clip_preprocess = preprocess_val
        print("[OK] CLIP ViT-B/32 loaded on CPU")

# LLM (small; CPU-friendly)
tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    LLM_ID,
    torch_dtype=dtype,
).to("cpu").eval()


# =========================
# Retrieval helpers
# =========================
@lru_cache(maxsize=4096)
def _encode_query_e5_cached(q: str) -> np.ndarray:
    qn = "query: " + normalize_digits_months(q)
    v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0]
    return v.astype("float32")

# استفاده به‌جای قدیمی:
def _encode_query_e5(q: str) -> np.ndarray:
    return _encode_query_e5_cached(q)

def _faiss_search(index, q_vec: np.ndarray, k: int):
    if q_vec.ndim == 1:
        q_vec = q_vec[None, :]
    s, I = index.search(q_vec.astype("float32"), k)
    return list(zip(I[0].tolist(), s[0].tolist()))

def search_text_rag(query_text: str, k: int = 5):
    q = _encode_query_e5(query_text)
    return _faiss_search(index_e5, q, k)

# ---- Fusion (CLIP space) ----
def _jpeg(img, quality=40):
    buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality, optimize=False)
    buf.seek(0); return Image.open(buf).convert("RGB")

def _rand_resized_crop(img, scale=(0.7, 0.9)):
    w,h = img.size; s = np.random.uniform(*scale)
    nw,nh = max(1,int(w*s)), max(1,int(h*s))
    left  = np.random.randint(0, max(1, w-nw))
    top   = np.random.randint(0, max(1, h-nh))
    return img.crop((left, top, left+nw, top+nh)).resize((w, h), Image.BICUBIC)

def _color_jitter(img, b=(0.9,1.1), c=(0.9,1.1)):
    img = ImageOps.autocontrast(img)
    img = ImageEnhance.Brightness(img).enhance(np.random.uniform(*b))
    img = ImageEnhance.Contrast(img).enhance(np.random.uniform(*c))
    return img

def augment_once(img: Image.Image, level="medium"):
    if level == "mild":
        img = _rand_resized_crop(img, (0.85, 0.95)); img = _jpeg(img, 60)
    elif level == "medium":
        img = _rand_resized_crop(img, (0.7, 0.9))
        img = img.filter(ImageFilter.GaussianBlur(1.0))
        img = _color_jitter(img, (0.9,1.1), (0.9,1.1)); img = _jpeg(img, 40)
    else:
        img = _rand_resized_crop(img, (0.6, 0.8))
        img = img.filter(ImageFilter.GaussianBlur(1.2)); img = _jpeg(img, 30)
    return img

@torch.no_grad()
def _encode_pil_clip(img: Image.Image) -> np.ndarray:
    _ensure_clip_loaded()
    t = clip_preprocess(img).unsqueeze(0)
    feat = clip_model.encode_image(t)
    feat = F.normalize(feat.float(), dim=-1)
    return feat.cpu().numpy().astype("float32")  # (1,512)

@torch.no_grad()
def _encode_query_text_clipspace(q: str) -> np.ndarray:
    qn = normalize_digits_months(q)
    if proj_txt is not None:
        # mclip raw → proj → normalize
        t = torch.tensor(
            mclip.encode([qn], convert_to_numpy=True, normalize_embeddings=False),
            dtype=torch.float32
        )
        t = proj_txt(t)
        t = F.normalize(t, dim=-1).cpu().numpy().astype("float32")
        return t
    else:
        # fallback: CLIP multilingual text encoder (already normalized)
        t = st_clip_txt.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)
        return t.astype("float32")

@torch.no_grad()
def make_query_embed(query_text: str,
                     image: Image.Image = None,
                     alpha_q: float = 0.7,
                     use_aug: bool = True,
                     n_aug: int = 3) -> np.ndarray:
    qt = _encode_query_text_clipspace(query_text)  # (1,512)
    qi = None
    if image is not None:
        if clip_model is None:  # ensure loaded only if needed
            _ensure_clip_loaded()
        if use_aug:
            feats = [ _encode_pil_clip(augment_once(image, "medium")) for _ in range(max(1,int(n_aug))) ]
            qi = np.mean(np.vstack(feats), axis=0, keepdims=True).astype("float32")
        else:
            qi = _encode_pil_clip(image)
    if qi is not None:
        qv = torch.from_numpy(alpha_q*qt + (1.0-alpha_q)*qi)
        qv = F.normalize(qv, dim=-1).cpu().numpy().astype("float32")
        return qv
    return qt

def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7):
    if index_fusion is None:
        raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).")
    qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3)
    return _faiss_search(index_fusion, qv, k)

# =========================
# RAG + LLM
# =========================
def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None) -> Dict[str, Any]:
    q = normalize_digits_months(question)
    if (image is not None):
        route = "fusion"
        try:
            hits = search_fusion(q, image=image, k=k)
        except Exception as e:
            route = "text_e5"  # graceful fallback
            hits = search_text_rag(q, k=k)
    else:
        route = "text_e5"
        hits = search_text_rag(q, k=k)

    ctxs = []
    for idx, score in hits:
        if 0 <= idx < len(df):
            row = df.iloc[idx]
            ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])})
    return {"route": route, "contexts": ctxs}

def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
    sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»"
    sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'"
    system_text = sys_fa if lang == "fa" else sys_en

    parts = []
    for i, c in enumerate(contexts, 1):
        bi = c["bio"].strip()
        if bi:
            parts.append(f"[{i}] {bi}")
    joined = _truncate_chars("\n\n".join(parts), max_chars)

    user = (f"سؤال: {question}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
            f"فقط با اتکا به متون بالا پاسخ بده و منابع را با [1], [2], ... ارجاع بده."
           ) if lang == "fa" else (
            f"Question: {question}\n\nRetrieved passages:\n{joined}\n\n"
            f"Answer only using the passages, cite sources as [1], [2], ..."
           )
    msgs = [{"role": "system", "content": system_text},
            {"role": "user", "content": user}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

@torch.inference_mode()
def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str:
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,             # قطعی
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=1,                 # بدون beam-search
        use_cache=True,              # سریع‌تر
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    if text.startswith(prompt):
        text = text[len(prompt):]
    return text.strip()

# ---- MCQ helpers ----
def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str:
    sys_fa = (
        "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. "
        "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی."
    )
    sys_en = (
        "You are a helpful assistant. Answer ONLY using the retrieved passages. "
        "You MUST return a single JSON object and nothing else."
    )
    system_text = sys_fa if lang == "fa" else sys_en

    parts = []
    for i, c in enumerate(contexts, 1):
        bi = c["bio"].strip()
        if bi:
            parts.append(f"[{i}] {bi}")
    joined = _truncate_chars("\n\n".join(parts), max_chars)

    labels = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    opts_str = "\n".join([f"{labels[i]}) {o}" for i, o in enumerate(options)])

    if lang == "fa":
        user = (
            f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n"
            "دقیقاً و فقط یک JSON برگردان. فرمت اجباری: "
            '{"answer_index": X, "reason": "…"} '
            "که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس."
        )
    else:
        user = (
            f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n"
            'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} '
            "where X is the 0-based option index. Do not write anything else."
        )

    msgs = [{"role": "system", "content": system_text},
            {"role": "user", "content": user}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

import json as _json
import re as _re
import numpy as _np

def _strict_json_from_text(text: str):
    # فقط اولین بلاک {...} را بگیر و JSON-parse کن
    m = _re.search(r'\{.*\}', text, _re.S)
    if not m:
        return None
    frag = m.group(0)
    try:
        obj = _json.loads(frag)
        return obj
    except Exception:
        return None

import re as _re
import numpy as _np

def _norm_text_for_match(s: str) -> str:
    # نرمال‌سازی ساده: اعداد فارسی/عربی، ZWNJ، فاصله‌های اضافه
    s = normalize_digits_months(s or "")
    s = s.replace("\u200c", " ").strip()
    # پایین‌حرفی و تک‌فاصله
    s = _re.sub(r"\s+", " ", s.lower())
    return s

def _find_snippet(hay: str, needle: str, win: int = 60) -> str:
    """یک تکه متن کوتاه اطراف اولین مچ را بده."""
    try:
        i = hay.index(needle)
        start = max(0, i - win)
        end   = min(len(hay), i + len(needle) + win)
        return hay[start:end].replace("\n", " ")
    except ValueError:
        return ""

def score_options_by_context(
    options: List[str],
    contexts: List[Dict[str, Any]],
    return_snippet: bool = False
):
    """
    فال‌بک هوشمند:
      1) boundary-aware substring در تک‌تک کانتکست‌ها (امتیاز بالا + تعداد وقوع)
      2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکست‌ها
    خروجی:
      - اگر return_snippet=False → فقط best_idx (int)
      - اگر return_snippet=True  → (best_idx, snippet) برمی‌گرداند
    """
    # آماده‌سازی کانتکست‌ها
    raw_ctxs = [c.get("bio", "") for c in contexts]
    norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs]
    joined_norm = " \n ".join(norm_ctxs)

    # 1) جست‌وجوی دقیق‌تر: word boundary + شمارش
    # برای فارسی/عربی هم خوب جواب می‌دهد چون از فاصله استفاده می‌کنیم.
    best_idx, best_score, best_snip = 0, -1.0, ""
    for i, opt in enumerate(options):
        o_raw = str(opt).strip()
        o = _norm_text_for_match(o_raw)
        if not o:
            continue

        # الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله)
        # اگر گزینه چندکلمه‌ای است، همین هم خوب جواب می‌دهد.
        # اگر لازم شد می‌توان regex دقیق‌تر نوشت.
        pat = r"(?<!\S)" + _re.escape(o) + r"(?!\S)"

        total_hits = 0
        first_snip = ""
        for raw, norm in zip(raw_ctxs, norm_ctxs):
            for m in _re.finditer(pat, norm):
                total_hits += 1
                if not first_snip:
                    # اسنیپت از متن خام (خواناتر)
                    # موقعیت متن خام را تقریبی می‌گیریم با جست‌وجوی ساده
                    # (اگر اختلاف normalization زیاد بود، از norm استفاده می‌کنیم)
                    sn = _find_snippet(raw, o_raw) or _find_snippet(norm, o)
                    first_snip = sn
        if total_hits > 0:
            # امتیاز بالا برای مچ صریح + تعداد وقوع
            score = 10000.0 + total_hits
            if score > best_score:
                best_score, best_idx, best_snip = score, i, first_snip

    if best_score > 0:
        return (best_idx, best_snip) if return_snippet else best_idx

    # 2) اگر هیچ مچی نبود → شباهت embedding (mE5)
    try:
        # وکتور کل کانتکست‌ها (یک‌بار)
        ctx_vec = _encode_query_e5(joined_norm)  # (dim,)
        sims = []
        for opt in options:
            qv = _encode_query_e5(str(opt))
            sims.append(float(_np.dot(qv, ctx_vec)))
        best_idx = int(_np.argmax(sims))

        # برای snippet در این مسیر: نزدیک‌ترین کانتکست را با dot جداگانه پیدا کنیم
        # (سریع و به‌اندازه کافی خوب)
        best_snip = ""
        try:
            opt_vec = _encode_query_e5(str(options[best_idx]))
            # کوساین تقریباً همان inner-prod چون نرمال شده‌اند
            # امتیاز هر کانتکست با گزینه‌ی برنده:
            c_scores = []
            for raw, norm in zip(raw_ctxs, norm_ctxs):
                c_vec = _encode_query_e5(norm)
                c_scores.append(float(_np.dot(opt_vec, c_vec)))
            j = int(_np.argmax(c_scores))
            best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ")
        except Exception:
            pass

        return (best_idx, best_snip) if return_snippet else best_idx
    except Exception:
        return (0, "") if return_snippet else 0  # پیش‌فرض محافظه‌کارانه


def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]:
    obj = _strict_json_from_text(text)
    if obj and "answer_index" in obj:
        idx = obj["answer_index"]
        if isinstance(idx, int) and 0 <= idx < len(options):
            reason = str(obj.get("reason", "")).strip() or "—"
            return {"answer_index": idx, "reason": reason}
    idx, snip = score_options_by_context(options, contexts, return_snippet=True)
    return {"answer_index": idx, "reason": snip or "matched by context"}

def parse_mcq_output(text: str, n: int) -> Dict[str, Any]:
    m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S)
    if m:
        idx = int(m.group(1)); reason = m.group(2).strip()
        if 0 <= idx < n:
            return {"answer_index": idx, "reason": reason}
    letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    m2 = re.search(r'\b([A-D])\b', text, re.I)
    if m2:
        idx = letters.index(m2.group(1).upper())
        if idx < n:
            return {"answer_index": idx, "reason": text.strip()}
    m3 = re.search(r'\b([1-9])\b', text)
    if m3:
        idx = int(m3.group(1)) - 1
        if 0 <= idx < n:
            return {"answer_index": idx, "reason": text.strip()}
    return {"answer_index": None, "reason": text.strip()}

# =========================
# Gradio UI
# =========================
def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k):
    if not question or not question.strip():
        return "Please enter a question.", [], ""
    # Retrieve
    ret = retrieve_context_auto(question, k=int(topk), image=image)
    prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800)
    ans = llm_generate(prompt, max_new_tokens=int(max_tokens),
                       temperature=float(temperature), top_p=float(top_p),
                       top_k=int(top_k), do_sample=False)
    # Sources
    rows = []
    for i, c in enumerate(ret["contexts"], 1):
        snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
        rows.append([i, c["id"], round(c["score"], 4), snip])
    return ans, rows, ret["route"]

def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, top_k):
    opts = [o.strip() for o in (options_txt or "").splitlines() if o.strip()]
    if not question or len(opts) < 2:
        return "Provide a question and at least 2 options.", "", [], ""
    ret = retrieve_context_auto(question, k=int(topk), image=image)
    prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000)
    out = llm_generate(prompt, max_new_tokens=int(max_tokens),
                       temperature=float(temperature), top_p=float(top_p),
                       top_k=int(top_k), do_sample=False)  # deterministic on CPU
    parsed = parse_mcq_output_strict(out, opts, ret["contexts"])
    pred = parsed["answer_index"]
    pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A")

    rows = []
    for i, c in enumerate(ret["contexts"], 1):
        snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "")
        rows.append([i, c["id"], round(c["score"], 4), snip])

    result = f"Pred: index={pred}  text={pred_text}\nReason: {parsed['reason']}"
    return result, out, rows, ret["route"]


with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo:
    gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B")
    with gr.Tab("Ask"):
        with gr.Row():
            q = gr.Textbox(label="Question", lines=3)
            img = gr.Image(type="pil", label="Optional image")
            use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False)
        with gr.Row():
            topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
            max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
        with gr.Row():
            temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
            top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
        btn = gr.Button("Answer")
        ans = gr.Textbox(label="Answer", lines=8)
        route = gr.Textbox(label="Route used (text_e5 or fusion)")
        table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
        btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route])

    with gr.Tab("MCQ"):
        with gr.Row():
            q_mcq = gr.Textbox(label="Question", lines=3)
            opts_mcq = gr.Textbox(label="Options (one per line)", lines=8)
        img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)")
        with gr.Row():
            topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve")
            max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens")
        with gr.Row():
            temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
            top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p")
            top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k")
        btn2 = gr.Button("Answer MCQ")
        result = gr.Textbox(label="Prediction", lines=12, max_lines=20)
        raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20)
        route2 = gr.Textbox(label="Route used")
        table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False)
        btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2],
                   [result, raw, table2, route2])

demo.launch()