Spaces:
Running
Running
| """ | |
| Weight verification: base vs time-aware, sequential to avoid memory pressure. | |
| Runs in ~30 min on CPU. | |
| """ | |
| if __name__ == '__main__': | |
| import sys, re, random, gc | |
| sys.path.insert(0, 'src') | |
| import torch | |
| import numpy as np | |
| from datasets import load_dataset | |
| from transformers import AutoModel, AutoTokenizer | |
| from tqdm import tqdm | |
| from mrag_integration import ( | |
| encode_texts, | |
| precompute_window_embeddings, mrag_rerank_1, | |
| ) | |
| BASE_NAME = 'facebook/contriever-msmarco' | |
| TIME_PATH = 'contriever_finetuned_NEW_20k' | |
| N_QUESTIONS = 100 | |
| N_DISTRACTORS = 400 | |
| SEED = 42 | |
| # ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading ChroniclingAmericaQA...") | |
| caqa = load_dataset("Bhawna/ChroniclingAmericaQA", split="validation") | |
| passages, passage_to_id = [], {} | |
| questions_year, gold_ids_year = [], [] | |
| year_re = re.compile(r'\b(18[0-9]{2}|19[0-9]{2}|20[0-2][0-9])\b') | |
| for ex in caqa: | |
| q = ex.get('question') or ex.get('query') | |
| p = ex.get('context') or ex.get('positive_passage') or ex.get('passage') | |
| if not q or not p: continue | |
| if p not in passage_to_id: | |
| passage_to_id[p] = len(passages) | |
| passages.append(p) | |
| if year_re.search(q): | |
| questions_year.append(q) | |
| gold_ids_year.append(passage_to_id[p]) | |
| print(f"Full corpus: {len(passages)} passages | {len(questions_year)} year-questions") | |
| # ββ Mini corpus βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| random.seed(SEED) | |
| qs = questions_year[:N_QUESTIONS] | |
| gold_ids = gold_ids_year[:N_QUESTIONS] | |
| gold_set = set(gold_ids) | |
| distractors = random.sample( | |
| [i for i in range(len(passages)) if i not in gold_set], | |
| min(N_DISTRACTORS, len(passages) - len(gold_set)) | |
| ) | |
| pool_ids = sorted(gold_set | set(distractors)) | |
| pool_passages = [passages[i] for i in pool_ids] | |
| old2new = {old: new for new, old in enumerate(pool_ids)} | |
| gold_new = [old2new[g] for g in gold_ids] | |
| print(f"Mini corpus: {len(pool_passages)} passages " | |
| f"({len(gold_set)} gold + {len(distractors)} distractors)\n") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_NAME) | |
| top_k = min(100, len(pool_passages)) | |
| results = {} | |
| # ββ Evaluate one model, then free all memory before loading next ββ | |
| for model_path, label in [(BASE_NAME, 'base'), (TIME_PATH, 'time_aware')]: | |
| print(f"\n{'='*55}") | |
| print(f"Loading: {label} ({model_path})") | |
| model = AutoModel.from_pretrained(model_path).eval() | |
| print("Precomputing window embeddings...") | |
| win_tensor, win_map = precompute_window_embeddings( | |
| model, tokenizer, pool_passages) | |
| print("Encoding passages + queries...") | |
| p_embs = encode_texts(model, tokenizer, pool_passages) | |
| q_embs = encode_texts(model, tokenizer, qs) | |
| sim = q_embs @ p_embs.T | |
| ids = np.argsort(-sim, axis=1)[:, :top_k] | |
| scores = np.take_along_axis(sim, ids, axis=1) | |
| hits = 0 | |
| for qi, gold in enumerate(tqdm(gold_new, desc=f"eval [{label}]")): | |
| cand_ids = [int(c) for c in ids[qi] if 0 <= c < len(pool_passages)] | |
| cand_scores = scores[qi][:len(cand_ids)] | |
| cand_texts = [pool_passages[c] for c in cand_ids] | |
| ranked = mrag_rerank_1( | |
| qs[qi], cand_texts, cand_ids, model, tokenizer, | |
| base_scores=cand_scores, blend_weight=0.0, temporal_weight=1.0, | |
| window_emb_tensor=win_tensor, doc_window_map=win_map, | |
| ) | |
| if ranked and ranked[0] == gold: | |
| hits += 1 | |
| results[label] = hits / N_QUESTIONS | |
| print(f" {label} Hit@1: {results[label]:.4f} ({results[label]*100:.1f}%)") | |
| # Free everything before loading the next model | |
| del model, win_tensor, win_map, p_embs, q_embs, sim, ids, scores | |
| gc.collect() | |
| # ββ Report ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| r_base = results['base'] | |
| r_time = results['time_aware'] | |
| lift_here = r_time - r_base | |
| lift_stored = 0.5915 - 0.4036 | |
| print("\n" + "="*55) | |
| print(f"VERIFICATION ({N_QUESTIONS} questions | {len(pool_passages)} passages)") | |
| print(f" base_only Hit@1: {r_base:.4f} ({r_base*100:.1f}%)") | |
| print(f" mrag_time_aware Hit@1: {r_time:.4f} ({r_time*100:.1f}%)") | |
| print(f" lift: +{lift_here*100:.1f}pp") | |
| print() | |
| print("STORED full-corpus (12,695 passages | 1,219 questions):") | |
| print(f" base_only Hit@1: 40.4%") | |
| print(f" mrag_time_aware Hit@1: 59.2%") | |
| print(f" lift: +18.8pp") | |
| print() | |
| verdict = "PASS β" if (r_time > r_base and lift_here > 0.10) else "INCONCLUSIVE" | |
| print(f"VERDICT: {verdict}") | |