Spaces:
Sleeping
Sleeping
File size: 5,227 Bytes
46b9b58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """
Weight verification: base vs time-aware, sequential to avoid memory pressure.
Runs in ~30 min on CPU.
"""
if __name__ == '__main__':
import sys, re, random, gc
sys.path.insert(0, 'src')
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from mrag_integration import (
encode_texts,
precompute_window_embeddings, mrag_rerank_1,
)
BASE_NAME = 'facebook/contriever-msmarco'
TIME_PATH = 'contriever_finetuned_NEW_20k'
N_QUESTIONS = 100
N_DISTRACTORS = 400
SEED = 42
# ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading ChroniclingAmericaQA...")
caqa = load_dataset("Bhawna/ChroniclingAmericaQA", split="validation")
passages, passage_to_id = [], {}
questions_year, gold_ids_year = [], []
year_re = re.compile(r'\b(18[0-9]{2}|19[0-9]{2}|20[0-2][0-9])\b')
for ex in caqa:
q = ex.get('question') or ex.get('query')
p = ex.get('context') or ex.get('positive_passage') or ex.get('passage')
if not q or not p: continue
if p not in passage_to_id:
passage_to_id[p] = len(passages)
passages.append(p)
if year_re.search(q):
questions_year.append(q)
gold_ids_year.append(passage_to_id[p])
print(f"Full corpus: {len(passages)} passages | {len(questions_year)} year-questions")
# ββ Mini corpus βββββββββββββββββββββββββββββββββββββββββββββββββββ
random.seed(SEED)
qs = questions_year[:N_QUESTIONS]
gold_ids = gold_ids_year[:N_QUESTIONS]
gold_set = set(gold_ids)
distractors = random.sample(
[i for i in range(len(passages)) if i not in gold_set],
min(N_DISTRACTORS, len(passages) - len(gold_set))
)
pool_ids = sorted(gold_set | set(distractors))
pool_passages = [passages[i] for i in pool_ids]
old2new = {old: new for new, old in enumerate(pool_ids)}
gold_new = [old2new[g] for g in gold_ids]
print(f"Mini corpus: {len(pool_passages)} passages "
f"({len(gold_set)} gold + {len(distractors)} distractors)\n")
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME)
top_k = min(100, len(pool_passages))
results = {}
# ββ Evaluate one model, then free all memory before loading next ββ
for model_path, label in [(BASE_NAME, 'base'), (TIME_PATH, 'time_aware')]:
print(f"\n{'='*55}")
print(f"Loading: {label} ({model_path})")
model = AutoModel.from_pretrained(model_path).eval()
print("Precomputing window embeddings...")
win_tensor, win_map = precompute_window_embeddings(
model, tokenizer, pool_passages)
print("Encoding passages + queries...")
p_embs = encode_texts(model, tokenizer, pool_passages)
q_embs = encode_texts(model, tokenizer, qs)
sim = q_embs @ p_embs.T
ids = np.argsort(-sim, axis=1)[:, :top_k]
scores = np.take_along_axis(sim, ids, axis=1)
hits = 0
for qi, gold in enumerate(tqdm(gold_new, desc=f"eval [{label}]")):
cand_ids = [int(c) for c in ids[qi] if 0 <= c < len(pool_passages)]
cand_scores = scores[qi][:len(cand_ids)]
cand_texts = [pool_passages[c] for c in cand_ids]
ranked = mrag_rerank_1(
qs[qi], cand_texts, cand_ids, model, tokenizer,
base_scores=cand_scores, blend_weight=0.0, temporal_weight=1.0,
window_emb_tensor=win_tensor, doc_window_map=win_map,
)
if ranked and ranked[0] == gold:
hits += 1
results[label] = hits / N_QUESTIONS
print(f" {label} Hit@1: {results[label]:.4f} ({results[label]*100:.1f}%)")
# Free everything before loading the next model
del model, win_tensor, win_map, p_embs, q_embs, sim, ids, scores
gc.collect()
# ββ Report ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
r_base = results['base']
r_time = results['time_aware']
lift_here = r_time - r_base
lift_stored = 0.5915 - 0.4036
print("\n" + "="*55)
print(f"VERIFICATION ({N_QUESTIONS} questions | {len(pool_passages)} passages)")
print(f" base_only Hit@1: {r_base:.4f} ({r_base*100:.1f}%)")
print(f" mrag_time_aware Hit@1: {r_time:.4f} ({r_time*100:.1f}%)")
print(f" lift: +{lift_here*100:.1f}pp")
print()
print("STORED full-corpus (12,695 passages | 1,219 questions):")
print(f" base_only Hit@1: 40.4%")
print(f" mrag_time_aware Hit@1: 59.2%")
print(f" lift: +18.8pp")
print()
verdict = "PASS β" if (r_time > r_base and lift_here > 0.10) else "INCONCLUSIVE"
print(f"VERDICT: {verdict}")
|