File size: 5,227 Bytes
46b9b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Weight verification: base vs time-aware, sequential to avoid memory pressure.
Runs in ~30 min on CPU.
"""
if __name__ == '__main__':
    import sys, re, random, gc
    sys.path.insert(0, 'src')

    import torch
    import numpy as np
    from datasets import load_dataset
    from transformers import AutoModel, AutoTokenizer
    from tqdm import tqdm
    from mrag_integration import (
        encode_texts,
        precompute_window_embeddings, mrag_rerank_1,
    )

    BASE_NAME     = 'facebook/contriever-msmarco'
    TIME_PATH     = 'contriever_finetuned_NEW_20k'
    N_QUESTIONS   = 100
    N_DISTRACTORS = 400
    SEED          = 42

    # ── Dataset ───────────────────────────────────────────────────────
    print("Loading ChroniclingAmericaQA...")
    caqa = load_dataset("Bhawna/ChroniclingAmericaQA", split="validation")

    passages, passage_to_id = [], {}
    questions_year, gold_ids_year = [], []
    year_re = re.compile(r'\b(18[0-9]{2}|19[0-9]{2}|20[0-2][0-9])\b')

    for ex in caqa:
        q = ex.get('question') or ex.get('query')
        p = ex.get('context') or ex.get('positive_passage') or ex.get('passage')
        if not q or not p: continue
        if p not in passage_to_id:
            passage_to_id[p] = len(passages)
            passages.append(p)
        if year_re.search(q):
            questions_year.append(q)
            gold_ids_year.append(passage_to_id[p])

    print(f"Full corpus: {len(passages)} passages | {len(questions_year)} year-questions")

    # ── Mini corpus ───────────────────────────────────────────────────
    random.seed(SEED)
    qs       = questions_year[:N_QUESTIONS]
    gold_ids = gold_ids_year[:N_QUESTIONS]
    gold_set = set(gold_ids)
    distractors = random.sample(
        [i for i in range(len(passages)) if i not in gold_set],
        min(N_DISTRACTORS, len(passages) - len(gold_set))
    )
    pool_ids      = sorted(gold_set | set(distractors))
    pool_passages = [passages[i] for i in pool_ids]
    old2new       = {old: new for new, old in enumerate(pool_ids)}
    gold_new      = [old2new[g] for g in gold_ids]
    print(f"Mini corpus: {len(pool_passages)} passages "
          f"({len(gold_set)} gold + {len(distractors)} distractors)\n")

    tokenizer = AutoTokenizer.from_pretrained(BASE_NAME)
    top_k     = min(100, len(pool_passages))
    results   = {}

    # ── Evaluate one model, then free all memory before loading next ──
    for model_path, label in [(BASE_NAME, 'base'), (TIME_PATH, 'time_aware')]:
        print(f"\n{'='*55}")
        print(f"Loading: {label}  ({model_path})")
        model = AutoModel.from_pretrained(model_path).eval()

        print("Precomputing window embeddings...")
        win_tensor, win_map = precompute_window_embeddings(
            model, tokenizer, pool_passages)

        print("Encoding passages + queries...")
        p_embs = encode_texts(model, tokenizer, pool_passages)
        q_embs = encode_texts(model, tokenizer, qs)

        sim    = q_embs @ p_embs.T
        ids    = np.argsort(-sim, axis=1)[:, :top_k]
        scores = np.take_along_axis(sim, ids, axis=1)

        hits = 0
        for qi, gold in enumerate(tqdm(gold_new, desc=f"eval [{label}]")):
            cand_ids    = [int(c) for c in ids[qi] if 0 <= c < len(pool_passages)]
            cand_scores = scores[qi][:len(cand_ids)]
            cand_texts  = [pool_passages[c] for c in cand_ids]
            ranked = mrag_rerank_1(
                qs[qi], cand_texts, cand_ids, model, tokenizer,
                base_scores=cand_scores, blend_weight=0.0, temporal_weight=1.0,
                window_emb_tensor=win_tensor, doc_window_map=win_map,
            )
            if ranked and ranked[0] == gold:
                hits += 1

        results[label] = hits / N_QUESTIONS
        print(f"  {label} Hit@1: {results[label]:.4f}  ({results[label]*100:.1f}%)")

        # Free everything before loading the next model
        del model, win_tensor, win_map, p_embs, q_embs, sim, ids, scores
        gc.collect()

    # ── Report ────────────────────────────────────────────────────────
    r_base = results['base']
    r_time = results['time_aware']
    lift_here   = r_time - r_base
    lift_stored = 0.5915 - 0.4036

    print("\n" + "="*55)
    print(f"VERIFICATION ({N_QUESTIONS} questions | {len(pool_passages)} passages)")
    print(f"  base_only        Hit@1: {r_base:.4f}  ({r_base*100:.1f}%)")
    print(f"  mrag_time_aware  Hit@1: {r_time:.4f}  ({r_time*100:.1f}%)")
    print(f"  lift:            +{lift_here*100:.1f}pp")
    print()
    print("STORED full-corpus (12,695 passages | 1,219 questions):")
    print(f"  base_only        Hit@1: 40.4%")
    print(f"  mrag_time_aware  Hit@1: 59.2%")
    print(f"  lift:            +18.8pp")
    print()
    verdict = "PASS βœ“" if (r_time > r_base and lift_here > 0.10) else "INCONCLUSIVE"
    print(f"VERDICT: {verdict}")