DylanJHJ/APRIL / qrel-analysis /eval_autoqrels_sample.py
DylanJHJ's picture
download
raw
14 kB
import argparse
import json
import sys
import os
import importlib
from typing import Optional
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
class AutoQrel:
STRATEGIES = [
"all", "direct", "thresholding", "rank",
"largest_gap", "quantile",
"optimal_per_topic", "optimal_global",
]
def __init__(
self,
qrel,
judge_run,
strategies,
threshold=0.5,
rank_cutoff=10,
gap_k=1,
quantile_cutoff=0.75,
min_relevance=1
):
"""
qrel: dict of dicts {qid: {docid: relevance}}
judge_run: dict of dicts {qid: {docid: score}} from LLM judge run
strategies: list of thresholding strategies to apply (or "all" for all strategies)
threshold: float for binarize strategy
rank_cutoff: int for rank_cutoff strategy
gap_k: int for largest_gap strategy
quantile: float for quantile strategy
min_relevance: int relevance grade threshold for human qrel
"""
self.human_qrel = qrel
if "all" in strategies:
self.strategies = self.STRATEGIES
self.strategies.remove('all')
else:
self.strategies = strategies
# Parameters
self.threshold = threshold
self.rank_cutoff = rank_cutoff
self.quantile_cutoff = quantile_cutoff
self.gap_k = gap_k
self.min_relevance = min_relevance
# Precompute LLM-derived qrels for all strategies
self.llm_qrels = {s: self._dispatch(s, judge_run) for s in self.strategies}
def _dispatch(self, strategy, run):
if strategy == "human":
return self.human_qrel
if strategy == "direct":
result = AutoQrel.direct(run)
elif strategy == "thresholding":
result = AutoQrel.thresholding(run, self.threshold)
elif strategy == "rank":
result = AutoQrel.rank_cutoff(run, self.rank_cutoff)
elif strategy == "largest_gap":
result = AutoQrel.largest_gap(run, self.gap_k)
elif strategy == "quantile":
result = AutoQrel.quantile(run, self.quantile_cutoff)
elif strategy == "optimal_per_topic":
result = AutoQrel.optimal_per_topic(run, self.human_qrel, self.min_relevance)
elif strategy == "optimal_global":
result = AutoQrel.optimal_global(run, self.human_qrel, self.min_relevance)
else:
raise ValueError(f"Unknown thresholding strategy: {strategy!r}")
return AutoQrel._filter_no_relevant(result)
@staticmethod
def _filter_no_relevant(qrel):
"""Remove qids where all judged documents have zero relevance."""
return {qid: docs for qid, docs in qrel.items() if any(v > 0 for v in docs.values())}
@staticmethod
def direct(run):
"""Round each score to the nearest int → relevance label.
Best for judge runs whose scores are already relevance levels (e.g. 1.0, 2.0, 3.0)."""
return {qid: {did: round(score) for did, score in docs.items()} for qid, docs in run.items()}
@staticmethod
def thresholding(run, threshold=0.5): # greater and equal to threshold is positive
return {qid: {did: int(score >= threshold) for did, score in docs.items()}
for qid, docs in run.items()}
@staticmethod
def rank_cutoff(run, topk=10):
"""Top-k ranked docs → 1, remainder → 0.
Works for any ranked run regardless of score semantics."""
qrel = {}
for qid, docs in run.items():
ranked = list(docs.keys())
qrel[qid] = {did: (1 if i < topk else 0) for i, did in enumerate(ranked)}
return qrel
@staticmethod
def largest_gap(run, k=1):
"""Binary threshold at the k-th largest consecutive score gap (per topic).
Scores are sorted descending; consecutive differences are computed. The
position of the k-th largest gap becomes the positive/negative boundary:
all docs above (and including) that position are labeled 1, the rest 0.
k=1 → largest single drop (most common baseline)
k=2 → second-largest drop (useful if the top gap is a tie artefact)
"""
qrel = {}
for qid, docs in run.items():
ranked = sorted(docs.items(), key=lambda x: x[1], reverse=True)
if len(ranked) < 2:
qrel[qid] = {did: 1 for did, _ in ranked}
continue
scores = [s for _, s in ranked]
gaps = [(scores[i] - scores[i + 1], i) for i in range(len(scores) - 1)]
gaps_sorted = sorted(gaps, key=lambda x: x[0], reverse=True)
cutoff_rank = gaps_sorted[min(k, len(gaps_sorted)) - 1][1]
qrel[qid] = {
did: (1 if i <= cutoff_rank else 0)
for i, (did, _) in enumerate(ranked)
}
return qrel
@staticmethod
def quantile(run, q=0.75):
"""Per-topic quantile thresholding: score >= q-th percentile → 1, else 0.
q=0.75 sets the threshold at the 75th percentile of scores for that topic,
so roughly the top 25% of documents are labeled positive. The quantile is
computed per-topic so that score scale differences across topics don't bias
the labeling.
"""
qrel = {}
for qid, docs in run.items():
if not docs:
qrel[qid] = {}
continue
sorted_scores = sorted(docs.values())
pos = q * (len(sorted_scores) - 1)
lo, hi = int(pos), min(int(pos) + 1, len(sorted_scores) - 1)
threshold = sorted_scores[lo] + (pos - lo) * (sorted_scores[hi] - sorted_scores[lo])
qrel[qid] = {did: (1 if score >= threshold else 0) for did, score in docs.items()}
return qrel
@staticmethod
def _binary_f1(predicted_pos, actual_pos):
"""F1 between two sets of positive doc IDs."""
tp = len(predicted_pos & actual_pos)
if tp == 0:
return 0.0
precision = tp / len(predicted_pos)
recall = tp / len(actual_pos)
return 2 * precision * recall / (precision + recall)
@staticmethod
def optimal_per_topic(run, human_qrel, min_relevance=1):
"""Oracle: per-topic threshold that maximises F1 against human qrel.
For every query, all unique scores in the run are tried as candidate
thresholds; the one yielding the best binary F1 is chosen. This is an
oracle upper bound — it requires the human labels it is meant to replace.
"""
qrel = {}
for qid, docs in run.items():
hq = human_qrel.get(qid, {})
actual_pos = {did for did, rel in hq.items() if rel >= min_relevance}
if not actual_pos:
qrel[qid] = {did: 0 for did in docs}
continue
best_f1, best_threshold = -1.0, None
for thresh in sorted(set(docs.values()), reverse=True):
predicted_pos = {did for did, score in docs.items() if score >= thresh}
f1 = AutoQrel._binary_f1(predicted_pos, actual_pos)
if f1 > best_f1:
best_f1, best_threshold = f1, thresh
qrel[qid] = {
did: (1 if score >= best_threshold else 0)
for did, score in docs.items()
}
return qrel
@staticmethod
def qrel_to_run(qrel, run=None):
downscaling = 0.01
# ensure the run and qrel are sorted
sorted_qrel = OrderedDict()
for qid, qrel_hits in qrel.items():
sorted_qrel[qid] = dict(sorted(qrel_hits.items(), key=lambda x: x[1], reverse=True))
run = run if run is not None else {}
result = {}
for qid, qrel_hits in sorted_qrel.items():
run_hits = run.get(qid, {})
run_scores = {did: round(1/i * downscaling, 4) for i, did in \
enumerate(run_hits.keys(), start=1)}
result[qid] = {}
for docid, score in qrel_hits.items():
result[qid][docid] = sorted_qrel[qid][docid] + run_scores.get(docid, 0)
return result
@staticmethod
def optimal_global(run, human_qrel, min_relevance=1):
"""Oracle: single global threshold maximising macro-average F1 vs human qrel.
Every unique score across the entire run is tried as a global threshold;
the one with the highest average per-topic F1 is selected and applied
uniformly to all topics.
"""
all_scores = sorted(
{score for docs in run.values() for score in docs.values()}, reverse=True
)
topics_with_pos = {
qid for qid, docs in human_qrel.items()
if any(rel >= min_relevance for rel in docs.values())
}
best_avg_f1, best_threshold = -1.0, all_scores[-1]
for thresh in all_scores:
f1s = []
for qid, docs in run.items():
if qid not in topics_with_pos:
continue
hq = human_qrel.get(qid, {})
actual_pos = {did for did, rel in hq.items() if rel >= min_relevance}
predicted_pos = {did for did, score in docs.items() if score >= thresh}
f1s.append(AutoQrel._binary_f1(predicted_pos, actual_pos))
avg_f1 = sum(f1s) / len(f1s) if f1s else 0.0
if avg_f1 > best_avg_f1:
best_avg_f1, best_threshold = avg_f1, thresh
logger.info(
f"Look for optimal threshold." + \
"New best threshold={best_threshold:.5g}, avg_F1={best_avg_f1:.4f}"
)
return {
qid: {did: (1 if score >= best_threshold else 0) for did, score in docs.items()}
for qid, docs in run.items()
}
if __name__ == "__main__":
import ir_measures
from ir_measures import *
from pprint import pprint
parser = argparse.ArgumentParser(description="Evaluate retrieval runs against LLM-judge-derived qrel.")
parser.add_argument("--dataset_name", type=str, required=True)
parser.add_argument("--loader_type", type=str, default="irds")
parser.add_argument("--judge_run", type=str, required=True)
parser.add_argument("--evaluate_run", type=str, required=True, default=None)
## classification thresholding strategies
parser.add_argument("--strategies", action='append', choices=AutoQrel.STRATEGIES, default=None)
### binarize
parser.add_argument("--threshold", type=float, default=0.5)
### rank_cutoff
parser.add_argument("--rank_cutoff", type=int, default=10, help="Top-k for --thresholding rank_cutoff.")
parser.add_argument("--gap_k", type=int, default=1, help="k-th largest gap for --thresholding largest_gap.")
parser.add_argument("--quantile_cutoff", type=float, default=0.75, help="Quantile for --thresholding quantile.")
parser.add_argument("--min_relevance", type=int, default=1, help="Min relevance grade for oracle strategies.")
parser.add_argument("--exp", type=str, default=None, help="the experiment tag to record in the output")
parser.add_argument("--output", type=str, default=None, help="Path to save a per-query CSV for Colab analysis (long format).")
### sampling
args = parser.parse_args()
# Loading
loader = importlib.import_module(f"autollmrerank.loader_dev.{args.loader_type}")
run = loader.load_run(args.init_run)
_, _, qrel = loader.load(args.dataset_name)
judge_run = loader.load_run(args.judge_run) if args.judge_run else None
eval_run = loader.load_run(args.evaluate_run)
# Filtering
import random
random.seed( (args.sampling_seed or 42) )
selected_qids = random.sample( list(run.keys()), min(args.sampling_size, len(run)) )
judge_run = {qid: judge_run[qid] for qid in selected_qids}
eval_run = {qid: eval_run[qid] for qid in selected_qids}
qrels = {qid: qrels[qid] for qid in selected_qids}
autoqrel = AutoQrel(
qrel=qrel,
judge_run=judge_run,
strategies=args.strategies,
threshold=args.threshold,
rank_cutoff=args.rank_cutoff,
gap_k=args.gap_k,
quantile_cutoff=args.quantile_cutoff,
min_relevance=args.min_relevance,
)
# NOTE: this is deprecated as we are not going to use qrel as a run. It's the grountruth
# qrel_as_run = AutoQrel.qrel_to_run(qrel, judge_run)
# r3 = ir_measures.calc_aggregate([nDCG@10], autoqrel.human_qrel, qrel_as_run)[nDCG@10]
# Evaluate with the judge with ground-truth judge
judge_name = os.path.basename(args.judge_run)
results = []
r = ir_measures.calc_aggregate([nDCG@10], autoqrel.human_qrel, judge_run)[nDCG@10]
results = [{
'dataset': args.dataset_name,
'exp': args.exp,
'judge_run': 'human.qrel',
'evaluate_run': judge_name,
'strategy': 'human',
'nDCG@10': round(r, 4),
}]
# Evaluate against each thresholding strategy
eval_name = os.path.basename(args.evaluate_run)
for strategy, llm_qrel in autoqrel.llm_qrels.items():
llm_qrel = {qid: item for qid, item in llm_qrel.items() if qid in eval_run}
r = ir_measures.calc_aggregate([nDCG@10], llm_qrel, eval_run)[nDCG@10]
results.append({
'dataset': args.dataset_name,
'exp': args.exp,
'judge_run': judge_name,
'evaluate_run': eval_name,
'strategy': strategy,
'nDCG@10': round(r, 4),
})
# NOTE: per query evaluation
# for m in ir_measures.iter_calc([nDCG@10], llm_qrel, eval_run):
# if m.value != 1.0:
# print(m)
# Append aggregate rows to CSV for Colab analysis
if args.output:
with open(args.output, 'w') as f:
for result in results:
f.write(json.dumps(result) + '\n')

Xet Storage Details

Size:
14 kB
·
Xet hash:
d9ac16fa2d3f48e40b9e149f5663e1dd6ccf4e54fed6b92f82a79a96ff965315

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.