File size: 8,947 Bytes
ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 6cf04a5 ddb7b62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | import os
import json
import hashlib
import random
import pandas as pd
from typing import Union, Tuple
from threading import Lock
import torch
from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import Document
from pyserini.analysis import get_lucene_analyzer
from pyserini.pyclass import autoclass
from scipy.stats import hypergeom
from subspace.tool import SubspaceBERTScore
_CACHE_LOCK = Lock()
_SEARCHER_CACHE = {}
_HGF_ANALYZER = None
_STANDARD_QUERY_PARSER_CLASS = None
def _get_cached_searcher(index_path: str):
with _CACHE_LOCK:
searcher = _SEARCHER_CACHE.get(index_path)
if searcher is None:
searcher = LuceneSearcher(index_path)
_SEARCHER_CACHE[index_path] = searcher
return searcher
def _get_cached_hgf_analyzer():
global _HGF_ANALYZER
with _CACHE_LOCK:
if _HGF_ANALYZER is None:
_HGF_ANALYZER = get_lucene_analyzer(
language='hgf_tokenizer',
huggingFaceTokenizer='bert-base-uncased'
)
return _HGF_ANALYZER
# ==============================================================================
# BM25 Search and Query Building
# ==============================================================================
def get_standard_query(query: str, field: str = "contents", analyzer=None):
"""
Runs Lucene's StandardQueryParser to get a parsed query object.
"""
if analyzer is None:
analyzer = get_lucene_analyzer()
global _STANDARD_QUERY_PARSER_CLASS
if _STANDARD_QUERY_PARSER_CLASS is None:
_STANDARD_QUERY_PARSER_CLASS = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser')
query_parser = _STANDARD_QUERY_PARSER_CLASS()
query_parser.setAnalyzer(analyzer)
return query_parser.parse(query, field)
def query_bm25_index(index_path: str, keywords: list, doc_count: int = None) -> pd.DataFrame:
"""Load index, run BM25 phrase search using custom HuggingFace analyzer, and return results."""
# 1. Load searcher
searcher = _get_cached_searcher(index_path)
# 2. Load custom analyzer matching your index strategy
analyzer = _get_cached_hgf_analyzer()
# 3. Create query string connecting keywords by OR (e.g., '"jesus" OR "christ"')
query_string = " OR ".join([f'"{kw}"' for kw in keywords])
# 4. Build standard lucene query using your custom querybuilder
phrase_q = get_standard_query(query_string, analyzer=analyzer)
# 5. Search
if doc_count is None:
doc_count = searcher.num_docs
hits = searcher.search(phrase_q, doc_count)
# 6. Parse results
results = []
returned_ids = set()
for hit in hits:
returned_ids.add(hit.docid)
doc = Document(hit.lucene_document)
raw = doc.raw()
jd = json.loads(raw)
row = {
'id': jd.get("id"),
'content': jd.get("contents", ""),
'score': hit.score
}
if "metadata" in jd and jd["metadata"]:
metadata = json.loads(jd["metadata"])
row.update(metadata)
results.append(row)
returned_ext_ids = {r['id'] for r in results}
# Pad with random unretrieved items exactly as required
if len(results) < doc_count:
needed = doc_count - len(results)
total = searcher.num_docs
# build a list of internal docnums whose external ID wasn't already returned
pool = []
for docnum in range(total):
lucene_doc = searcher.doc(docnum)
doc = Document(lucene_doc)
jd = json.loads(doc.raw())
ext_id = jd.get("id")
if ext_id not in returned_ext_ids:
pool.append(docnum)
# deterministically shuffle by query
md5 = hashlib.md5(query_string.encode("utf-8")).hexdigest()
seed = int(md5, 16) % 2**32
rng = random.Random(seed)
rng.shuffle(pool)
# pull 'needed' more docs
for docnum in pool[:needed]:
lucene_doc = searcher.doc(docnum)
doc = Document(lucene_doc)
raw = doc.raw()
jd = json.loads(raw)
row = {
"id": jd.get("id"),
"content": jd.get("contents", ""),
"score": None
}
if "metadata" in jd and jd["metadata"]:
metadata = json.loads(jd["metadata"])
row.update(metadata)
results.append(row)
return pd.DataFrame(results)
def warmup_bm25(index_path: str, warmup_keyword: str = "community", warmup_k: int = 1) -> None:
"""Pre-warm Lucene searcher/analyzer/parser so first user query is faster."""
if not index_path or not os.path.isdir(index_path):
raise ValueError(f"Invalid index path for warmup: {index_path}")
searcher = _get_cached_searcher(index_path)
analyzer = _get_cached_hgf_analyzer()
query = get_standard_query(f'"{warmup_keyword}"', analyzer=analyzer)
searcher.search(query, max(1, int(warmup_k)))
# ==============================================================================
# Evaluation Metrics (Precision/Lift)
# ==============================================================================
def _resolve_k(df, k):
"""Convert float percentages to absolute k or return k as an int."""
if isinstance(k, float) and 0.0 < k <= 1.0:
return int(len(df) * k)
return int(k)
def precision_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
"""Calculate precision at k for a target demographic."""
rel = (df['demographic'] == correct_demographic).astype(int)
k_abs = _resolve_k(df, k)
if k_abs <= 0:
return 0.0
return rel.iloc[:k_abs].sum() / float(k_abs)
def lift_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
"""Lift@k: ratio of precision@k to the overall proportion of relevant items."""
k_abs = _resolve_k(df, k)
if k_abs <= 0 or len(df) == 0:
return 0.0
precision_k = precision_at_k(df, correct_demographic, k)
rel = (df['demographic'] == correct_demographic).astype(int)
overall_proportion = rel.sum() / float(len(df))
if overall_proportion == 0:
return 0.0
return precision_k / overall_proportion
def hypergeometric_significance_test(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, Tuple[int, int], Tuple[float, float]]:
"""Hypergeometric statistical significance test for the retrieval."""
n = _resolve_k(df, k)
N = len(df)
rel = (df['demographic'] == correct_demographic).astype(int)
K = rel.sum()
k_obs = rel.iloc[:n].sum()
if K == 0 or n <= 0:
return 0.0, (0, 0), (0.0, 0.0)
p_value = hypergeom.sf(k_obs - 1, N, K, n)
L = int(hypergeom.ppf(alpha/2, N, K, n))
U = int(hypergeom.isf(alpha/2, N, K, n))
return p_value, (L, U), (L / n, U / n)
def lift_ci(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, float, float]:
"""Calculate confidence interval for lift@k using hypergeometric distribution."""
n = _resolve_k(df, k)
N = len(df)
rel = (df['demographic'] == correct_demographic).astype(int)
K = rel.sum()
overall_proportion = K / float(N)
if K == 0 or n <= 0 or overall_proportion == 0:
return 0.0, 0.0, 0.0
pval, (L, U), _ = hypergeometric_significance_test(df, correct_demographic, k, alpha)
lower_bound_lift = (L / n) / overall_proportion
upper_bound_lift = (U / n) / overall_proportion
return pval, lower_bound_lift, upper_bound_lift
# ==============================================================================
# Keyword Similarity (SubspaceBERTScore)
# ==============================================================================
_GLOBAL_SCORER = None
def compute_keyword_similarity(set1: list, set2: list, device: str = None) -> dict:
"""
Computes precision, recall, and F-score similarity metrics between two keyword sets.
Mirrors the subspace-based BERT scoring logic handling keyword lists.
"""
global _GLOBAL_SCORER
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if _GLOBAL_SCORER is None:
print(f"Initializing BERT model on {device}...")
_GLOBAL_SCORER = SubspaceBERTScore(device=device, model_name_or_path='bert-base-uncased')
sentence_1 = [", ".join(set1)]
sentence_2 = [", ".join(set2)]
scores = _GLOBAL_SCORER(sentence_1, sentence_2)
return {
'Precision': scores[0].item(),
'Recall': scores[1].item(),
'F-Score': scores[2].item()
}
|