File size: 3,834 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import List, Dict, Any

import pandas as pd

from data.catalog_loader import load_catalog
from data.train_loader import load_train
from retrieval.query_rewriter import rewrite_query
from recommenders.bm25 import BM25Recommender
from recommenders.vector_recommender import VectorRecommender
from retrieval.vector_index import VectorIndex
from models.embedding_model import EmbeddingModel
from recommenders.hybrid_rrf import HybridRRFRecommender


def rank_of_first_positive(preds: List[str], positives: set, not_found: int) -> int:
    for i, p in enumerate(preds, 1):
        if p in positives:
            return i
    return not_found  # indicate not found within retrieved set


def main():
    parser = argparse.ArgumentParser(description="Evaluate impact of query rewriting on positive ranks.")
    parser.add_argument("--catalog", default="data/catalog_docs_rich.jsonl")
    parser.add_argument("--train", required=True, help="Train file (xlsx/jsonl) with labels")
    parser.add_argument("--vector-index", required=True, help="FAISS index path")
    parser.add_argument("--assessment-ids", required=True, help="assessment_ids.json aligned with index")
    parser.add_argument("--embedding-model", default="BAAI/bge-small-en-v1.5")
    parser.add_argument("--topn", type=int, default=200, help="Candidates to fetch")
    parser.add_argument("--out", default="runs/rewrite_impact.jsonl")
    parser.add_argument("--vocab", help="Optional vocab json produced by build_role_vocab.py")
    args = parser.parse_args()

    df_catalog, _, id_by_url = load_catalog(args.catalog)
    examples, label_report = load_train(args.train, id_by_url)

    # Build recommenders
    bm25 = BM25Recommender(df_catalog)
    embed_model = EmbeddingModel(args.embedding_model)
    index = VectorIndex.load(args.vector_index)
    with open(args.assessment_ids) as f:
        ids = json.load(f)
    vec_rec = VectorRecommender(embed_model, index, df_catalog, ids, k_candidates=args.topn)
    hybrid = HybridRRFRecommender(bm25, vec_rec, topn_candidates=args.topn, rrf_k=60)

    vocab = {}
    if args.vocab:
        with open(args.vocab) as f:
            vocab = json.load(f)

    rows: List[Dict[str, Any]] = []

    not_found_val = args.topn + 1

    for ex in examples:
        positives = ex.relevant_ids

        # Raw query (no rewrite)
        raw_preds = hybrid.recommend(ex.query, k=200)
        raw_ids = [p["assessment_id"] if isinstance(p, dict) else p for p in raw_preds]
        raw_rank = rank_of_first_positive(raw_ids, positives, not_found=not_found_val)

        # Rule rewrite (no vocab)
        rw_rule = rewrite_query(ex.query)
        preds_rule = hybrid.recommend(rw_rule.retrieval_query, k=200)
        rule_ids = [p["assessment_id"] if isinstance(p, dict) else p for p in preds_rule]
        rule_rank = rank_of_first_positive(rule_ids, positives, not_found=not_found_val)

        # Rule + vocab rewrite
        rw_vocab = rewrite_query(ex.query, catalog_vocab=vocab)
        preds_vocab = hybrid.recommend(rw_vocab.retrieval_query, k=200)
        vocab_ids = [p["assessment_id"] if isinstance(p, dict) else p for p in preds_vocab]
        vocab_rank = rank_of_first_positive(vocab_ids, positives, not_found=not_found_val)

        rows.append(
            {
                "query": ex.query,
                "raw_rank": raw_rank,
                "rewrite_rank": rule_rank,
                "rewrite_vocab_rank": vocab_rank,
                "positives": list(positives),
            }
        )

    Path(Path(args.out).parent).mkdir(parents=True, exist_ok=True)
    pd.DataFrame(rows).to_json(args.out, orient="records", lines=True)
    print(f"Saved rewrite impact to {args.out}")


if __name__ == "__main__":
    main()